batmat develop
Batched linear algebra routines
Loading...
Searching...
No Matches
thread-pool.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <condition_variable>
4#include <exception>
5#include <functional>
6#include <mutex>
7#include <optional>
8#include <stdexcept>
9#include <stop_token>
10#include <thread>
11#include <utility>
12#include <vector>
13
14namespace batmat {
15
16/// @ingroup topic-utils
18 private:
19 struct State {
20 std::mutex mtx;
21 std::condition_variable_any cv;
22 std::function<void()> func;
23 std::exception_ptr exception;
24
25 void run(std::stop_token stop);
26 };
27 std::vector<State> states;
28 std::vector<std::jthread> threads; // must be destroyed first
29
30 public:
31 explicit thread_pool(size_t num_threads = std::thread::hardware_concurrency())
32 : states(num_threads) {
33 threads.reserve(num_threads);
34 for (auto &state : states)
35 threads.emplace_back([&state](std::stop_token stop) { state.run(std::move(stop)); });
36 }
37
38 thread_pool(const thread_pool &) = delete;
39 thread_pool &operator=(const thread_pool &) = delete;
40 thread_pool(thread_pool &&) = default;
42
43 void schedule(size_t i, std::function<void()> func) {
44 auto &state = states[i];
45 std::unique_lock lck{state.mtx};
46 state.func = std::move(func);
47 lck.unlock();
48 state.cv.notify_all();
49 }
50
51 void wait(size_t i) {
52 auto &state = states[i];
53 std::unique_lock lck{state.mtx};
54 state.cv.wait(lck, [&] { return !state.func; });
55 if (auto &e = state.exception)
56 std::rethrow_exception(std::exchange(e, nullptr));
57 }
58
59 void wait_all() {
60 for (size_t i = 0; i < size(); ++i)
61 wait(i);
62 }
63
64 template <class I = size_t, class F>
65 void sync_run_all(F &&f) {
66 const auto n = size();
67 for (size_t i = 0; i < n; ++i)
68 schedule(i, [&f, i, n] { f(static_cast<I>(i), static_cast<I>(n)); });
69 wait_all();
70 }
71
72 template <class I = size_t, class F>
73 void sync_run_n(I n, F &&f) {
74 if (static_cast<size_t>(n) > size())
75 throw std::invalid_argument("Not enough threads in pool");
76 for (size_t i = 0; i < static_cast<size_t>(n); ++i)
77 schedule(i, [&f, i, n] { f(static_cast<I>(i), n); });
78 for (size_t i = 0; i < static_cast<size_t>(n); ++i)
79 wait(i);
80 }
81
82 [[nodiscard]] size_t size() const { return threads.size(); }
83};
84
85inline void thread_pool::State::run(std::stop_token stop) {
86 while (true) {
87 std::unique_lock lck{mtx};
88 cv.wait(lck, stop, [&] { return static_cast<bool>(func); });
89 if (stop.stop_requested()) {
90 break;
91 } else {
92 try {
93 func();
94 } catch (...) {
95 exception = std::current_exception();
96 }
97 func = nullptr;
98 lck.unlock();
99 cv.notify_all();
100 }
101 }
102}
103
104namespace detail {
105extern std::mutex pool_mtx;
106extern std::optional<thread_pool> pool;
107} // namespace detail
108
109/// Set the number of threads in the global thread pool.
110/// @ingroup topic-utils
111/// @deprecated
112[[deprecated]] void pool_set_num_threads(size_t num_threads);
113
114/// Run a function on all threads in the global thread pool, synchronously waiting for all threads.
115/// @ingroup topic-utils
116/// @deprecated
117template <class I = size_t, class F>
118[[deprecated]] void pool_sync_run_all(F &&f) {
119 std::lock_guard<std::mutex> lck(detail::pool_mtx);
120 if (!detail::pool)
121 return;
122 detail::pool->sync_run_all(std::forward<F>(f));
123}
124
125/// Run a function on the first @p n threads in the global thread pool, synchronously waiting for
126/// those threads. If @p n is greater than the number of threads in the pool, the pool is expanded.
127/// @ingroup topic-utils
128/// @deprecated
129template <class I = size_t, class F>
130[[deprecated]] void pool_sync_run_n(I n, F &&f) {
131 std::lock_guard<std::mutex> lck(detail::pool_mtx);
132 if (!detail::pool || detail::pool->size() < static_cast<size_t>(n))
133 detail::pool.emplace(static_cast<size_t>(n));
134 detail::pool->sync_run_n(n, std::forward<F>(f));
135}
136
137} // namespace batmat
void sync_run_n(I n, F &&f)
thread_pool & operator=(const thread_pool &)=delete
void sync_run_all(F &&f)
void wait(size_t i)
thread_pool(size_t num_threads=std::thread::hardware_concurrency())
std::vector< State > states
void schedule(size_t i, std::function< void()> func)
thread_pool(const thread_pool &)=delete
std::vector< std::jthread > threads
thread_pool(thread_pool &&)=default
thread_pool & operator=(thread_pool &&)=default
size_t size() const
void pool_sync_run_all(F &&f)
Run a function on all threads in the global thread pool, synchronously waiting for all threads.
void pool_sync_run_n(I n, F &&f)
Run a function on the first n threads in the global thread pool, synchronously waiting for those thre...
void pool_set_num_threads(size_t num_threads)
Set the number of threads in the global thread pool.
std::optional< thread_pool > pool
std::mutex pool_mtx
std::condition_variable_any cv
std::function< void()> func
void run(std::stop_token stop)
std::exception_ptr exception