this repo has no description
at main 293 lines 7.8 kB view raw
1#pragma once 2 3#include <cassert> 4#include <coroutine> 5#include <exception> 6#include <optional> 7#include <stdexec/execution.hpp> 8#include <type_traits> 9#include <utility> 10 11namespace kev 12{ 13 14template <typename T> struct task; 15 16/** 17 * @brief Common data for all promise types. 18 * 19 */ 20struct promise_data 21{ 22 std::exception_ptr m_exception{}; 23}; 24 25/** 26 * @brief Promise base for tasks with return type T. 27 * 28 * @tparam T The return type of the task. 29 */ 30template <typename T> struct promise_base : promise_data 31{ 32 /** 33 * @brief Store the returned value in the promise. 34 * 35 * @param value 36 */ 37 void return_value(T value) 38 { 39 this->m_value = std::move(value); 40 } 41 std::optional<T> m_value{std::nullopt}; 42}; 43 44/** 45 * @brief Specialization of promise base for void return type. 46 * 47 */ 48template <> struct promise_base<void> : promise_data 49{ 50 /** 51 * @brief Handle the void return type. 52 * 53 */ 54 static constexpr void return_void() 55 { 56 } 57}; 58 59/** 60 * @brief Promise type for tasks. 61 * 62 * @tparam T The return type of the task. 63 */ 64template <typename T> struct task_promise : stdexec::with_awaitable_senders<task_promise<T>>, promise_base<T> 65{ 66 using promise_type = task_promise<T>; 67 using coroutine_handle = std::coroutine_handle<promise_type>; 68 69 /** 70 * @brief Get the return object object 71 * 72 * @return task<T> 73 */ 74 auto get_return_object() -> task<T> 75 { 76 return task<T>{coroutine_handle::from_promise(*this)}; 77 } 78 /** 79 * @brief The task type always suspends at the beginning. 80 * 81 * @return auto 82 */ 83 auto initial_suspend() -> std::suspend_always 84 { 85 return std::suspend_always{}; 86 } 87 88 /** 89 * @brief Final awaiter to resume the continuation. 90 * 91 */ 92 struct final_awaiter 93 { 94 /** 95 * @brief Always suspend to resume the continuation. 96 * 97 */ 98 static constexpr auto await_ready() noexcept -> bool 99 { 100 // Always suspend to allow resuming the continuation. 101 return false; 102 } 103 104 std::coroutine_handle<> await_suspend(std::coroutine_handle<promise_type> h) noexcept 105 { 106 // cppreference: "if await_suspend returns a coroutine handle for some other coroutine, that handle is 107 // resumed (by a call to handle.resume())" 108 return h.promise().m_continuation; 109 } 110 111 void await_resume() noexcept 112 { 113 // Nothing to do here. 114 } 115 }; 116 117 /** 118 * @brief Final suspend point to resume the continuation. 119 * 120 * @return final_awaiter 121 */ 122 auto final_suspend() noexcept -> final_awaiter 123 { 124 return {}; 125 } 126 127 /** 128 * @brief Store the current exception in the promise. 129 * 130 */ 131 auto unhandled_exception() -> void 132 { 133 // cppreference: 134 // If the coroutine ends with an uncaught exception, it performs the following: 135 // catches the exception and calls promise.unhandled_exception() from within the catch-block 136 this->m_exception = std::current_exception(); 137 } 138 139 std::coroutine_handle<> m_continuation; 140}; 141 142/** 143 * @brief Awaiter for tasks. Allows awaiting the task and retrieving its result. 144 * 145 * @tparam T The return type of the task. 146 */ 147template <typename T> struct task_awaiter 148{ 149 using coroutine_handle = std::coroutine_handle<task_promise<T>>; 150 151 explicit task_awaiter(coroutine_handle handle) noexcept : m_coroutine_handle(handle) 152 { 153 // Take ownership of the coroutine handle. 154 } 155 156 task_awaiter(task_awaiter &&other) noexcept : m_coroutine_handle(std::exchange(other.m_coroutine_handle, {})) 157 { 158 // Awaiters are move only. 159 } 160 161 task_awaiter(const task_awaiter &) = delete; 162 task_awaiter &operator=(const task_awaiter &) = delete; 163 164 task_awaiter &operator=(task_awaiter &&other) noexcept 165 { 166 if (this != &other) 167 { 168 if (m_coroutine_handle) 169 { 170 m_coroutine_handle.destroy(); 171 } 172 m_coroutine_handle = std::exchange(other.m_coroutine_handle, {}); 173 } 174 return *this; 175 } 176 177 ~task_awaiter() 178 { 179 // Destroy the coroutine if we still own it. There's a possibility of transferring ownership via move 180 // construction/assignment. 181 if (m_coroutine_handle) 182 { 183 m_coroutine_handle.destroy(); 184 } 185 } 186 187 auto await_ready() -> bool 188 { 189 // Always suspend to allow the caller to await the task. 190 return false; 191 } 192 auto await_suspend(std::coroutine_handle<> continuation) -> bool 193 { 194 /* 195 cppreference on await_suspend return value: 196 if await_suspend returns bool: 197 - the value true returns control to the caller/resumer of the current coroutine 198 - the value false resumes the current coroutine. 199 */ 200 201 // Get the promise of the coroutine that is being awaited. 202 auto &promise = m_coroutine_handle.promise(); 203 // Set the continuation to a noop coroutine to avoid resuming the caller too early. 204 promise.m_continuation = std::noop_coroutine(); 205 // Resume the task coroutine. Since we set the continuation to a noop, the caller won't be resumed yet. 206 m_coroutine_handle.resume(); 207 if (m_coroutine_handle.done()) 208 { 209 // If the task is already done, we can resume the caller immediately. 210 return false; 211 } 212 // Set the actual continuation to the caller's coroutine handle. 213 promise.m_continuation = continuation; 214 // Indicate that we have a continuation to resume later. 215 return true; 216 } 217 auto await_resume() -> T 218 { 219 // The coroutine body has completed at this point. 220 if (m_coroutine_handle.promise().m_exception) 221 { 222 // If an exception was thrown in the body it would have been stored in the promise. 223 std::rethrow_exception(m_coroutine_handle.promise().m_exception); 224 } 225 if constexpr (std::is_void_v<T>) 226 { 227 return; 228 } 229 else 230 { 231 assert(m_coroutine_handle.promise().m_value.has_value() && "Task did not set a return value"); 232 return std::move(*(m_coroutine_handle.promise().m_value)); 233 } 234 } 235 236 private: 237 coroutine_handle m_coroutine_handle; 238}; 239 240template <typename T> struct task 241{ 242 using sender_concept = 243 stdexec::sender_t; // Mark task as a sender. This allows us to chain tasks with stdexec algorithms. 244 using promise_type = task_promise<T>; // The promise type associated with this task. 245 using coroutine_handle = std::coroutine_handle<promise_type>; // The coroutine handle type for this task. 246 247 task_awaiter<T> operator co_await() && 248 { 249 if (!m_handle) 250 { 251 throw std::runtime_error("Attempting to co_await a moved-from task"); 252 } 253 // Move the handle out of the task to ensure it's only awaited once. 254 return task_awaiter(std::exchange(m_handle, {})); 255 } 256 257 task(task &&other) noexcept : m_handle(std::exchange(other.m_handle, {})) 258 { 259 } 260 261 task &operator=(task &&other) noexcept 262 { 263 if (this != &other) 264 { 265 if (m_handle) 266 { 267 m_handle.destroy(); 268 } 269 m_handle = std::exchange(other.m_handle, {}); 270 } 271 return *this; 272 } 273 274 task(const task &) = delete; 275 task &operator=(const task &) = delete; 276 277 ~task() 278 { 279 if (m_handle) 280 { 281 m_handle.destroy(); 282 } 283 } 284 285 private: 286 friend promise_type; 287 explicit task(coroutine_handle handle) : m_handle(handle) 288 { 289 } 290 coroutine_handle m_handle{}; 291}; 292 293} // namespace kev