fixup ThreadPool

This commit is contained in:
B3n30 2017-09-30 14:41:02 +02:00
parent 16fb89fef0
commit cbd4623df7

View File

@ -1,9 +1,8 @@
// Copyright 2016 Citra Emulator Project / PPSSPP Project // Copyright 2017 Citra Emulator Project
// Licensed under GPLv2 or any later version // Licensed under GPLv2 or any later version
// Refer to the license.txt file included. // Refer to the license.txt file included.
#include <condition_variable> #include <condition_variable>
#include <deque>
#include <functional> #include <functional>
#include <future> #include <future>
#include <mutex> #include <mutex>
@ -16,7 +15,7 @@ namespace Common {
class ThreadPool { class ThreadPool {
private: private:
explicit ThreadPool(unsigned int num_threads) : num_threads(num_threads), workers(num_threads) { explicit ThreadPool(size_t num_threads) : num_threads(num_threads), workers(num_threads) {
ASSERT(num_threads); ASSERT(num_threads);
} }
@ -27,13 +26,13 @@ public:
} }
template <typename F, typename... Args> template <typename F, typename... Args>
auto push(F&& f, Args&&... args) { auto Push(F&& f, Args&&... args) {
auto ret = workers[next_worker].push(std::forward<F>(f), std::forward<Args>(args)...); auto ret = workers[next_worker].Push(std::forward<F>(f), std::forward<Args>(args)...);
next_worker = (next_worker + 1) % num_threads; next_worker = (next_worker + 1) % num_threads;
return ret; return ret;
} }
unsigned int total_threads() { const size_t total_threads() const {
return num_threads; return num_threads;
} }
@ -51,7 +50,7 @@ private:
queue_storage.reserve(capacity); queue_storage.reserve(capacity);
} }
void push(const T& element) { void Push(const T& element) {
std::unique_lock<std::mutex> lock(mutex); std::unique_lock<std::mutex> lock(mutex);
while (queue_storage.size() >= capacity) { while (queue_storage.size() >= capacity) {
queue_changed.wait(lock); queue_changed.wait(lock);
@ -60,7 +59,7 @@ private:
queue_changed.notify_one(); queue_changed.notify_one();
} }
T pop() { T Pop() {
std::unique_lock<std::mutex> lock(mutex); std::unique_lock<std::mutex> lock(mutex);
while (queue_storage.empty()) { while (queue_storage.empty()) {
queue_changed.wait(lock); queue_changed.wait(lock);
@ -71,12 +70,12 @@ private:
return element; return element;
} }
void push(T&& element) { void Push(T&& element) {
std::unique_lock<std::mutex> lock(mutex); std::unique_lock<std::mutex> lock(mutex);
while (queue_storage.size() >= capacity) { while (queue_storage.size() >= capacity) {
queue_changed.wait(lock); queue_changed.wait(lock);
} }
queue_storage.emplace_back(element); queue_storage.emplace_back(std::move(element));
queue_changed.notify_one(); queue_changed.notify_one();
} }
}; };
@ -88,16 +87,16 @@ private:
static constexpr size_t MAX_QUEUE_CAPACITY = 100; static constexpr size_t MAX_QUEUE_CAPACITY = 100;
public: public:
Worker() : queue(MAX_QUEUE_CAPACITY), thread([this] { loop(); }) {} Worker() : queue(MAX_QUEUE_CAPACITY), thread([this] { Loop(); }) {}
~Worker() { ~Worker() {
queue.push(nullptr); // Exit the loop queue.Push(nullptr); // Exit the loop
thread.join(); thread.join();
} }
void loop() { void Loop() {
for (;;) { while (true) {
std::function<void()> fn(queue.pop()); std::function<void()> fn(queue.Pop());
if (!fn) // a nullptr function is the signal to exit the loop if (!fn) // a nullptr function is the signal to exit the loop
break; break;
fn(); fn();
@ -105,16 +104,16 @@ private:
} }
template <typename F, typename... Args> template <typename F, typename... Args>
auto push(F&& f, Args&&... args) { auto Push(F&& f, Args&&... args) {
auto task = std::make_shared<std::packaged_task<decltype(f(args...))()>>( auto task = std::make_shared<std::packaged_task<decltype(f(args...))()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)); std::bind(std::forward<F>(f), std::forward<Args>(args)...));
queue.push([task]() { (*task)(); }); queue.Push([task] { (*task)(); });
return task->get_future(); return task->get_future();
} }
}; };
const unsigned int num_threads; const size_t num_threads;
int next_worker = 0; size_t next_worker = 0;
std::vector<Worker> workers; std::vector<Worker> workers;
}; };