diff --git a/CMakeLists.txt b/CMakeLists.txt index 74823a9..dc51476 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.25) -project(blt-gp VERSION 0.0.49) +project(blt-gp VERSION 0.0.50) include(CTest) diff --git a/examples/gp_test_5.cpp b/examples/gp_test_5.cpp index 1a25c19..2be92bc 100644 --- a/examples/gp_test_5.cpp +++ b/examples/gp_test_5.cpp @@ -37,32 +37,35 @@ #include #include #include +#include +#include static constexpr long SEED = 41912; + blt::gp::type_provider type_system; blt::gp::gp_program program(type_system, std::mt19937_64{SEED}); // NOLINT -blt::gp::operation_t add([](float a, float b) { return a + b; }); // 0 -blt::gp::operation_t sub([](float a, float b) { return a - b; }); // 1 -blt::gp::operation_t mul([](float a, float b) { return a * b; }); // 2 -blt::gp::operation_t pro_div([](float a, float b) { return b == 0 ? 0.0f : a / b; }); // 3 +blt::gp::operation_t add([](float a, float b) { return a + b; }, "add"); // 0 +blt::gp::operation_t sub([](float a, float b) { return a - b; }, "sub"); // 1 +blt::gp::operation_t mul([](float a, float b) { return a * b; }, "mul"); // 2 +blt::gp::operation_t pro_div([](float a, float b) { return b == 0 ? 0.0f : a / b; }, "div"); // 3 -blt::gp::operation_t op_if([](bool b, float a, float c) { return b ? a : c; }); // 4 -blt::gp::operation_t eq_f([](float a, float b) { return a == b; }); // 5 -blt::gp::operation_t eq_b([](bool a, bool b) { return a == b; }); // 6 -blt::gp::operation_t lt([](float a, float b) { return a < b; }); // 7 -blt::gp::operation_t gt([](float a, float b) { return a > b; }); // 8 -blt::gp::operation_t op_and([](bool a, bool b) { return a && b; }); // 9 -blt::gp::operation_t op_or([](bool a, bool b) { return a || b; }); // 10 -blt::gp::operation_t op_xor([](bool a, bool b) { return static_cast(a ^ b); }); // 11 -blt::gp::operation_t op_not([](bool b) { return !b; }); // 12 +blt::gp::operation_t op_if([](bool b, float a, float c) { return b ? a : c; }, "if"); // 4 +blt::gp::operation_t eq_f([](float a, float b) { return a == b; }, "eq_f"); // 5 +blt::gp::operation_t eq_b([](bool a, bool b) { return a == b; }, "eq_b"); // 6 +blt::gp::operation_t lt([](float a, float b) { return a < b; }, "lt"); // 7 +blt::gp::operation_t gt([](float a, float b) { return a > b; }, "gt"); // 8 +blt::gp::operation_t op_and([](bool a, bool b) { return a && b; }, "and"); // 9 +blt::gp::operation_t op_or([](bool a, bool b) { return a || b; }, "or"); // 10 +blt::gp::operation_t op_xor([](bool a, bool b) { return static_cast(a ^ b); }, "xor"); // 11 +blt::gp::operation_t op_not([](bool b) { return !b; }, "not"); // 12 blt::gp::operation_t lit([]() { // 13 //static std::uniform_real_distribution dist(-32000, 32000); static std::uniform_real_distribution dist(0.0f, 10.0f); return dist(program.get_random()); -}); +}, "lit"); /** * This is a test using multiple types with blt::gp @@ -106,29 +109,88 @@ int main() blt::gp::crossover_t crossover; auto& ind = pop.getIndividuals(); - auto results = crossover.apply(program, ind[0], ind[1]); - BLT_INFO("Post crossover:"); - if (results.has_value()) + std::vector pre; + std::vector pos; + blt::size_t errors = 0; + BLT_INFO("Pre-Crossover:"); + for (auto& tree : pop.getIndividuals()) { - BLT_TRACE("Parent 1: %f", ind[0].get_evaluation_value(nullptr)); - BLT_TRACE("Parent 2: %f", ind[1].get_evaluation_value(nullptr)); - BLT_TRACE("------------"); - BLT_TRACE("Child 1: %f", results->child1.get_evaluation_value(nullptr)); - BLT_TRACE("Child 2: %f", results->child2.get_evaluation_value(nullptr)); - } else + auto f = tree.get_evaluation_value(nullptr); + pre.push_back(f); + BLT_TRACE(f); + } + + BLT_INFO("Crossover:"); + blt::gp::population_t new_pop; + while (new_pop.getIndividuals().size() < pop.getIndividuals().size()) { - switch (results.error()) + auto& random = program.get_random(); + std::uniform_int_distribution dist(0ul, pop.getIndividuals().size() - 1); + blt::size_t first = dist(random); + blt::size_t second; + do { - case blt::gp::crossover_t::error_t::NO_VALID_TYPE: - BLT_ERROR("No valid type!"); - break; - case blt::gp::crossover_t::error_t::TREE_TOO_SMALL: - BLT_ERROR("Tree is too small!"); - break; + second = dist(random); + } while (second == first); + + auto results = crossover.apply(program, ind[first], ind[second]); + if (results.has_value()) + { +// bool print_literals = true; +// bool pretty_print = false; +// bool print_returns = false; +// BLT_TRACE("Parent 1: %f", ind[0].get_evaluation_value(nullptr)); +// ind[0].print(program, std::cout, print_literals, pretty_print, print_returns); +// BLT_TRACE("Parent 2: %f", ind[1].get_evaluation_value(nullptr)); +// ind[1].print(program, std::cout, print_literals, pretty_print, print_returns); +// BLT_TRACE("------------"); +// BLT_TRACE("Child 1: %f", results->child1.get_evaluation_value(nullptr)); +// results->child1.print(program, std::cout, print_literals, pretty_print, print_returns); +// BLT_TRACE("Child 2: %f", results->child2.get_evaluation_value(nullptr)); +// results->child2.print(program, std::cout, print_literals, pretty_print, print_returns); + new_pop.getIndividuals().push_back(std::move(results->child1)); + new_pop.getIndividuals().push_back(std::move(results->child2)); + } else + { + switch (results.error()) + { + case blt::gp::crossover_t::error_t::NO_VALID_TYPE: + BLT_ERROR("No valid type!"); + break; + case blt::gp::crossover_t::error_t::TREE_TOO_SMALL: + BLT_ERROR("Tree is too small!"); + break; + } + errors++; + new_pop.getIndividuals().push_back(ind[first]); + new_pop.getIndividuals().push_back(ind[second]); } } + BLT_INFO("Post-Crossover:"); + for (auto& tree : new_pop.getIndividuals()) + { + auto f = tree.get_evaluation_value(nullptr); + pos.push_back(f); + BLT_TRACE(f); + } + + BLT_INFO("Stats:"); + blt::size_t eq = 0; + for (const auto& v : pos) + { + for (const auto m : pre) + { + if (v == m) + { + eq++; + break; + } + } + } + BLT_INFO("Equal values: %ld", eq); + BLT_INFO("Error times: %ld", errors); return 0; } diff --git a/include/blt/gp/fwdecl.h b/include/blt/gp/fwdecl.h index dfbf903..bfde2e7 100644 --- a/include/blt/gp/fwdecl.h +++ b/include/blt/gp/fwdecl.h @@ -22,6 +22,7 @@ #include #include #include +#include namespace blt::gp { @@ -55,6 +56,8 @@ namespace blt::gp using callable_t = std::function; // to, from using transfer_t = std::function; + // debug function, + using print_func_t = std::function; } } diff --git a/include/blt/gp/operations.h b/include/blt/gp/operations.h index ac7ffd0..304bb75 100644 --- a/include/blt/gp/operations.h +++ b/include/blt/gp/operations.h @@ -24,6 +24,7 @@ #include #include #include +#include namespace blt::gp { @@ -83,7 +84,7 @@ namespace blt::gp template inline static constexpr Return exec_sequence_to_indices(Func&& func, stack_allocator& allocator, std::integer_sequence, - ExtraArgs&&... args) + ExtraArgs&& ... args) { // expands Args and indices, providing each argument with its index calculating the current argument byte offset return std::forward(func)(std::forward(args)..., allocator.from(getByteOffset())...); @@ -120,7 +121,7 @@ namespace blt::gp constexpr operation_t(operation_t&& move) = default; template - constexpr explicit operation_t(const Functor& functor): func(functor) + constexpr explicit operation_t(const Functor& functor, std::optional name = {}): func(functor), name(name) {} [[nodiscard]] constexpr inline Return operator()(stack_allocator& read_allocator) const @@ -171,9 +172,15 @@ namespace blt::gp { return sizeof...(Args); } + + [[nodiscard]] inline constexpr std::optional get_name() const + { + return name; + } private: function_t func; + std::optional name; }; template @@ -188,6 +195,12 @@ namespace blt::gp template operation_t(Return(*)(Args...)) -> operation_t; + + template + operation_t(Lambda, std::optional) -> operation_t; + + template + operation_t(Return(*)(Args...), std::optional) -> operation_t; // templat\e // operation_t make_operator(Return (Class::*)(Args...) const lambda) diff --git a/include/blt/gp/program.h b/include/blt/gp/program.h index 0c4f651..440986e 100644 --- a/include/blt/gp/program.h +++ b/include/blt/gp/program.h @@ -48,6 +48,11 @@ namespace blt::gp { blt::u32 argc = 0; blt::u32 argc_context = 0; + + [[nodiscard]] bool is_terminal() const + { + return argc == 0; + } }; struct config_t @@ -56,14 +61,21 @@ namespace blt::gp blt::u16 max_crossover_tries = 5; // if we fail to find a point in the tree, should we search forward from the last point to the end of the operators? bool should_crossover_try_forward = false; + // avoid selecting terminals when doing crossover + bool avoid_terminals = false; }; struct operator_info { + // types of the arguments std::vector argument_types; + // return type of this operator type_id return_type; + // number of arguments for this operator argc_t argc; + // function to call this operator detail::callable_t function; + // function used to transfer values between stacks detail::transfer_t transfer; }; @@ -80,6 +92,8 @@ namespace blt::gp // std::vector operators; // std::vector transfer_funcs; std::vector operators; + std::vector print_funcs; + std::vector> names; }; template @@ -120,7 +134,7 @@ namespace blt::gp info.transfer = [](stack_allocator& to, stack_allocator& from) { #if BLT_DEBUG_LEVEL >= 3 auto value = from.pop(); - BLT_TRACE_STREAM << value << "\n"; + //BLT_TRACE_STREAM << value << "\n"; to.push(value); #else to.push(from.pop()); @@ -128,6 +142,10 @@ namespace blt::gp }; storage.operators.push_back(info); + storage.print_funcs.push_back([](std::ostream& out, stack_allocator& stack) { + out << stack.pop(); + }); + storage.names.push_back(op.get_name()); if (is_static) storage.static_types.insert(operator_id); return *this; @@ -279,6 +297,16 @@ namespace blt::gp return storage.operators[id]; } + inline detail::print_func_t& get_print_func(operator_id id) + { + return storage.print_funcs[id]; + } + + inline std::optional get_name(operator_id id) + { + return storage.names[id]; + } + inline std::vector& get_type_terminals(type_id id) { return storage.terminals[id]; diff --git a/include/blt/gp/tree.h b/include/blt/gp/tree.h index d209209..11e97b5 100644 --- a/include/blt/gp/tree.h +++ b/include/blt/gp/tree.h @@ -26,6 +26,7 @@ #include #include +#include namespace blt::gp { @@ -100,6 +101,8 @@ namespace blt::gp auto results = evaluate(context); return results.values.pop(); } + + void print(gp_program& program, std::ostream& output, bool print_literals = true, bool pretty_indent = false, bool include_types = false); private: std::vector operations; diff --git a/src/transformers.cpp b/src/transformers.cpp index 9b06139..094de6d 100644 --- a/src/transformers.cpp +++ b/src/transformers.cpp @@ -45,6 +45,10 @@ namespace blt::gp std::uniform_int_distribution op_sel2(3ul, c2_ops.size() - 1); blt::size_t crossover_point = op_sel1(program.get_random()); + + while (config.avoid_terminals && program.get_operator_info(c1_ops[crossover_point].id).argc.is_terminal()) + crossover_point = op_sel1(program.get_random()); + blt::size_t attempted_point = 0; const auto& crossover_point_type = program.get_operator_info(c1_ops[crossover_point].id); @@ -57,16 +61,22 @@ namespace blt::gp { if (config.should_crossover_try_forward) { + bool found = false; for (auto i = attempted_point + 1; i < c2_ops.size(); i++) { auto* info = &program.get_operator_info(c2_ops[i].id); if (info->return_type == crossover_point_type.return_type) { + if (config.avoid_terminals && info->argc.is_terminal()) + continue; attempted_point = i; attempted_point_type = info; + found = true; break; } } + if (!found) + return blt::unexpected(error_t::NO_VALID_TYPE); } // should we try again over the whole tree? probably not. return blt::unexpected(error_t::NO_VALID_TYPE); @@ -74,9 +84,13 @@ namespace blt::gp { attempted_point = op_sel2(program.get_random()); attempted_point_type = &program.get_operator_info(c2_ops[attempted_point].id); + if (config.avoid_terminals && attempted_point_type->argc.is_terminal()) + continue; + if (crossover_point_type.return_type == attempted_point_type->return_type) + break; counter++; } - } while (crossover_point_type.return_type != attempted_point_type->return_type); + } while (true); blt::i64 children_left = 0; blt::size_t index = crossover_point; @@ -85,12 +99,15 @@ namespace blt::gp { const auto& type = program.get_operator_info(c1_ops[index].id); #if BLT_DEBUG_LEVEL > 1 - BLT_TRACE("Crossover type: %s, op %ld", std::string(program.get_typesystem().get_type(type.return_type).name()).c_str(), c1_ops[index].id); + #define MAKE_C_STR() program.get_name(c1_ops[index].id).has_value() ? std::string(program.get_name(c1_ops[index].id).value()).c_str() : std::to_string(c1_ops[index].id).c_str() + BLT_TRACE("Crossover type: %s, op: %s", std::string(program.get_typesystem().get_type(type.return_type).name()).c_str(), MAKE_C_STR()); + #undef MAKE_C_STR #endif + // this is a child to someone + if (children_left != 0) + children_left--; if (type.argc.argc > 0) children_left += type.argc.argc; - else - children_left--; index++; } while (children_left > 0); @@ -107,12 +124,16 @@ namespace blt::gp { const auto& type = program.get_operator_info(c2_ops[index].id); #if BLT_DEBUG_LEVEL > 1 - BLT_TRACE("Found type: %s, op: %ld", std::string(program.get_typesystem().get_type(type.return_type).name()).c_str(), c2_ops[index].id); + #define MAKE_C_STR() program.get_name(c2_ops[index].id).has_value() ? std::string(program.get_name(c2_ops[index].id).value()).c_str() : std::to_string(c2_ops[index].id).c_str() + BLT_TRACE("Found type: %s, op: %s", std::string(program.get_typesystem().get_type(type.return_type).name()).c_str(), MAKE_C_STR()); + #undef MAKE_C_STR #endif + // this is a child to someone + if (children_left != 0) + children_left--; if (type.argc.argc > 0) children_left += type.argc.argc; - else - children_left--; + index++; } while (children_left > 0); diff --git a/src/tree.cpp b/src/tree.cpp index dc5f9e4..759a24b 100644 --- a/src/tree.cpp +++ b/src/tree.cpp @@ -19,6 +19,8 @@ #include #include #include +#include +#include namespace blt::gp { @@ -49,4 +51,102 @@ namespace blt::gp return results; } + + std::ostream& create_indent(std::ostream& out, blt::size_t amount, bool pretty_print) + { + if (!pretty_print) + return out; + for (blt::size_t i = 0; i < amount; i++) + out << '\t'; + return out; + } + + std::string_view end_indent(bool pretty_print) + { + return pretty_print ? "\n" : ""; + } + + std::string get_return_type(gp_program& program, type_id id, bool use_returns) + { + if (!use_returns) + return ""; + return "(" + std::string(program.get_typesystem().get_type(id).name()) + ")"; + } + + void tree_t::print(gp_program& program, std::ostream& out, bool print_literals, bool pretty_print, bool include_types) + { + std::stack arguments_left; + blt::size_t indent = 0; + + stack_allocator reversed; + if (print_literals) + { + // I hate this. + stack_allocator copy = values; + + // reverse the order of the stack + for (const auto& v : operations) + { + if (v.is_value) + v.transfer(reversed, copy); + } + } + for (const auto& v : operations) + { + auto info = program.get_operator_info(v.id); + auto name = program.get_name(v.id) ? program.get_name(v.id).value() : "NULL"; + auto return_type = get_return_type(program, info.return_type, include_types); + if (info.argc.argc > 0) + { + create_indent(out, indent, pretty_print) << "("; + indent++; + arguments_left.emplace(info.argc.argc); + out << name << return_type << end_indent(pretty_print); + } else + { + if (print_literals) + { + create_indent(out, indent, pretty_print); + program.get_print_func(v.id)(out, reversed); + out << return_type << end_indent(pretty_print); + } else + create_indent(out, indent, pretty_print) << name << return_type << end_indent(pretty_print); + } + + while (!arguments_left.empty()) + { + auto top = arguments_left.top(); + arguments_left.pop(); + if (top == 0) + { + indent--; + create_indent(out, indent, pretty_print) << ")" << end_indent(pretty_print); + continue; + } else + { + if (!pretty_print) + out << " "; + arguments_left.push(top - 1); + break; + } + } + } + while (!arguments_left.empty()) + { + auto top = arguments_left.top(); + arguments_left.pop(); + if (top == 0) + { + indent--; + create_indent(out, indent, pretty_print) << ")" << end_indent(pretty_print); + continue; + } else + { + BLT_ERROR("Failed to print tree correctly!"); + break; + } + } + + out << '\n'; + } } \ No newline at end of file