mp-coro main
Coroutine support tools
when_all.h
Go to the documentation of this file.
1// The MIT License (MIT)
2//
3// Copyright (c) 2021 Mateusz Pusz
4//
5// Permission is hereby granted, free of charge, to any person obtaining a copy
6// of this software and associated documentation files (the "Software"), to deal
7// in the Software without restriction, including without limitation the rights
8// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9// copies of the Software, and to permit persons to whom the Software is
10// furnished to do so, subject to the following conditions:
11//
12// The above copyright notice and this permission notice shall be included in all
13// copies or substantial portions of the Software.
14//
15// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21// SOFTWARE.
22
23#pragma once
24
26#include <mp-coro/concepts.h>
27#include <mp-coro/trace.h>
28#include <mp-coro/type_traits.h>
29#include <atomic>
30#include <coroutine>
31#include <cstdint>
32#include <ranges>
33#include <tuple>
34#include <vector>
35
36namespace mp_coro {
37
38namespace detail {
39
41 std::atomic<std::size_t> counter_;
42 std::coroutine_handle<> continuation_;
43
44 public:
45 constexpr when_all_sync(std::size_t count) noexcept
46 : counter_(count + 1) {} // +1 for attaching a continuation
47 when_all_sync(when_all_sync &&other) noexcept
48 : counter_(other.counter_.load()), continuation_(other.continuation_) {}
49
50 /// @retval false when a continuation is being attached when all work is
51 /// already done and the current coroutine should be resumed right away via
52 /// Symmetric Control Transfer.
53 bool set_continuation(std::coroutine_handle<> cont) {
54 continuation_ = cont;
55 return counter_.fetch_sub(1, std::memory_order_acq_rel) > 1;
56 }
57
58 /// On completion of each task a counter is being decremented and if the
59 /// continuation is already attached (a counter was decremented by
60 /// @ref set_continuation() it is being resumed.
62 if (counter_.fetch_sub(1, std::memory_order_acq_rel) == 1)
63 continuation_.resume();
64 }
65
66 /// @retval true if the continuation is already assigned which means that
67 /// someone already awaited for the awaitable completion.
68 bool is_ready() const { return static_cast<bool>(continuation_); }
69};
70
71template <typename T>
72std::size_t tasks_size(T &container) {
73 if constexpr (std::ranges::range<T>)
74 return size(container);
75 else
76 return std::tuple_size_v<T>;
77}
78
79template <typename T>
80decltype(auto) start_all_tasks(T &container, when_all_sync &sync) {
81 if constexpr (std::ranges::range<T>)
82 for (auto &t : container)
83 t.start(sync);
84 else
85 std::apply([&](auto &...tasks) { (..., tasks.start(sync)); }, container);
86}
87
88template <typename T>
89decltype(auto) make_all_results(T &&container) {
90 if constexpr (std::ranges::range<T>) {
91 if constexpr (std::is_void_v<typename std::ranges::range_value_t<T>::value_type>) {
92 // in case of `void` check for exception and do not return any result
93 for (auto &task : container)
94 task.get();
95 } else {
96 std::vector<typename std::ranges::range_value_t<T>::value_type> result;
97 result.reserve(size(container));
98 for (auto &&task : std::forward<T>(container))
99 result.emplace_back(std::forward<decltype(task)>(task).get());
100 return result;
101 }
102 } else {
103 return std::apply(
104 [&]<typename... Tasks>(Tasks &&...tasks) {
105 if constexpr ((...
106 && std::is_void_v<
107 typename std::remove_cvref_t<Tasks>::value_type>)) {
108 // in case of all `void` check for exception and do not return any result
109 (..., tasks.get());
110 } else {
111 using ret_type = std::tuple<remove_rvalue_reference_t<
112 decltype(std::forward<Tasks>(tasks).nonvoid_get())>...>;
113 return ret_type(std::forward<Tasks>(tasks).nonvoid_get()...);
114 }
115 },
116 std::forward<T>(container));
117 }
118}
119
120template <typename T>
122 explicit when_all_awaitable(T &&tasks) : tasks_(std::move(tasks)) {}
123
124 decltype(auto) operator co_await() & {
125 struct awaiter : awaiter_base {
126 decltype(auto) await_resume() {
127 TRACE_FUNC();
128 return make_all_results(this->awaitable.tasks_);
129 }
130 };
131 return awaiter {{*this}};
132 }
133
134 decltype(auto) operator co_await() && {
135 struct awaiter : awaiter_base {
136 decltype(auto) await_resume() {
137 TRACE_FUNC();
138 return make_all_results(std::move(this->awaitable.tasks_));
139 }
140 };
141 return awaiter {{*this}};
142 }
143
144 private:
147
148 bool await_ready() const noexcept {
149 TRACE_FUNC();
150 return awaitable.sync_.is_ready();
151 }
152 bool await_suspend(std::coroutine_handle<> handle) {
153 TRACE_FUNC();
154 start_all_tasks(awaitable.tasks_, awaitable.sync_);
155 return awaitable.sync_.set_continuation(handle);
156 }
157 };
160};
161
162} // namespace detail
163
164template <awaitable... Awaitables>
165awaitable auto when_all(Awaitables &&...awaitables) {
166 TRACE_FUNC();
168 std::make_tuple(detail::make_synchronized_task<detail::when_all_sync>(
169 std::forward<Awaitables>(awaitables))...));
170}
171
172template <std::ranges::range R>
173awaitable auto when_all(R &&awaitables) {
174 TRACE_FUNC();
177 std::vector<task_t> tasks;
178 tasks.reserve(size(awaitables));
179 for (auto &&awaitable : std::forward<R>(awaitables))
180 tasks.emplace_back(detail::make_synchronized_task<detail::when_all_sync>(
181 std::forward<decltype(awaitable)>(awaitable)));
182 return detail::when_all_awaitable(std::move(tasks));
183}
184
185} // namespace mp_coro
Lazy task that can later be started explicitly, and that notifies another variable (the “sync” object...
when_all_sync(when_all_sync &&other) noexcept
Definition: when_all.h:47
std::atomic< std::size_t > counter_
Definition: when_all.h:41
std::coroutine_handle continuation_
Definition: when_all.h:42
constexpr when_all_sync(std::size_t count) noexcept
Definition: when_all.h:45
bool set_continuation(std::coroutine_handle<> cont)
Definition: when_all.h:53
void notify_awaitable_completed()
On completion of each task a counter is being decremented and if the continuation is already attached...
Definition: when_all.h:61
Task that produces a value of type T: to get that value, simply await the task.
Definition: task.h:65
decltype(auto) make_all_results(T &&container)
Definition: when_all.h:89
decltype(auto) start_all_tasks(T &container, when_all_sync &sync)
Definition: when_all.h:80
std::size_t tasks_size(T &container)
Definition: when_all.h:72
Definition: async.h:31
typename remove_rvalue_reference< T >::type remove_rvalue_reference_t
decltype(std::declval< awaiter_for_t< A > >().await_resume()) await_result_t
Definition: type_traits.h:35
awaitable auto when_all(Awaitables &&...awaitables)
Definition: when_all.h:165
bool await_suspend(std::coroutine_handle<> handle)
Definition: when_all.h:152
decltype(auto) await_resume() const
Return the value of the task's promise.
Definition: task.h:131
#define TRACE_FUNC()
Definition: trace.h:27