io-chess
UCI chess engine
Loading...
Searching...
No Matches
PersistentThreadPool.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <atomic>
4#include <condition_variable>
5#include <functional>
6#include <memory>
7#include <mutex>
8#include <queue>
9#include <thread>
10#include <vector>
11#include <algorithm>
12
25public:
26 explicit PersistentThreadPool(int workers) {
27 const int n = std::max(0, workers);
28 for (int i = 0; i < n; ++i)
30 }
31
33 {
34 std::lock_guard<std::mutex> lk(mu_);
35 stop_ = true;
36 }
37 cv_.notify_all();
38 for (auto &t : threads_) {
39 if (t.joinable())
40 t.join();
41 }
42 }
43
46
47 void parallel_for(int n, const std::function<void(int)> &fn) {
48 if (n <= 0)
49 return;
50 if (threads_.empty() || n == 1) {
51 for (int i = 0; i < n; ++i)
52 fn(i);
53 return;
54 }
55
56 auto job = std::make_shared<Job>();
57 job->total = n;
58 job->next.store(0, std::memory_order_relaxed);
59 job->remaining.store(n, std::memory_order_relaxed);
60 job->fn = fn;
61
62 {
63 std::lock_guard<std::mutex> lk(mu_);
64 for (size_t i = 0; i < threads_.size(); ++i)
65 jobs_.push(job);
66 }
67 cv_.notify_all();
68
69 run_job(job);
70
71 std::unique_lock<std::mutex> lk(job->doneMu);
72 job->doneCv.wait(lk, [&] {
73 return job->remaining.load(std::memory_order_acquire) == 0;
74 });
75 }
76
77private:
78 struct Job {
79 int total = 0;
80 std::atomic<int> next{0};
81 std::atomic<int> remaining{0};
82 std::function<void(int)> fn;
83 std::mutex doneMu;
84 std::condition_variable doneCv;
85 };
86
87 std::vector<std::thread> threads_;
88 std::mutex mu_;
89 std::condition_variable cv_;
90 std::queue<std::shared_ptr<Job>> jobs_;
91 bool stop_ = false;
92
93 static void finish_one(const std::shared_ptr<Job> &job) {
94 if (job->remaining.fetch_sub(1, std::memory_order_acq_rel) == 1) {
95 std::lock_guard<std::mutex> lk(job->doneMu);
96 job->doneCv.notify_one();
97 }
98 }
99
100 static void run_job(const std::shared_ptr<Job> &job) {
101 while (true) {
102 const int idx = job->next.fetch_add(1, std::memory_order_relaxed);
103 if (idx >= job->total)
104 break;
105 job->fn(idx);
106 finish_one(job);
107 }
108 }
109
110 void worker_loop() {
111 while (true) {
112 std::shared_ptr<Job> job;
113 {
114 std::unique_lock<std::mutex> lk(mu_);
115 cv_.wait(lk, [&] { return stop_ || !jobs_.empty(); });
116 if (stop_ && jobs_.empty())
117 return;
118 job = jobs_.front();
119 jobs_.pop();
120 }
121 run_job(job);
122 }
123 }
124};
PersistentThreadPool(int workers)
Definition PersistentThreadPool.hpp:26
PersistentThreadPool(const PersistentThreadPool &)=delete
std::mutex mu_
Definition PersistentThreadPool.hpp:88
std::queue< std::shared_ptr< Job > > jobs_
Definition PersistentThreadPool.hpp:90
static void finish_one(const std::shared_ptr< Job > &job)
Definition PersistentThreadPool.hpp:93
PersistentThreadPool & operator=(const PersistentThreadPool &)=delete
static void run_job(const std::shared_ptr< Job > &job)
Definition PersistentThreadPool.hpp:100
bool stop_
Definition PersistentThreadPool.hpp:91
~PersistentThreadPool()
Definition PersistentThreadPool.hpp:32
std::condition_variable cv_
Definition PersistentThreadPool.hpp:89
std::vector< std::thread > threads_
Definition PersistentThreadPool.hpp:87
void parallel_for(int n, const std::function< void(int)> &fn)
Definition PersistentThreadPool.hpp:47
void worker_loop()
Definition PersistentThreadPool.hpp:110
Definition PersistentThreadPool.hpp:78
int total
Definition PersistentThreadPool.hpp:79
std::condition_variable doneCv
Definition PersistentThreadPool.hpp:84
std::mutex doneMu
Definition PersistentThreadPool.hpp:83
std::atomic< int > remaining
Definition PersistentThreadPool.hpp:81
std::atomic< int > next
Definition PersistentThreadPool.hpp:80
std::function< void(int)> fn
Definition PersistentThreadPool.hpp:82