more thread pool changes, added counting of tasks. im sure parker will hate this :3

v1
Brett 2024-03-22 11:40:01 -04:00
parent 6a5b7a6865
commit 16641a27cb
2 changed files with 53 additions and 16 deletions

View File

@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.5) cmake_minimum_required(VERSION 3.5)
include(cmake/color.cmake) include(cmake/color.cmake)
set(BLT_VERSION 0.15.1) set(BLT_VERSION 0.15.2)
set(BLT_TEST_VERSION 0.0.1) set(BLT_TEST_VERSION 0.0.1)
set(BLT_TARGET BLT) set(BLT_TARGET BLT)

View File

@ -36,24 +36,24 @@ namespace blt
/** /**
* @tparam queue should we use a queue or execute the same function over and over? * @tparam queue should we use a queue or execute the same function over and over?
*/ */
template<bool queue = false> template<bool queue = false, typename... Args>
class thread_pool class thread_pool
{ {
private: private:
typedef std::function<void()> thread_function; using thread_function = std::function<void(Args...)>;
volatile std::atomic_bool should_stop = false; volatile std::atomic_bool should_stop = false;
volatile std::atomic_uint64_t stopped = 0; volatile std::atomic_uint64_t stopped = 0;
std::uint64_t number_of_threads = 0; std::uint64_t number_of_threads = 0;
std::vector<std::thread*> threads; std::vector<std::thread*> threads;
std::variant<std::queue<thread_function>, thread_function> func_queue; std::variant<std::queue<thread_function>, thread_function> func_queue;
std::mutex queue_mutex; std::mutex queue_mutex;
// only used when a queue
volatile std::atomic_uint64_t tasks = 0;
volatile std::atomic_uint64_t completed_tasks = 0;
bool func_loaded = false; bool func_loaded = false;
public:
explicit thread_pool(std::uint64_t number_of_threads = 8, std::optional<thread_function> default_function = {}) void init()
{ {
if (default_function.has_value())
func_queue = default_function.value();
this->number_of_threads = number_of_threads;
for (std::uint64_t i = 0; i < number_of_threads; i++) for (std::uint64_t i = 0; i < number_of_threads; i++)
{ {
threads.push_back(new std::thread([this]() { threads.push_back(new std::thread([this]() {
@ -75,9 +75,11 @@ namespace blt
func_q.pop(); func_q.pop();
lock.unlock(); lock.unlock();
func(); func();
completed_tasks++;
} else } else
{ {
if (!func_loaded){ if (!func_loaded)
{
std::scoped_lock lock(queue_mutex); std::scoped_lock lock(queue_mutex);
if (std::holds_alternative<std::queue<thread_function>>(func_queue)) if (std::holds_alternative<std::queue<thread_function>>(func_queue))
{ {
@ -96,6 +98,25 @@ namespace blt
} }
} }
void cleanup()
{
for (auto* t : threads)
{
if (t->joinable())
t->join();
delete t;
}
}
public:
explicit thread_pool(std::uint64_t number_of_threads = 8, std::optional<thread_function> default_function = {})
{
if (default_function.has_value())
func_queue = default_function.value();
this->number_of_threads = number_of_threads;
init();
}
inline void execute(const thread_function& func) inline void execute(const thread_function& func)
{ {
std::scoped_lock lock(queue_mutex); std::scoped_lock lock(queue_mutex);
@ -103,13 +124,20 @@ namespace blt
{ {
auto& v = std::get<std::queue<thread_function>>(func_queue); auto& v = std::get<std::queue<thread_function>>(func_queue);
v.push(func); v.push(func);
tasks++;
} else } else
{ {
func_queue = func; func_queue = func;
} }
} }
[[nodiscard]] inline bool complete() const { [[nodiscard]] inline bool tasks_complete() const
{
return completed_tasks == tasks;
}
[[nodiscard]] inline bool complete() const
{
return stopped == number_of_threads; return stopped == number_of_threads;
} }
@ -118,15 +146,24 @@ namespace blt
should_stop = true; should_stop = true;
} }
inline void reset_tasks()
{
tasks = 0;
completed_tasks = 0;
}
inline void reset()
{
stop();
cleanup();
stopped = 0;
init();
}
~thread_pool() ~thread_pool()
{ {
should_stop = true; should_stop = true;
for (auto* t : threads) cleanup();
{
if (t->joinable())
t->join();
delete t;
}
} }
}; };
} }