#pragma once /* * Copyright (C) 2024 Brett Terpstra * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <https://www.gnu.org/licenses/>. */ #ifndef BLT_GP_THREADING_H #define BLT_GP_THREADING_H #include <blt/std/types.h> #include <blt/std/thread.h> #include <thread> #include <functional> #include <atomic> #include <type_traits> namespace blt::gp { namespace detail { struct empty_callable { void operator()() const { } }; } template <typename EnumId> class task_builder_t; template <typename EnumId, typename Parallel, typename Single = detail::empty_callable> class task_t { static_assert(std::is_enum_v<EnumId>, "Enum ID must be of enum type!"); static_assert(std::is_invocable_v<Parallel, int>, "Parallel must be invocable with exactly one argument (thread index)"); static_assert(std::is_invocable_v<Single>, "Single must be invocable with no arguments"); friend task_builder_t<EnumId>; public: task_t(const EnumId task_id, const Parallel& parallel, const Single& single): parallel(std::forward<Parallel>(parallel)), single(std::forward<Single>(single)), requires_single_sync(true), task_id(task_id) { } explicit task_t(const EnumId task_id, const Parallel& parallel): parallel(std::forward<Parallel>(parallel)), single(detail::empty_callable{}), task_id(task_id) { } void call_parallel(size_t thread_index) const { parallel(thread_index); } void call_single() const { single(); } [[nodiscard]] EnumId get_task_id() const { return task_id; } private: const Parallel& parallel; const Single& single; bool requires_single_sync = false; EnumId task_id; }; template <typename EnumId, typename Parallel, typename Single = detail::empty_callable> task_t(EnumId, Parallel, Single) -> task_t<EnumId, Parallel, Single>; template <typename EnumId> class task_builder_t { static_assert(std::is_enum_v<EnumId>, "Enum ID must be of enum type!"); using EnumInt = std::underlying_type_t<EnumId>; public: task_builder_t() = default; template <typename... Tasks> static std::function<void(barrier&, EnumId, size_t)> make_callable(Tasks&&... tasks) { return [&tasks...](barrier& sync_barrier, EnumId task, size_t thread_index) { call_jmp_table(sync_barrier, task, thread_index, tasks...); }; } private: template <typename Task> static void execute(barrier& sync_barrier, const size_t thread_index, Task&& task) { // sync_barrier.wait(); if (task.requires_single_sync) { if (thread_index == 0) task.call_single(); sync_barrier.wait(); } task.call_parallel(thread_index); // sync_barrier.wait(); } template <typename Task> static bool call(barrier& sync_barrier, const EnumId current_task, const size_t thread_index, Task&& task) { if (static_cast<EnumInt>(current_task) == static_cast<EnumInt>(task.get_task_id())) { execute(sync_barrier, thread_index, std::forward<Task>(task)); return false; } return true; } template <typename... Tasks> static void call_jmp_table(barrier& sync_barrier, const EnumId current_task, const size_t thread_index, Tasks&&... tasks) { if (static_cast<EnumInt>(current_task) >= sizeof...(tasks)) BLT_UNREACHABLE; (call(sync_barrier, current_task, thread_index, std::forward<Tasks>(tasks)) && ...); } }; template <typename EnumId> class thread_manager_t { static_assert(std::is_enum_v<EnumId>, "Enum ID must be of enum type!"); public: explicit thread_manager_t(const size_t thread_count, std::function<void(barrier&, EnumId, size_t)> task_func, const bool will_main_block = true): barrier(thread_count), will_main_block(will_main_block) { thread_callable = [this, task_func = std::move(task_func)](const size_t thread_index) { while (should_run) { barrier.wait(); if (tasks_remaining > 0) task_func(barrier, tasks.back(), thread_index); barrier.wait(); if (thread_index == 0) { if (this->will_main_block) { tasks.pop_back(); --tasks_remaining; } else { std::scoped_lock lock{task_lock}; tasks.pop_back(); --tasks_remaining; } } } }; for (size_t i = 0; i < will_main_block ? thread_count - 1 : thread_count; ++i) threads.emplace_back(thread_callable, will_main_block ? i + 1 : i); } void execute() const { BLT_ASSERT(will_main_block && "You attempted to call this function without specifying that " "you want an external blocking thread (try passing will_main_block = true)"); thread_callable(0); } void add_task(EnumId task) { if (will_main_block) { tasks.push_back(task); ++tasks_remaining; } else { std::scoped_lock lock(task_lock); tasks.push_back(task); ++tasks_remaining; } } bool has_tasks_left() { if (will_main_block) { return !tasks.empty(); } std::scoped_lock lock{task_lock}; return tasks.empty(); } ~thread_manager_t() { should_run = false; for (auto& thread : threads) { if (thread.joinable()) thread.join(); } } private: [[nodiscard]] size_t thread_count() const { return will_main_block ? threads.size() + 1 : threads.size(); } blt::barrier barrier; std::atomic_bool should_run = true; bool will_main_block; std::vector<EnumId> tasks; std::atomic_uint64_t tasks_remaining = 0; std::vector<std::thread> threads; std::mutex task_lock; std::function<void(size_t)> thread_callable; }; } #endif //BLT_GP_THREADING_H