batmat main
Batched linear algebra routines
Loading...
Searching...
No Matches
thread-pool.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <condition_variable>
4#include <exception>
5#include <functional>
6#include <mutex>
7#include <optional>
8#include <stdexcept>
9#include <thread>
10#include <utility>
11#include <vector>
12
13namespace batmat {
14
15/// @ingroup topic-utils
17 private:
18 struct Signals {
19 std::mutex mtx;
20 std::condition_variable_any cv;
21 };
22 std::vector<Signals> signals;
23 std::vector<std::function<void()>> funcs;
24 std::vector<std::exception_ptr> exceptions;
25 std::vector<std::jthread> threads; // must be destroyed first
26
27 void work(std::stop_token stop, size_t i) {
28 auto &sig = signals[i];
29 while (true) {
30 std::unique_lock lck{sig.mtx};
31 sig.cv.wait(lck, stop, [&] { return static_cast<bool>(funcs[i]); });
32 if (stop.stop_requested()) {
33 break;
34 } else {
35 try {
36 funcs[i]();
37 } catch (...) {
38 exceptions[i] = std::current_exception();
39 }
40 funcs[i] = nullptr;
41 lck.unlock();
42 sig.cv.notify_all();
43 }
44 }
45 }
46
47 public:
48 explicit thread_pool(size_t num_threads = std::thread::hardware_concurrency())
49 : signals(num_threads), funcs(num_threads), exceptions(num_threads) {
50 threads.reserve(num_threads);
51 for (size_t i = 0; i < num_threads; ++i)
52 threads.emplace_back(
53 [this, i](std::stop_token stop) { this->work(std::move(stop), i); });
54 }
55
56 thread_pool(const thread_pool &) = delete;
57 thread_pool &operator=(const thread_pool &) = delete;
58 thread_pool(thread_pool &&) = default;
60
61 void schedule(size_t i, std::function<void()> func) {
62 auto &sig = signals[i];
63 std::unique_lock lck{sig.mtx};
64 funcs[i] = std::move(func);
65 lck.unlock();
66 sig.cv.notify_all();
67 }
68
69 void wait(size_t i) {
70 auto &sig = signals[i];
71 std::unique_lock lck{sig.mtx};
72 sig.cv.wait(lck, [&] { return !funcs[i]; });
73 if (auto &e = exceptions[i])
74 std::rethrow_exception(std::exchange(e, nullptr));
75 }
76
77 void wait_all() {
78 for (size_t i = 0; i < size(); ++i)
79 wait(i);
80 }
81
82 template <class I = size_t, class F>
83 void sync_run_all(F &&f) {
84 const auto n = size();
85 for (size_t i = 0; i < n; ++i)
86 schedule(i, [&f, i, n] { f(static_cast<I>(i), static_cast<I>(n)); });
87 wait_all();
88 }
89
90 template <class I = size_t, class F>
91 void sync_run_n(I n, F &&f) {
92 if (static_cast<size_t>(n) > size())
93 throw std::invalid_argument("Not enough threads in pool");
94 for (size_t i = 0; i < static_cast<size_t>(n); ++i)
95 schedule(i, [&f, i, n] { f(static_cast<I>(i), n); });
96 for (size_t i = 0; i < static_cast<size_t>(n); ++i)
97 wait(i);
98 }
99
100 [[nodiscard]] size_t size() const { return threads.size(); }
101};
102
103namespace detail {
104extern std::mutex pool_mtx;
105extern std::optional<thread_pool> pool;
106} // namespace detail
107
108/// Set the number of threads in the global thread pool.
109/// @ingroup topic-utils
110/// @deprecated
111[[deprecated]] void pool_set_num_threads(size_t num_threads);
112
113/// Run a function on all threads in the global thread pool, synchronously waiting for all threads.
114/// @ingroup topic-utils
115/// @deprecated
116template <class I = size_t, class F>
117[[deprecated]] void pool_sync_run_all(F &&f) {
118 std::lock_guard<std::mutex> lck(detail::pool_mtx);
119 if (!detail::pool)
120 return;
121 detail::pool->sync_run_all(std::forward<F>(f));
122}
123
124/// Run a function on the first @p n threads in the global thread pool, synchronously waiting for
125/// those threads. If @p n is greater than the number of threads in the pool, the pool is expanded.
126/// @ingroup topic-utils
127/// @deprecated
128template <class I = size_t, class F>
129[[deprecated]] void pool_sync_run_n(I n, F &&f) {
130 std::lock_guard<std::mutex> lck(detail::pool_mtx);
131 if (!detail::pool || detail::pool->size() < static_cast<size_t>(n))
132 detail::pool.emplace(static_cast<size_t>(n));
133 detail::pool->sync_run_n(n, std::forward<F>(f));
134}
135
136} // namespace batmat
void sync_run_n(I n, F &&f)
thread_pool & operator=(const thread_pool &)=delete
std::vector< Signals > signals
std::vector< std::exception_ptr > exceptions
std::condition_variable_any cv
void sync_run_all(F &&f)
void wait(size_t i)
thread_pool(size_t num_threads=std::thread::hardware_concurrency())
void schedule(size_t i, std::function< void()> func)
thread_pool(const thread_pool &)=delete
std::vector< std::jthread > threads
thread_pool(thread_pool &&)=default
std::vector< std::function< void()> > funcs
thread_pool & operator=(thread_pool &&)=default
size_t size() const
void work(std::stop_token stop, size_t i)
void pool_sync_run_all(F &&f)
Run a function on all threads in the global thread pool, synchronously waiting for all threads.
void pool_sync_run_n(I n, F &&f)
Run a function on the first n threads in the global thread pool, synchronously waiting for those thre...
void pool_set_num_threads(size_t num_threads)
Set the number of threads in the global thread pool.
std::optional< thread_pool > pool
std::mutex pool_mtx