#pragma once /* * Copyright (C) 2024 Brett Terpstra * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ #ifndef BLT_GP_THREADING_H #define BLT_GP_THREADING_H #include #include #include #include #include #include namespace blt::gp { namespace detail { struct empty_callable { void operator()() const { } }; } template class task_builder_t; template class task_t { static_assert(std::is_enum_v, "Enum ID must be of enum type!"); static_assert(std::is_invocable_v, "Parallel must be invocable with exactly one argument (thread index)"); static_assert(std::is_invocable_v, "Single must be invocable with no arguments"); friend task_builder_t; public: task_t(const EnumId task_id, const Parallel& parallel, const Single& single): parallel(std::forward(parallel)), single(std::forward(single)), requires_single_sync(true), task_id(task_id) { } explicit task_t(const EnumId task_id, const Parallel& parallel): parallel(std::forward(parallel)), single(detail::empty_callable{}), task_id(task_id) { } void call_parallel(size_t thread_index) const { parallel(thread_index); } void call_single() const { single(); } [[nodiscard]] EnumId get_task_id() const { return task_id; } private: const Parallel& parallel; const Single& single; bool requires_single_sync = false; EnumId task_id; }; template task_t(EnumId, Parallel, Single) -> task_t; template class task_builder_t { static_assert(std::is_enum_v, "Enum ID must be of enum type!"); using EnumInt = std::underlying_type_t; public: task_builder_t() = default; template static std::function make_callable(Tasks&&... tasks) { return [&tasks...](barrier& sync_barrier, EnumId task, size_t thread_index) { call_jmp_table(sync_barrier, task, thread_index, tasks...); }; } private: template static void execute(barrier& sync_barrier, const size_t thread_index, Task&& task) { // sync_barrier.wait(); if (task.requires_single_sync) { if (thread_index == 0) task.call_single(); sync_barrier.wait(); } task.call_parallel(thread_index); // sync_barrier.wait(); } template static bool call(barrier& sync_barrier, const EnumId current_task, const size_t thread_index, Task&& task) { if (static_cast(current_task) == static_cast(task.get_task_id())) { execute(sync_barrier, thread_index, std::forward(task)); return false; } return true; } template static void call_jmp_table(barrier& sync_barrier, const EnumId current_task, const size_t thread_index, Tasks&&... tasks) { if (static_cast(current_task) >= sizeof...(tasks)) BLT_UNREACHABLE; (call(sync_barrier, current_task, thread_index, std::forward(tasks)) && ...); } }; template class thread_manager_t { static_assert(std::is_enum_v, "Enum ID must be of enum type!"); public: explicit thread_manager_t(const size_t thread_count, std::function task_func, const bool will_main_block = true): barrier(thread_count), will_main_block(will_main_block) { thread_callable = [this, task_func = std::move(task_func)](const size_t thread_index) { while (should_run) { barrier.wait(); if (tasks_remaining > 0) task_func(barrier, tasks.back(), thread_index); barrier.wait(); if (thread_index == 0) { if (this->will_main_block) { tasks.pop_back(); --tasks_remaining; } else { std::scoped_lock lock{task_lock}; tasks.pop_back(); --tasks_remaining; } } } }; for (size_t i = 0; i < will_main_block ? thread_count - 1 : thread_count; ++i) threads.emplace_back(thread_callable, will_main_block ? i + 1 : i); } void execute() const { BLT_ASSERT(will_main_block && "You attempted to call this function without specifying that " "you want an external blocking thread (try passing will_main_block = true)"); thread_callable(0); } void add_task(EnumId task) { if (will_main_block) { tasks.push_back(task); ++tasks_remaining; } else { std::scoped_lock lock(task_lock); tasks.push_back(task); ++tasks_remaining; } } bool has_tasks_left() { if (will_main_block) { return !tasks.empty(); } std::scoped_lock lock{task_lock}; return tasks.empty(); } ~thread_manager_t() { should_run = false; for (auto& thread : threads) { if (thread.joinable()) thread.join(); } } private: [[nodiscard]] size_t thread_count() const { return will_main_block ? threads.size() + 1 : threads.size(); } blt::barrier barrier; std::atomic_bool should_run = true; bool will_main_block; std::vector tasks; std::atomic_uint64_t tasks_remaining = 0; std::vector threads; std::mutex task_lock; std::function thread_callable; }; } #endif //BLT_GP_THREADING_H