243 lines
8.3 KiB
C++
243 lines
8.3 KiB
C++
/*
|
|
* <Short Description>
|
|
* Copyright (C) 2023 Brett Terpstra
|
|
*
|
|
* This program is free software: you can redistribute it and/or modify
|
|
* it under the terms of the GNU General Public License as published by
|
|
* the Free Software Foundation, either version 3 of the License, or
|
|
* (at your option) any later version.
|
|
*
|
|
* This program is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
* GNU General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU General Public License
|
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
*/
|
|
|
|
#ifndef BLT_THREAD_H
|
|
#define BLT_THREAD_H
|
|
|
|
#include <blt/std/types.h>
|
|
#include <thread>
|
|
#include <functional>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <queue>
|
|
#include <utility>
|
|
#include <variant>
|
|
#include <atomic>
|
|
#include <mutex>
|
|
#include <chrono>
|
|
#include <optional>
|
|
#include <blt/std/logging.h>
|
|
#include <condition_variable>
|
|
|
|
namespace blt
|
|
{
|
|
|
|
class barrier
|
|
{
|
|
public:
|
|
explicit barrier(blt::size_t threads, std::optional<std::reference_wrapper<std::atomic_bool>> exit_cond = {}):
|
|
thread_count(threads), threads_waiting(0), use_count(0), exit_cond(exit_cond), count_mutex(), cv()
|
|
{
|
|
if (threads == 0)
|
|
throw std::runtime_error("Barrier thread count cannot be 0");
|
|
}
|
|
|
|
barrier(const barrier& copy) = delete;
|
|
|
|
barrier(barrier&& move) = delete;
|
|
|
|
barrier& operator=(const barrier& copy) = delete;
|
|
|
|
barrier& operator=(barrier&& move) = delete;
|
|
|
|
~barrier() = default;
|
|
|
|
void wait()
|
|
{
|
|
// (unique_lock acquires lock)
|
|
std::unique_lock lock(count_mutex);
|
|
std::size_t current_uses = use_count;
|
|
|
|
if (++threads_waiting == thread_count)
|
|
{
|
|
threads_waiting = 0;
|
|
use_count++;
|
|
cv.notify_all();
|
|
} else
|
|
{
|
|
if constexpr (BUSY_LOOP_WAIT > 0) // NOLINT
|
|
{
|
|
lock.unlock();
|
|
for (blt::size_t i = 0; i < BUSY_LOOP_WAIT; i++)
|
|
{
|
|
if (use_count != current_uses || (exit_cond && exit_cond->get()))
|
|
return;
|
|
}
|
|
lock.lock();
|
|
}
|
|
cv.wait(lock, [this, ¤t_uses]() {
|
|
return (use_count != current_uses || (exit_cond && exit_cond->get()));
|
|
});
|
|
}
|
|
}
|
|
|
|
void notify_all()
|
|
{
|
|
cv.notify_all();
|
|
}
|
|
|
|
private:
|
|
blt::size_t thread_count;
|
|
blt::size_t threads_waiting;
|
|
blt::size_t use_count;
|
|
std::optional<std::reference_wrapper<std::atomic_bool>> exit_cond;
|
|
std::mutex count_mutex;
|
|
std::condition_variable cv;
|
|
|
|
// improves performance by not blocking the thread for n iterations of the loop.
|
|
// If the condition is not met by the end of this loop we can block the thread.
|
|
static constexpr blt::size_t BUSY_LOOP_WAIT = 200;
|
|
};
|
|
|
|
/**
|
|
* @tparam queue should we use a queue or execute the same function over and over?
|
|
*/
|
|
template<bool queue = false, typename... Args>
|
|
class thread_pool
|
|
{
|
|
private:
|
|
using thread_function = std::function<void(Args...)>;
|
|
volatile std::atomic_bool should_stop = false;
|
|
volatile std::atomic_uint64_t stopped = 0;
|
|
std::uint64_t number_of_threads = 0;
|
|
std::vector<std::thread*> threads;
|
|
std::variant<std::queue<thread_function>, thread_function> func_queue;
|
|
std::mutex queue_mutex;
|
|
// only used when a queue
|
|
volatile std::atomic_uint64_t tasks = 0;
|
|
volatile std::atomic_uint64_t completed_tasks = 0;
|
|
bool func_loaded = false;
|
|
|
|
void init()
|
|
{
|
|
for (std::uint64_t i = 0; i < number_of_threads; i++)
|
|
{
|
|
threads.push_back(new std::thread([this]() {
|
|
while (!should_stop)
|
|
{
|
|
if constexpr (queue)
|
|
{
|
|
// should be safe right?
|
|
std::unique_lock lock(queue_mutex);
|
|
lock.lock();
|
|
auto& func_q = std::get<std::queue<thread_function>>(func_queue);
|
|
if (func_q.empty())
|
|
{
|
|
lock.unlock();
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(16));
|
|
continue;
|
|
}
|
|
auto func = func_q.front();
|
|
func_q.pop();
|
|
lock.unlock();
|
|
func();
|
|
completed_tasks++;
|
|
} else
|
|
{
|
|
if (!func_loaded)
|
|
{
|
|
std::scoped_lock lock(queue_mutex);
|
|
if (std::holds_alternative<std::queue<thread_function>>(func_queue))
|
|
{
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(16));
|
|
//BLT_WARN("Running non queue variant with a queue inside!");
|
|
continue;
|
|
}
|
|
func_loaded = true;
|
|
}
|
|
auto& func = std::get<thread_function>(func_queue);
|
|
func();
|
|
}
|
|
}
|
|
stopped++;
|
|
}));
|
|
}
|
|
}
|
|
|
|
void cleanup()
|
|
{
|
|
for (auto* t : threads)
|
|
{
|
|
if (t->joinable())
|
|
t->join();
|
|
delete t;
|
|
}
|
|
}
|
|
|
|
public:
|
|
explicit thread_pool(std::uint64_t number_of_threads = 8, std::optional<thread_function> default_function = {})
|
|
{
|
|
if (default_function.has_value())
|
|
func_queue = default_function.value();
|
|
this->number_of_threads = number_of_threads;
|
|
init();
|
|
}
|
|
|
|
inline void execute(const thread_function& func)
|
|
{
|
|
std::scoped_lock lock(queue_mutex);
|
|
if constexpr (queue)
|
|
{
|
|
auto& v = std::get<std::queue<thread_function>>(func_queue);
|
|
v.push(func);
|
|
tasks++;
|
|
} else
|
|
{
|
|
func_queue = func;
|
|
}
|
|
}
|
|
|
|
[[nodiscard]] inline bool tasks_complete() const
|
|
{
|
|
return completed_tasks == tasks;
|
|
}
|
|
|
|
[[nodiscard]] inline bool complete() const
|
|
{
|
|
return stopped == number_of_threads;
|
|
}
|
|
|
|
inline void stop()
|
|
{
|
|
should_stop = true;
|
|
}
|
|
|
|
inline void reset_tasks()
|
|
{
|
|
tasks = 0;
|
|
completed_tasks = 0;
|
|
}
|
|
|
|
inline void reset()
|
|
{
|
|
stop();
|
|
cleanup();
|
|
stopped = 0;
|
|
init();
|
|
}
|
|
|
|
~thread_pool()
|
|
{
|
|
should_stop = true;
|
|
cleanup();
|
|
}
|
|
};
|
|
}
|
|
|
|
#endif //BLT_THREAD_H
|