crossover working, test 5 now fully tests crossover on a production sized scale

thread
Brett 2024-07-03 21:27:57 -04:00
parent 30a4a0e8d7
commit f37582ab4a
8 changed files with 271 additions and 41 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(blt-gp VERSION 0.0.49) project(blt-gp VERSION 0.0.50)
include(CTest) include(CTest)

View File

@ -37,32 +37,35 @@
#include <blt/gp/tree.h> #include <blt/gp/tree.h>
#include <blt/std/logging.h> #include <blt/std/logging.h>
#include <blt/gp/transformers.h> #include <blt/gp/transformers.h>
#include <string_view>
#include <iostream>
static constexpr long SEED = 41912; static constexpr long SEED = 41912;
blt::gp::type_provider type_system; blt::gp::type_provider type_system;
blt::gp::gp_program program(type_system, std::mt19937_64{SEED}); // NOLINT blt::gp::gp_program program(type_system, std::mt19937_64{SEED}); // NOLINT
blt::gp::operation_t add([](float a, float b) { return a + b; }); // 0 blt::gp::operation_t add([](float a, float b) { return a + b; }, "add"); // 0
blt::gp::operation_t sub([](float a, float b) { return a - b; }); // 1 blt::gp::operation_t sub([](float a, float b) { return a - b; }, "sub"); // 1
blt::gp::operation_t mul([](float a, float b) { return a * b; }); // 2 blt::gp::operation_t mul([](float a, float b) { return a * b; }, "mul"); // 2
blt::gp::operation_t pro_div([](float a, float b) { return b == 0 ? 0.0f : a / b; }); // 3 blt::gp::operation_t pro_div([](float a, float b) { return b == 0 ? 0.0f : a / b; }, "div"); // 3
blt::gp::operation_t op_if([](bool b, float a, float c) { return b ? a : c; }); // 4 blt::gp::operation_t op_if([](bool b, float a, float c) { return b ? a : c; }, "if"); // 4
blt::gp::operation_t eq_f([](float a, float b) { return a == b; }); // 5 blt::gp::operation_t eq_f([](float a, float b) { return a == b; }, "eq_f"); // 5
blt::gp::operation_t eq_b([](bool a, bool b) { return a == b; }); // 6 blt::gp::operation_t eq_b([](bool a, bool b) { return a == b; }, "eq_b"); // 6
blt::gp::operation_t lt([](float a, float b) { return a < b; }); // 7 blt::gp::operation_t lt([](float a, float b) { return a < b; }, "lt"); // 7
blt::gp::operation_t gt([](float a, float b) { return a > b; }); // 8 blt::gp::operation_t gt([](float a, float b) { return a > b; }, "gt"); // 8
blt::gp::operation_t op_and([](bool a, bool b) { return a && b; }); // 9 blt::gp::operation_t op_and([](bool a, bool b) { return a && b; }, "and"); // 9
blt::gp::operation_t op_or([](bool a, bool b) { return a || b; }); // 10 blt::gp::operation_t op_or([](bool a, bool b) { return a || b; }, "or"); // 10
blt::gp::operation_t op_xor([](bool a, bool b) { return static_cast<bool>(a ^ b); }); // 11 blt::gp::operation_t op_xor([](bool a, bool b) { return static_cast<bool>(a ^ b); }, "xor"); // 11
blt::gp::operation_t op_not([](bool b) { return !b; }); // 12 blt::gp::operation_t op_not([](bool b) { return !b; }, "not"); // 12
blt::gp::operation_t lit([]() { // 13 blt::gp::operation_t lit([]() { // 13
//static std::uniform_real_distribution<float> dist(-32000, 32000); //static std::uniform_real_distribution<float> dist(-32000, 32000);
static std::uniform_real_distribution<float> dist(0.0f, 10.0f); static std::uniform_real_distribution<float> dist(0.0f, 10.0f);
return dist(program.get_random()); return dist(program.get_random());
}); }, "lit");
/** /**
* This is a test using multiple types with blt::gp * This is a test using multiple types with blt::gp
@ -106,29 +109,88 @@ int main()
blt::gp::crossover_t crossover; blt::gp::crossover_t crossover;
auto& ind = pop.getIndividuals(); auto& ind = pop.getIndividuals();
auto results = crossover.apply(program, ind[0], ind[1]);
BLT_INFO("Post crossover:");
if (results.has_value()) std::vector<float> pre;
std::vector<float> pos;
blt::size_t errors = 0;
BLT_INFO("Pre-Crossover:");
for (auto& tree : pop.getIndividuals())
{ {
BLT_TRACE("Parent 1: %f", ind[0].get_evaluation_value<float>(nullptr)); auto f = tree.get_evaluation_value<float>(nullptr);
BLT_TRACE("Parent 2: %f", ind[1].get_evaluation_value<float>(nullptr)); pre.push_back(f);
BLT_TRACE("------------"); BLT_TRACE(f);
BLT_TRACE("Child 1: %f", results->child1.get_evaluation_value<float>(nullptr)); }
BLT_TRACE("Child 2: %f", results->child2.get_evaluation_value<float>(nullptr));
} else BLT_INFO("Crossover:");
blt::gp::population_t new_pop;
while (new_pop.getIndividuals().size() < pop.getIndividuals().size())
{ {
switch (results.error()) auto& random = program.get_random();
std::uniform_int_distribution dist(0ul, pop.getIndividuals().size() - 1);
blt::size_t first = dist(random);
blt::size_t second;
do
{ {
case blt::gp::crossover_t::error_t::NO_VALID_TYPE: second = dist(random);
BLT_ERROR("No valid type!"); } while (second == first);
break;
case blt::gp::crossover_t::error_t::TREE_TOO_SMALL: auto results = crossover.apply(program, ind[first], ind[second]);
BLT_ERROR("Tree is too small!"); if (results.has_value())
break; {
// bool print_literals = true;
// bool pretty_print = false;
// bool print_returns = false;
// BLT_TRACE("Parent 1: %f", ind[0].get_evaluation_value<float>(nullptr));
// ind[0].print(program, std::cout, print_literals, pretty_print, print_returns);
// BLT_TRACE("Parent 2: %f", ind[1].get_evaluation_value<float>(nullptr));
// ind[1].print(program, std::cout, print_literals, pretty_print, print_returns);
// BLT_TRACE("------------");
// BLT_TRACE("Child 1: %f", results->child1.get_evaluation_value<float>(nullptr));
// results->child1.print(program, std::cout, print_literals, pretty_print, print_returns);
// BLT_TRACE("Child 2: %f", results->child2.get_evaluation_value<float>(nullptr));
// results->child2.print(program, std::cout, print_literals, pretty_print, print_returns);
new_pop.getIndividuals().push_back(std::move(results->child1));
new_pop.getIndividuals().push_back(std::move(results->child2));
} else
{
switch (results.error())
{
case blt::gp::crossover_t::error_t::NO_VALID_TYPE:
BLT_ERROR("No valid type!");
break;
case blt::gp::crossover_t::error_t::TREE_TOO_SMALL:
BLT_ERROR("Tree is too small!");
break;
}
errors++;
new_pop.getIndividuals().push_back(ind[first]);
new_pop.getIndividuals().push_back(ind[second]);
} }
} }
BLT_INFO("Post-Crossover:");
for (auto& tree : new_pop.getIndividuals())
{
auto f = tree.get_evaluation_value<float>(nullptr);
pos.push_back(f);
BLT_TRACE(f);
}
BLT_INFO("Stats:");
blt::size_t eq = 0;
for (const auto& v : pos)
{
for (const auto m : pre)
{
if (v == m)
{
eq++;
break;
}
}
}
BLT_INFO("Equal values: %ld", eq);
BLT_INFO("Error times: %ld", errors);
return 0; return 0;
} }

View File

@ -22,6 +22,7 @@
#include <functional> #include <functional>
#include <blt/std/logging.h> #include <blt/std/logging.h>
#include <blt/std/types.h> #include <blt/std/types.h>
#include <ostream>
namespace blt::gp namespace blt::gp
{ {
@ -55,6 +56,8 @@ namespace blt::gp
using callable_t = std::function<void(void*, stack_allocator&, stack_allocator&)>; using callable_t = std::function<void(void*, stack_allocator&, stack_allocator&)>;
// to, from // to, from
using transfer_t = std::function<void(stack_allocator&, stack_allocator&)>; using transfer_t = std::function<void(stack_allocator&, stack_allocator&)>;
// debug function,
using print_func_t = std::function<void(std::ostream&, stack_allocator&)>;
} }
} }

View File

@ -24,6 +24,7 @@
#include <blt/gp/stack.h> #include <blt/gp/stack.h>
#include <functional> #include <functional>
#include <type_traits> #include <type_traits>
#include <optional>
namespace blt::gp namespace blt::gp
{ {
@ -83,7 +84,7 @@ namespace blt::gp
template<typename Func, blt::u64... indices, typename... ExtraArgs> template<typename Func, blt::u64... indices, typename... ExtraArgs>
inline static constexpr Return exec_sequence_to_indices(Func&& func, stack_allocator& allocator, std::integer_sequence<blt::u64, indices...>, inline static constexpr Return exec_sequence_to_indices(Func&& func, stack_allocator& allocator, std::integer_sequence<blt::u64, indices...>,
ExtraArgs&&... args) ExtraArgs&& ... args)
{ {
// expands Args and indices, providing each argument with its index calculating the current argument byte offset // expands Args and indices, providing each argument with its index calculating the current argument byte offset
return std::forward<Func>(func)(std::forward<ExtraArgs>(args)..., allocator.from<Args>(getByteOffset<indices>())...); return std::forward<Func>(func)(std::forward<ExtraArgs>(args)..., allocator.from<Args>(getByteOffset<indices>())...);
@ -120,7 +121,7 @@ namespace blt::gp
constexpr operation_t(operation_t&& move) = default; constexpr operation_t(operation_t&& move) = default;
template<typename Functor> template<typename Functor>
constexpr explicit operation_t(const Functor& functor): func(functor) constexpr explicit operation_t(const Functor& functor, std::optional<std::string_view> name = {}): func(functor), name(name)
{} {}
[[nodiscard]] constexpr inline Return operator()(stack_allocator& read_allocator) const [[nodiscard]] constexpr inline Return operator()(stack_allocator& read_allocator) const
@ -172,8 +173,14 @@ namespace blt::gp
return sizeof...(Args); return sizeof...(Args);
} }
[[nodiscard]] inline constexpr std::optional<std::string_view> get_name() const
{
return name;
}
private: private:
function_t func; function_t func;
std::optional<std::string_view> name;
}; };
template<typename Return, typename Class, typename... Args> template<typename Return, typename Class, typename... Args>
@ -189,6 +196,12 @@ namespace blt::gp
template<typename Return, typename... Args> template<typename Return, typename... Args>
operation_t(Return(*)(Args...)) -> operation_t<Return(Args...)>; operation_t(Return(*)(Args...)) -> operation_t<Return(Args...)>;
template<typename Lambda>
operation_t(Lambda, std::optional<std::string_view>) -> operation_t<decltype(&Lambda::operator())>;
template<typename Return, typename... Args>
operation_t(Return(*)(Args...), std::optional<std::string_view>) -> operation_t<Return(Args...)>;
// templat\e<typename Return, typename Class, typename... Args> // templat\e<typename Return, typename Class, typename... Args>
// operation_t<Return(Args...)> make_operator(Return (Class::*)(Args...) const lambda) // operation_t<Return(Args...)> make_operator(Return (Class::*)(Args...) const lambda)
// { // {

View File

@ -48,6 +48,11 @@ namespace blt::gp
{ {
blt::u32 argc = 0; blt::u32 argc = 0;
blt::u32 argc_context = 0; blt::u32 argc_context = 0;
[[nodiscard]] bool is_terminal() const
{
return argc == 0;
}
}; };
struct config_t struct config_t
@ -56,14 +61,21 @@ namespace blt::gp
blt::u16 max_crossover_tries = 5; blt::u16 max_crossover_tries = 5;
// if we fail to find a point in the tree, should we search forward from the last point to the end of the operators? // if we fail to find a point in the tree, should we search forward from the last point to the end of the operators?
bool should_crossover_try_forward = false; bool should_crossover_try_forward = false;
// avoid selecting terminals when doing crossover
bool avoid_terminals = false;
}; };
struct operator_info struct operator_info
{ {
// types of the arguments
std::vector<type_id> argument_types; std::vector<type_id> argument_types;
// return type of this operator
type_id return_type; type_id return_type;
// number of arguments for this operator
argc_t argc; argc_t argc;
// function to call this operator
detail::callable_t function; detail::callable_t function;
// function used to transfer values between stacks
detail::transfer_t transfer; detail::transfer_t transfer;
}; };
@ -80,6 +92,8 @@ namespace blt::gp
// std::vector<detail::callable_t> operators; // std::vector<detail::callable_t> operators;
// std::vector<detail::transfer_t> transfer_funcs; // std::vector<detail::transfer_t> transfer_funcs;
std::vector<operator_info> operators; std::vector<operator_info> operators;
std::vector<detail::print_func_t> print_funcs;
std::vector<std::optional<std::string_view>> names;
}; };
template<typename Context = detail::empty_t> template<typename Context = detail::empty_t>
@ -120,7 +134,7 @@ namespace blt::gp
info.transfer = [](stack_allocator& to, stack_allocator& from) { info.transfer = [](stack_allocator& to, stack_allocator& from) {
#if BLT_DEBUG_LEVEL >= 3 #if BLT_DEBUG_LEVEL >= 3
auto value = from.pop<Return>(); auto value = from.pop<Return>();
BLT_TRACE_STREAM << value << "\n"; //BLT_TRACE_STREAM << value << "\n";
to.push(value); to.push(value);
#else #else
to.push(from.pop<Return>()); to.push(from.pop<Return>());
@ -128,6 +142,10 @@ namespace blt::gp
}; };
storage.operators.push_back(info); storage.operators.push_back(info);
storage.print_funcs.push_back([](std::ostream& out, stack_allocator& stack) {
out << stack.pop<Return>();
});
storage.names.push_back(op.get_name());
if (is_static) if (is_static)
storage.static_types.insert(operator_id); storage.static_types.insert(operator_id);
return *this; return *this;
@ -279,6 +297,16 @@ namespace blt::gp
return storage.operators[id]; return storage.operators[id];
} }
inline detail::print_func_t& get_print_func(operator_id id)
{
return storage.print_funcs[id];
}
inline std::optional<std::string_view> get_name(operator_id id)
{
return storage.names[id];
}
inline std::vector<operator_id>& get_type_terminals(type_id id) inline std::vector<operator_id>& get_type_terminals(type_id id)
{ {
return storage.terminals[id]; return storage.terminals[id];

View File

@ -26,6 +26,7 @@
#include <utility> #include <utility>
#include <stack> #include <stack>
#include <ostream>
namespace blt::gp namespace blt::gp
{ {
@ -101,6 +102,8 @@ namespace blt::gp
return results.values.pop<T>(); return results.values.pop<T>();
} }
void print(gp_program& program, std::ostream& output, bool print_literals = true, bool pretty_indent = false, bool include_types = false);
private: private:
std::vector<op_container_t> operations; std::vector<op_container_t> operations;
blt::gp::stack_allocator values; blt::gp::stack_allocator values;

View File

@ -45,6 +45,10 @@ namespace blt::gp
std::uniform_int_distribution op_sel2(3ul, c2_ops.size() - 1); std::uniform_int_distribution op_sel2(3ul, c2_ops.size() - 1);
blt::size_t crossover_point = op_sel1(program.get_random()); blt::size_t crossover_point = op_sel1(program.get_random());
while (config.avoid_terminals && program.get_operator_info(c1_ops[crossover_point].id).argc.is_terminal())
crossover_point = op_sel1(program.get_random());
blt::size_t attempted_point = 0; blt::size_t attempted_point = 0;
const auto& crossover_point_type = program.get_operator_info(c1_ops[crossover_point].id); const auto& crossover_point_type = program.get_operator_info(c1_ops[crossover_point].id);
@ -57,16 +61,22 @@ namespace blt::gp
{ {
if (config.should_crossover_try_forward) if (config.should_crossover_try_forward)
{ {
bool found = false;
for (auto i = attempted_point + 1; i < c2_ops.size(); i++) for (auto i = attempted_point + 1; i < c2_ops.size(); i++)
{ {
auto* info = &program.get_operator_info(c2_ops[i].id); auto* info = &program.get_operator_info(c2_ops[i].id);
if (info->return_type == crossover_point_type.return_type) if (info->return_type == crossover_point_type.return_type)
{ {
if (config.avoid_terminals && info->argc.is_terminal())
continue;
attempted_point = i; attempted_point = i;
attempted_point_type = info; attempted_point_type = info;
found = true;
break; break;
} }
} }
if (!found)
return blt::unexpected(error_t::NO_VALID_TYPE);
} }
// should we try again over the whole tree? probably not. // should we try again over the whole tree? probably not.
return blt::unexpected(error_t::NO_VALID_TYPE); return blt::unexpected(error_t::NO_VALID_TYPE);
@ -74,9 +84,13 @@ namespace blt::gp
{ {
attempted_point = op_sel2(program.get_random()); attempted_point = op_sel2(program.get_random());
attempted_point_type = &program.get_operator_info(c2_ops[attempted_point].id); attempted_point_type = &program.get_operator_info(c2_ops[attempted_point].id);
if (config.avoid_terminals && attempted_point_type->argc.is_terminal())
continue;
if (crossover_point_type.return_type == attempted_point_type->return_type)
break;
counter++; counter++;
} }
} while (crossover_point_type.return_type != attempted_point_type->return_type); } while (true);
blt::i64 children_left = 0; blt::i64 children_left = 0;
blt::size_t index = crossover_point; blt::size_t index = crossover_point;
@ -85,12 +99,15 @@ namespace blt::gp
{ {
const auto& type = program.get_operator_info(c1_ops[index].id); const auto& type = program.get_operator_info(c1_ops[index].id);
#if BLT_DEBUG_LEVEL > 1 #if BLT_DEBUG_LEVEL > 1
BLT_TRACE("Crossover type: %s, op %ld", std::string(program.get_typesystem().get_type(type.return_type).name()).c_str(), c1_ops[index].id); #define MAKE_C_STR() program.get_name(c1_ops[index].id).has_value() ? std::string(program.get_name(c1_ops[index].id).value()).c_str() : std::to_string(c1_ops[index].id).c_str()
BLT_TRACE("Crossover type: %s, op: %s", std::string(program.get_typesystem().get_type(type.return_type).name()).c_str(), MAKE_C_STR());
#undef MAKE_C_STR
#endif #endif
// this is a child to someone
if (children_left != 0)
children_left--;
if (type.argc.argc > 0) if (type.argc.argc > 0)
children_left += type.argc.argc; children_left += type.argc.argc;
else
children_left--;
index++; index++;
} while (children_left > 0); } while (children_left > 0);
@ -107,12 +124,16 @@ namespace blt::gp
{ {
const auto& type = program.get_operator_info(c2_ops[index].id); const auto& type = program.get_operator_info(c2_ops[index].id);
#if BLT_DEBUG_LEVEL > 1 #if BLT_DEBUG_LEVEL > 1
BLT_TRACE("Found type: %s, op: %ld", std::string(program.get_typesystem().get_type(type.return_type).name()).c_str(), c2_ops[index].id); #define MAKE_C_STR() program.get_name(c2_ops[index].id).has_value() ? std::string(program.get_name(c2_ops[index].id).value()).c_str() : std::to_string(c2_ops[index].id).c_str()
BLT_TRACE("Found type: %s, op: %s", std::string(program.get_typesystem().get_type(type.return_type).name()).c_str(), MAKE_C_STR());
#undef MAKE_C_STR
#endif #endif
// this is a child to someone
if (children_left != 0)
children_left--;
if (type.argc.argc > 0) if (type.argc.argc > 0)
children_left += type.argc.argc; children_left += type.argc.argc;
else
children_left--;
index++; index++;
} while (children_left > 0); } while (children_left > 0);

View File

@ -19,6 +19,8 @@
#include <blt/gp/stack.h> #include <blt/gp/stack.h>
#include <blt/std/assert.h> #include <blt/std/assert.h>
#include <blt/std/logging.h> #include <blt/std/logging.h>
#include <blt/gp/program.h>
#include <stack>
namespace blt::gp namespace blt::gp
{ {
@ -49,4 +51,102 @@ namespace blt::gp
return results; return results;
} }
std::ostream& create_indent(std::ostream& out, blt::size_t amount, bool pretty_print)
{
if (!pretty_print)
return out;
for (blt::size_t i = 0; i < amount; i++)
out << '\t';
return out;
}
std::string_view end_indent(bool pretty_print)
{
return pretty_print ? "\n" : "";
}
std::string get_return_type(gp_program& program, type_id id, bool use_returns)
{
if (!use_returns)
return "";
return "(" + std::string(program.get_typesystem().get_type(id).name()) + ")";
}
void tree_t::print(gp_program& program, std::ostream& out, bool print_literals, bool pretty_print, bool include_types)
{
std::stack<blt::size_t> arguments_left;
blt::size_t indent = 0;
stack_allocator reversed;
if (print_literals)
{
// I hate this.
stack_allocator copy = values;
// reverse the order of the stack
for (const auto& v : operations)
{
if (v.is_value)
v.transfer(reversed, copy);
}
}
for (const auto& v : operations)
{
auto info = program.get_operator_info(v.id);
auto name = program.get_name(v.id) ? program.get_name(v.id).value() : "NULL";
auto return_type = get_return_type(program, info.return_type, include_types);
if (info.argc.argc > 0)
{
create_indent(out, indent, pretty_print) << "(";
indent++;
arguments_left.emplace(info.argc.argc);
out << name << return_type << end_indent(pretty_print);
} else
{
if (print_literals)
{
create_indent(out, indent, pretty_print);
program.get_print_func(v.id)(out, reversed);
out << return_type << end_indent(pretty_print);
} else
create_indent(out, indent, pretty_print) << name << return_type << end_indent(pretty_print);
}
while (!arguments_left.empty())
{
auto top = arguments_left.top();
arguments_left.pop();
if (top == 0)
{
indent--;
create_indent(out, indent, pretty_print) << ")" << end_indent(pretty_print);
continue;
} else
{
if (!pretty_print)
out << " ";
arguments_left.push(top - 1);
break;
}
}
}
while (!arguments_left.empty())
{
auto top = arguments_left.top();
arguments_left.pop();
if (top == 0)
{
indent--;
create_indent(out, indent, pretty_print) << ")" << end_indent(pretty_print);
continue;
} else
{
BLT_ERROR("Failed to print tree correctly!");
break;
}
}
out << '\n';
}
} }