this repo has no description
1#pragma once
2
3#include <cassert>
4#include <coroutine>
5#include <exception>
6#include <optional>
7#include <stdexec/execution.hpp>
8#include <type_traits>
9#include <utility>
10
11namespace kev
12{
13
14template <typename T> struct task;
15
16/**
17 * @brief Common data for all promise types.
18 *
19 */
20struct promise_data
21{
22 std::exception_ptr m_exception{};
23};
24
25/**
26 * @brief Promise base for tasks with return type T.
27 *
28 * @tparam T The return type of the task.
29 */
30template <typename T> struct promise_base : promise_data
31{
32 /**
33 * @brief Store the returned value in the promise.
34 *
35 * @param value
36 */
37 void return_value(T value)
38 {
39 this->m_value = std::move(value);
40 }
41 std::optional<T> m_value{std::nullopt};
42};
43
44/**
45 * @brief Specialization of promise base for void return type.
46 *
47 */
48template <> struct promise_base<void> : promise_data
49{
50 /**
51 * @brief Handle the void return type.
52 *
53 */
54 static constexpr void return_void()
55 {
56 }
57};
58
59/**
60 * @brief Promise type for tasks.
61 *
62 * @tparam T The return type of the task.
63 */
64template <typename T> struct task_promise : stdexec::with_awaitable_senders<task_promise<T>>, promise_base<T>
65{
66 using promise_type = task_promise<T>;
67 using coroutine_handle = std::coroutine_handle<promise_type>;
68
69 /**
70 * @brief Get the return object object
71 *
72 * @return task<T>
73 */
74 auto get_return_object() -> task<T>
75 {
76 return task<T>{coroutine_handle::from_promise(*this)};
77 }
78 /**
79 * @brief The task type always suspends at the beginning.
80 *
81 * @return auto
82 */
83 auto initial_suspend() -> std::suspend_always
84 {
85 return std::suspend_always{};
86 }
87
88 /**
89 * @brief Final awaiter to resume the continuation.
90 *
91 */
92 struct final_awaiter
93 {
94 /**
95 * @brief Always suspend to resume the continuation.
96 *
97 */
98 static constexpr auto await_ready() noexcept -> bool
99 {
100 // Always suspend to allow resuming the continuation.
101 return false;
102 }
103
104 std::coroutine_handle<> await_suspend(std::coroutine_handle<promise_type> h) noexcept
105 {
106 // cppreference: "if await_suspend returns a coroutine handle for some other coroutine, that handle is
107 // resumed (by a call to handle.resume())"
108 return h.promise().m_continuation;
109 }
110
111 void await_resume() noexcept
112 {
113 // Nothing to do here.
114 }
115 };
116
117 /**
118 * @brief Final suspend point to resume the continuation.
119 *
120 * @return final_awaiter
121 */
122 auto final_suspend() noexcept -> final_awaiter
123 {
124 return {};
125 }
126
127 /**
128 * @brief Store the current exception in the promise.
129 *
130 */
131 auto unhandled_exception() -> void
132 {
133 // cppreference:
134 // If the coroutine ends with an uncaught exception, it performs the following:
135 // catches the exception and calls promise.unhandled_exception() from within the catch-block
136 this->m_exception = std::current_exception();
137 }
138
139 std::coroutine_handle<> m_continuation;
140};
141
142/**
143 * @brief Awaiter for tasks. Allows awaiting the task and retrieving its result.
144 *
145 * @tparam T The return type of the task.
146 */
147template <typename T> struct task_awaiter
148{
149 using coroutine_handle = std::coroutine_handle<task_promise<T>>;
150
151 explicit task_awaiter(coroutine_handle handle) noexcept : m_coroutine_handle(handle)
152 {
153 // Take ownership of the coroutine handle.
154 }
155
156 task_awaiter(task_awaiter &&other) noexcept : m_coroutine_handle(std::exchange(other.m_coroutine_handle, {}))
157 {
158 // Awaiters are move only.
159 }
160
161 task_awaiter(const task_awaiter &) = delete;
162 task_awaiter &operator=(const task_awaiter &) = delete;
163
164 task_awaiter &operator=(task_awaiter &&other) noexcept
165 {
166 if (this != &other)
167 {
168 if (m_coroutine_handle)
169 {
170 m_coroutine_handle.destroy();
171 }
172 m_coroutine_handle = std::exchange(other.m_coroutine_handle, {});
173 }
174 return *this;
175 }
176
177 ~task_awaiter()
178 {
179 // Destroy the coroutine if we still own it. There's a possibility of transferring ownership via move
180 // construction/assignment.
181 if (m_coroutine_handle)
182 {
183 m_coroutine_handle.destroy();
184 }
185 }
186
187 auto await_ready() -> bool
188 {
189 // Always suspend to allow the caller to await the task.
190 return false;
191 }
192 auto await_suspend(std::coroutine_handle<> continuation) -> bool
193 {
194 /*
195 cppreference on await_suspend return value:
196 if await_suspend returns bool:
197 - the value true returns control to the caller/resumer of the current coroutine
198 - the value false resumes the current coroutine.
199 */
200
201 // Get the promise of the coroutine that is being awaited.
202 auto &promise = m_coroutine_handle.promise();
203 // Set the continuation to a noop coroutine to avoid resuming the caller too early.
204 promise.m_continuation = std::noop_coroutine();
205 // Resume the task coroutine. Since we set the continuation to a noop, the caller won't be resumed yet.
206 m_coroutine_handle.resume();
207 if (m_coroutine_handle.done())
208 {
209 // If the task is already done, we can resume the caller immediately.
210 return false;
211 }
212 // Set the actual continuation to the caller's coroutine handle.
213 promise.m_continuation = continuation;
214 // Indicate that we have a continuation to resume later.
215 return true;
216 }
217 auto await_resume() -> T
218 {
219 // The coroutine body has completed at this point.
220 if (m_coroutine_handle.promise().m_exception)
221 {
222 // If an exception was thrown in the body it would have been stored in the promise.
223 std::rethrow_exception(m_coroutine_handle.promise().m_exception);
224 }
225 if constexpr (std::is_void_v<T>)
226 {
227 return;
228 }
229 else
230 {
231 assert(m_coroutine_handle.promise().m_value.has_value() && "Task did not set a return value");
232 return std::move(*(m_coroutine_handle.promise().m_value));
233 }
234 }
235
236 private:
237 coroutine_handle m_coroutine_handle;
238};
239
240template <typename T> struct task
241{
242 using sender_concept =
243 stdexec::sender_t; // Mark task as a sender. This allows us to chain tasks with stdexec algorithms.
244 using promise_type = task_promise<T>; // The promise type associated with this task.
245 using coroutine_handle = std::coroutine_handle<promise_type>; // The coroutine handle type for this task.
246
247 task_awaiter<T> operator co_await() &&
248 {
249 if (!m_handle)
250 {
251 throw std::runtime_error("Attempting to co_await a moved-from task");
252 }
253 // Move the handle out of the task to ensure it's only awaited once.
254 return task_awaiter(std::exchange(m_handle, {}));
255 }
256
257 task(task &&other) noexcept : m_handle(std::exchange(other.m_handle, {}))
258 {
259 }
260
261 task &operator=(task &&other) noexcept
262 {
263 if (this != &other)
264 {
265 if (m_handle)
266 {
267 m_handle.destroy();
268 }
269 m_handle = std::exchange(other.m_handle, {});
270 }
271 return *this;
272 }
273
274 task(const task &) = delete;
275 task &operator=(const task &) = delete;
276
277 ~task()
278 {
279 if (m_handle)
280 {
281 m_handle.destroy();
282 }
283 }
284
285 private:
286 friend promise_type;
287 explicit task(coroutine_handle handle) : m_handle(handle)
288 {
289 }
290 coroutine_handle m_handle{};
291};
292
293} // namespace kev