diff --git a/CMakeLists.txt b/CMakeLists.txt index ccbcf2f..0101d34 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,7 +27,7 @@ macro(compile_options target_name) sanitizers(${target_name}) endmacro() -project(blt-gp VERSION 0.3.7) +project(blt-gp VERSION 0.3.8) include(CTest) diff --git a/include/blt/gp/threading.h b/include/blt/gp/threading.h index 2f491d1..263a896 100644 --- a/include/blt/gp/threading.h +++ b/include/blt/gp/threading.h @@ -24,61 +24,193 @@ #include #include #include +#include namespace blt::gp { - template + namespace detail + { + struct empty_callable + { + void operator()() const + { + } + }; + } + + template + class task_builder_t; + + template class task_t { + static_assert(std::is_enum_v, "Enum ID must be of enum type!"); + static_assert(std::is_invocable_v, "Parallel must be invocable with exactly one argument (thread index)"); + static_assert(std::is_invocable_v, "Single must be invocable with no arguments"); + + friend task_builder_t; + public: - task_t(Parallel parallel, Single single): parallel(parallel), single(single), requires_single_sync(true) + task_t(const EnumId task_id, const Parallel& parallel, const Single& single): parallel(std::forward(parallel)), + single(std::forward(single)), + requires_single_sync(true), task_id(task_id) { } - explicit task_t(Parallel parallel): parallel(parallel), requires_single_sync(false) + explicit task_t(const EnumId task_id, const Parallel& parallel): parallel(std::forward(parallel)), single(detail::empty_callable{}), + task_id(task_id) { } - void call_parallel(size_t thread_index) + void call_parallel(size_t thread_index) const { parallel(thread_index); } - void call_single() + void call_single() const { single(); } + + [[nodiscard]] EnumId get_task_id() const + { + return task_id; + } + private: - Parallel parallel; - Single single; - bool explicit_sync_begin : 1 = true; - bool explicit_sync_end : 1 = true; - bool requires_single_sync : 1 = false; + const Parallel& parallel; + const Single& single; + bool requires_single_sync = false; + EnumId task_id; }; - template - class task_storage_t + template + task_t(EnumId, Parallel, Single) -> task_t; + + template + class task_builder_t { + static_assert(std::is_enum_v, "Enum ID must be of enum type!"); + using EnumInt = std::underlying_type_t; + + public: + task_builder_t() = default; + + template + static std::function make_callable(Tasks&&... tasks) + { + return [&tasks...](barrier& sync_barrier, EnumId task, size_t thread_index) + { + call_jmp_table(sync_barrier, task, thread_index, tasks...); + }; + } + + private: + template + static void execute(barrier& sync_barrier, const size_t thread_index, Task&& task) + { + // sync_barrier.wait(); + if (task.requires_single_sync) + { + if (thread_index == 0) + task.call_single(); + sync_barrier.wait(); + } + task.call_parallel(thread_index); + // sync_barrier.wait(); + } + + template + static bool call(barrier& sync_barrier, const EnumId current_task, const size_t thread_index, Task&& task) + { + if (static_cast(current_task) == static_cast(task.get_task_id())) + { + execute(sync_barrier, thread_index, std::forward(task)); + return false; + } + return true; + } + + template + static void call_jmp_table(barrier& sync_barrier, const EnumId current_task, const size_t thread_index, Tasks&&... tasks) + { + if (static_cast(current_task) >= sizeof...(tasks)) + BLT_UNREACHABLE; + (call(sync_barrier, current_task, thread_index, std::forward(tasks)) && ...); + } }; + template class thread_manager_t { + static_assert(std::is_enum_v, "Enum ID must be of enum type!"); + public: - explicit thread_manager_t(const size_t thread_count, const bool will_main_block = true): barrier(thread_count), - will_main_block(will_main_block) + explicit thread_manager_t(const size_t thread_count, std::function task_func, + const bool will_main_block = true): barrier(thread_count), will_main_block(will_main_block) { - for (size_t i = 0; i < will_main_block ? thread_count - 1 : thread_count; ++i) + thread_callable = [this, task_func = std::move(task_func)](const size_t thread_index) { - threads.emplace_back([i, this]() + while (should_run) { - while (should_run) + barrier.wait(); + if (tasks_remaining > 0) + task_func(barrier, tasks.back(), thread_index); + barrier.wait(); + if (thread_index == 0) { + if (this->will_main_block) + { + tasks.pop_back(); + --tasks_remaining; + } + else + { + std::scoped_lock lock{task_lock}; + tasks.pop_back(); + --tasks_remaining; + } } - }); + } + }; + for (size_t i = 0; i < will_main_block ? thread_count - 1 : thread_count; ++i) + threads.emplace_back(thread_callable, will_main_block ? i + 1 : i); + } + + void execute() const + { + BLT_ASSERT(will_main_block && + "You attempted to call this function without specifying that " + "you want an external blocking thread (try passing will_main_block = true)"); + thread_callable(0); + } + + void add_task(EnumId task) + { + if (will_main_block) + { + tasks.push_back(task); + ++tasks_remaining; + } + else + { + std::scoped_lock lock(task_lock); + tasks.push_back(task); + ++tasks_remaining; } } - ~thread_manager() + bool has_tasks_left() + { + if (will_main_block) + { + return !tasks.empty(); + } + std::scoped_lock lock{task_lock}; + return tasks.empty(); + } + + ~thread_manager_t() { should_run = false; for (auto& thread : threads) @@ -94,11 +226,15 @@ namespace blt::gp return will_main_block ? threads.size() + 1 : threads.size(); } - barrier barrier; + blt::barrier barrier; std::atomic_bool should_run = true; bool will_main_block; - std::vector tasks; + std::vector tasks; + std::atomic_uint64_t tasks_remaining = 0; std::vector threads; + std::mutex task_lock; + + std::function thread_callable; }; } diff --git a/tests/symbolic_regression_test.cpp b/tests/symbolic_regression_test.cpp index aea4459..8460410 100644 --- a/tests/symbolic_regression_test.cpp +++ b/tests/symbolic_regression_test.cpp @@ -165,15 +165,46 @@ void do_run() std::cout << std::endl; } -template +template auto what(What addr, What2 addr2) -> decltype(addr + addr2) { return addr + addr2; } +enum class test +{ + hello, + there +}; + +inline void hello(blt::size_t) +{ + BLT_TRACE("I did some parallel work!"); +} + +inline void there(blt::size_t) +{ + BLT_TRACE("Wow there"); +} + int main() { - for (int i = 0; i < 1; i++) - do_run(); + blt::gp::thread_manager_t threads{ + std::thread::hardware_concurrency(), blt::gp::task_builder_t::make_callable( + blt::gp::task_t{test::hello, hello}, + blt::gp::task_t{test::there, there} + ) + }; + + threads.add_task(test::hello); + threads.add_task(test::hello); + threads.add_task(test::hello); + threads.add_task(test::there); + + while (threads.has_tasks_left()) + threads.execute(); + + // for (int i = 0; i < 1; i++) + // do_run(); return 0; }