need to add leaf nodes back

shared
Brett 2024-08-20 13:07:33 -04:00
parent 52732bc1a1
commit 82b8c82768
10 changed files with 281 additions and 210 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(blt-gp VERSION 0.1.13) project(blt-gp VERSION 0.1.14)
include(CTest) include(CTest)

View File

@ -54,9 +54,9 @@ blt::gp::operation_t op_cos([](float a) { return std::cos(a); }, "cos");
blt::gp::operation_t op_exp([](float a) { return std::exp(a); }, "exp"); blt::gp::operation_t op_exp([](float a) { return std::exp(a); }, "exp");
blt::gp::operation_t op_log([](float a) { return a == 0.0f ? 0.0f : std::log(a); }, "log"); blt::gp::operation_t op_log([](float a) { return a == 0.0f ? 0.0f : std::log(a); }, "log");
blt::gp::operation_t lit([]() { auto lit = blt::gp::operation_t([]() {
return program.get_random().get_float(-320.0f, 320.0f); return program.get_random().get_float(-320.0f, 320.0f);
}, "lit"); }, "lit").set_ephemeral();
blt::gp::operation_t op_x([](const context& context) { blt::gp::operation_t op_x([](const context& context) {
return context.x; return context.x;
}, "x"); }, "x");
@ -65,7 +65,7 @@ constexpr auto fitness_function = [](blt::gp::tree_t& current_tree, blt::gp::fit
constexpr double value_cutoff = 1.e15; constexpr double value_cutoff = 1.e15;
for (auto& fitness_case : fitness_cases) for (auto& fitness_case : fitness_cases)
{ {
auto diff = std::abs(fitness_case.y - current_tree.get_evaluation_value<float>(&fitness_case)); auto diff = std::abs(fitness_case.y - current_tree.get_evaluation_value<float>(&fitness_case, program.get_eval_func()));
if (diff < value_cutoff) if (diff < value_cutoff)
{ {
fitness.raw_fitness += diff; fitness.raw_fitness += diff;
@ -102,19 +102,19 @@ int main()
type_system.register_type<float>(); type_system.register_type<float>();
blt::gp::operator_builder<context> builder{type_system}; blt::gp::operator_builder<context> builder{type_system};
builder.add_operator(add); // builder.add_operator(add);
builder.add_operator(sub); // builder.add_operator(sub);
builder.add_operator(mul); // builder.add_operator(mul);
builder.add_operator(pro_div); // builder.add_operator(pro_div);
builder.add_operator(op_sin); // builder.add_operator(op_sin);
builder.add_operator(op_cos); // builder.add_operator(op_cos);
builder.add_operator(op_exp); // builder.add_operator(op_exp);
builder.add_operator(op_log); // builder.add_operator(op_log);
//
// builder.add_operator(lit, true);
// builder.add_operator(op_x);
builder.add_operator(lit, true); program.set_operations(builder.build(add, sub, mul, pro_div, op_sin, op_cos, op_exp, op_log, lit, op_x));
builder.add_operator(op_x);
program.set_operations(builder.build());
BLT_DEBUG("Generate Initial Population"); BLT_DEBUG("Generate Initial Population");
auto sel = blt::gp::select_fitness_proportionate_t{}; auto sel = blt::gp::select_fitness_proportionate_t{};

View File

@ -57,7 +57,8 @@ namespace blt::gp
class operator_storage_test; class operator_storage_test;
// context*, read stack, write stack // context*, read stack, write stack
using callable_t = std::function<void(void*, stack_allocator&, stack_allocator&, bitmask_t*)>; //using callable_t = std::function<void(void*, stack_allocator&, stack_allocator&, bitmask_t*)>;
using eval_func_t = std::function<evaluation_context(const tree_t& tree, void* context)>;
// debug function, // debug function,
using print_func_t = std::function<void(std::ostream&, stack_allocator&)>; using print_func_t = std::function<void(std::ostream&, stack_allocator&)>;

View File

@ -144,11 +144,12 @@ namespace blt::gp
template<typename, typename> template<typename, typename>
class operation_t; class operation_t;
template<typename ArgType, typename Return, typename... Args> template<typename RawFunction, typename Return, typename... Args>
class operation_t<ArgType, Return(Args...)> class operation_t<RawFunction, Return(Args...)>
{ {
public: public:
using function_t = ArgType; using function_t = RawFunction;
using First_Arg = typename blt::meta::arg_helper<Args...>::First;
constexpr operation_t(const operation_t& copy) = default; constexpr operation_t(const operation_t& copy) = default;
@ -186,21 +187,21 @@ namespace blt::gp
} }
} }
template<typename Context> // template<typename Context>
[[nodiscard]] detail::callable_t make_callable() const // [[nodiscard]] detail::callable_t make_callable() const
{ // {
return [this](void* context, stack_allocator& read_allocator, stack_allocator& write_allocator, detail::bitmask_t* mask) { // return [this](void* context, stack_allocator& read_allocator, stack_allocator& write_allocator, detail::bitmask_t* mask) {
if constexpr (detail::is_same_v<Context, detail::remove_cv_ref<typename detail::first_arg<Args...>::type>>) // if constexpr (detail::is_same_v<Context, detail::remove_cv_ref<typename detail::first_arg<Args...>::type>>)
{ // {
// first arg is context // // first arg is context
write_allocator.push(this->operator()(context, read_allocator, mask)); // write_allocator.push(this->operator()(context, read_allocator, mask));
} else // } else
{ // {
// first arg isn't context // // first arg isn't context
write_allocator.push(this->operator()(read_allocator, mask)); // write_allocator.push(this->operator()(read_allocator, mask));
} // }
}; // };
} // }
[[nodiscard]] inline constexpr std::optional<std::string_view> get_name() const [[nodiscard]] inline constexpr std::optional<std::string_view> get_name() const
{ {
@ -212,38 +213,42 @@ namespace blt::gp
return func; return func;
} }
inline auto set_ephemeral()
{
is_ephemeral_ = true;
return *this;
}
inline bool is_ephemeral()
{
return is_ephemeral_;
}
operator_id id = -1; operator_id id = -1;
private: private:
function_t func; function_t func;
std::optional<std::string_view> name; std::optional<std::string_view> name;
bool is_ephemeral_ = false;
}; };
template<typename ArgType, typename Return, typename Class, typename... Args> template<typename RawFunction, typename Return, typename Class, typename... Args>
class operation_t<ArgType, Return (Class::*)(Args...) const> : public operation_t<ArgType, Return(Args...)> class operation_t<RawFunction, Return (Class::*)(Args...) const> : public operation_t<RawFunction, Return(Args...)>
{ {
public: public:
using operation_t<ArgType, Return(Args...)>::operation_t; using operation_t<RawFunction, Return(Args...)>::operation_t;
}; };
template<typename Lambda> template<typename Lambda>
operation_t(Lambda) operation_t(Lambda) -> operation_t<Lambda, decltype(&Lambda::operator())>;
->
operation_t<Lambda, decltype(&Lambda::operator())>;
template<typename Return, typename... Args> template<typename Return, typename... Args>
operation_t(Return(*) operation_t(Return(*)(Args...)) -> operation_t<Return(*)(Args...), Return(Args...)>;
(Args...)) ->
operation_t<Return(*)(Args...), Return(Args...)>;
template<typename Lambda> template<typename Lambda>
operation_t(Lambda, std::optional<std::string_view> operation_t(Lambda, std::optional<std::string_view>) -> operation_t<Lambda, decltype(&Lambda::operator())>;
) ->
operation_t<Lambda, decltype(&Lambda::operator())>;
template<typename Return, typename... Args> template<typename Return, typename... Args>
operation_t(Return(*) operation_t(Return(*)(Args...), std::optional<std::string_view>) -> operation_t<Return(*)(Args...), Return(Args...)>;
(Args...), std::optional<std::string_view>) ->
operation_t<Return(*)(Args...), Return(Args...)>;
} }
#endif //BLT_GP_OPERATIONS_H #endif //BLT_GP_OPERATIONS_H

View File

@ -75,10 +75,6 @@ namespace blt::gp
type_id return_type; type_id return_type;
// number of arguments for this operator // number of arguments for this operator
argc_t argc; argc_t argc;
// function to call this operator
detail::callable_t function;
// function used to transfer values between stacks
//detail::transfer_t transfer;
}; };
struct operator_storage struct operator_storage
@ -93,6 +89,8 @@ namespace blt::gp
std::vector<detail::print_func_t> print_funcs; std::vector<detail::print_func_t> print_funcs;
std::vector<detail::destroy_func_t> destroy_funcs; std::vector<detail::destroy_func_t> destroy_funcs;
std::vector<std::optional<std::string_view>> names; std::vector<std::optional<std::string_view>> names;
detail::eval_func_t eval_func;
}; };
template<typename Context = detail::empty_t> template<typename Context = detail::empty_t>
@ -106,65 +104,41 @@ namespace blt::gp
explicit operator_builder(type_provider& system): system(system) explicit operator_builder(type_provider& system): system(system)
{} {}
template<typename ArgType, typename Return, typename... Args> template<typename... Operators>
operator_builder& add_operator(operation_t<ArgType, Return(Args...)>& op, bool is_static = false) operator_storage& build(Operators& ... operators)
{ {
auto return_type_id = system.get_type<Return>().id(); std::vector<blt::size_t> sizes;
auto operator_id = blt::gp::operator_id(storage.operators.size()); (sizes.push_back(add_operator(operators)), ...);
op.id = operator_id; blt::size_t largest = 0;
for (auto v : sizes)
largest = std::max(v, largest);
operator_info info; storage.eval_func = [&operators..., largest](const tree_t& tree, void* context) {
const auto& ops = tree.get_operations();
const auto& vals = tree.get_values();
if constexpr (sizeof...(Args) > 0) evaluation_context results{};
auto value_stack = vals;
auto& values_process = results.values;
static thread_local detail::bitmask_t bitfield;
bitfield.clear();
for (const auto& operation : blt::reverse_iterate(ops.begin(), ops.end()))
{ {
(add_non_context_argument<detail::remove_cv_ref<Args>>(info.argument_types), ...); if (operation.is_value)
{
value_stack.transfer_bytes(values_process, operation.type_size);
bitfield.push_back(false);
continue;
}
call_jmp_table(operation.id, context, values_process, values_process, &bitfield, operators...);
bitfield.push_back(true);
} }
info.argc.argc_context = info.argc.argc = sizeof...(Args); return results;
info.return_type = system.get_type<Return>().id(); };
((std::is_same_v<detail::remove_cv_ref<Args>, Context> ? info.argc.argc -= 1 : (blt::size_t) nullptr), ...);
auto& operator_list = info.argc.argc == 0 ? storage.terminals : storage.non_terminals;
operator_list[return_type_id].push_back(operator_id);
BLT_ASSERT(info.argc.argc_context - info.argc.argc <= 1 && "Cannot pass multiple context as arguments!");
info.function = op.template make_callable<Context>();
storage.operators.push_back(info);
storage.print_funcs.push_back([&op](std::ostream& out, stack_allocator& stack) {
if constexpr (blt::meta::is_streamable_v<Return>)
{
out << stack.from<Return>(0);
(void) (op); // remove warning
} else
{
out << "[Printing Value on '" << (op.get_name() ? *op.get_name() : "") << "' Not Supported!]";
}
});
storage.destroy_funcs.push_back([](detail::destroy_t type, detail::bitmask_t* mask, stack_allocator& alloc) {
switch (type)
{
case detail::destroy_t::ARGS:
alloc.call_destructors<Args...>(mask);
break;
case detail::destroy_t::RETURN:
if constexpr (detail::has_func_drop_v<remove_cvref_t<Return>>)
{
alloc.from<detail::remove_cv_ref<Return>>(0).drop();
}
break;
}
});
storage.names.push_back(op.get_name());
if (is_static)
storage.static_types.insert(operator_id);
return *this;
}
operator_storage& build()
{
blt::hashset_t<type_id> has_terminals; blt::hashset_t<type_id> has_terminals;
for (const auto& v : blt::enumerate(storage.terminals)) for (const auto& v : blt::enumerate(storage.terminals))
@ -232,6 +206,64 @@ namespace blt::gp
} }
private: private:
template<typename RawFunction, typename Return, typename... Args>
auto add_operator(operation_t<RawFunction, Return(Args...)>& op)
{
auto total_size_required = stack_allocator::aligned_size(sizeof(Return));
((total_size_required += stack_allocator::aligned_size(sizeof(Args))) , ...);
auto return_type_id = system.get_type<Return>().id();
auto operator_id = blt::gp::operator_id(storage.operators.size());
op.id = operator_id;
operator_info info;
if constexpr (sizeof...(Args) > 0)
{
(add_non_context_argument<detail::remove_cv_ref<Args>>(info.argument_types), ...);
}
info.argc.argc_context = info.argc.argc = sizeof...(Args);
info.return_type = return_type_id;
((std::is_same_v<detail::remove_cv_ref<Args>, Context> ? info.argc.argc -= 1 : (blt::size_t) nullptr), ...);
auto& operator_list = info.argc.argc == 0 ? storage.terminals : storage.non_terminals;
operator_list[return_type_id].push_back(operator_id);
BLT_ASSERT(info.argc.argc_context - info.argc.argc <= 1 && "Cannot pass multiple context as arguments!");
storage.operators.push_back(info);
storage.print_funcs.push_back([&op](std::ostream& out, stack_allocator& stack) {
if constexpr (blt::meta::is_streamable_v<Return>)
{
out << stack.from<Return>(0);
(void) (op); // remove warning
} else
{
out << "[Printing Value on '" << (op.get_name() ? *op.get_name() : "") << "' Not Supported!]";
}
});
storage.destroy_funcs.push_back([](detail::destroy_t type, detail::bitmask_t* mask, stack_allocator& alloc) {
switch (type)
{
case detail::destroy_t::ARGS:
alloc.call_destructors<Args...>(mask);
break;
case detail::destroy_t::RETURN:
if constexpr (detail::has_func_drop_v<remove_cvref_t<Return>>)
{
alloc.from<detail::remove_cv_ref<Return>>(0).drop();
}
break;
}
});
storage.names.push_back(op.get_name());
if (op.is_ephemeral())
storage.static_types.insert(operator_id);
return total_size_required;
}
template<typename T> template<typename T>
void add_non_context_argument(decltype(operator_info::argument_types)& types) void add_non_context_argument(decltype(operator_info::argument_types)& types)
{ {
@ -241,6 +273,42 @@ namespace blt::gp
} }
} }
template<bool HasContext, size_t id, typename Lambda>
static inline bool execute(size_t op, void* context, stack_allocator& write_stack, stack_allocator& read_stack, detail::bitmask_t* mask,
Lambda lambda)
{
if (op == id)
{
if constexpr (HasContext)
{
write_stack.push(lambda(context, read_stack, mask));
} else
{
write_stack.push(lambda(read_stack, mask));
}
return false;
}
return true;
}
template<typename... Lambdas, size_t... operator_ids>
static inline void call_jmp_table_internal(size_t op, void* context, stack_allocator& write_stack, stack_allocator& read_stack,
detail::bitmask_t* mask, std::integer_sequence<size_t, operator_ids...>, Lambdas... lambdas)
{
if (op > sizeof...(operator_ids))
BLT_UNREACHABLE;
(execute<detail::is_same_v<typename Lambdas::First_Arg, Context>, operator_ids>(
op, context, write_stack, read_stack, mask, lambdas) && ...);
}
template<typename... Lambdas>
static inline void call_jmp_table(size_t op, void* context, stack_allocator& write_stack, stack_allocator& read_stack,
detail::bitmask_t* mask, Lambdas... lambdas)
{
call_jmp_table_internal(op, context, write_stack, read_stack, mask, std::index_sequence_for<Lambdas...>(),
lambdas...);
}
type_provider& system; type_provider& system;
operator_storage storage; operator_storage storage;
}; };
@ -603,6 +671,11 @@ namespace blt::gp
storage = std::move(op); storage = std::move(op);
} }
inline detail::eval_func_t& get_eval_func()
{
return storage.eval_func;
}
[[nodiscard]] inline auto get_current_generation() const [[nodiscard]] inline auto get_current_generation() const
{ {
return current_generation.load(); return current_generation.load();

View File

@ -34,11 +34,10 @@ namespace blt::gp
struct op_container_t struct op_container_t
{ {
op_container_t(detail::callable_t& func, blt::size_t type_size, operator_id id, bool is_value): op_container_t(blt::size_t type_size, operator_id id, bool is_value):
func(func), type_size(type_size), id(id), is_value(is_value) type_size(type_size), id(id), is_value(is_value)
{} {}
std::reference_wrapper<detail::callable_t> func;
blt::size_t type_size; blt::size_t type_size;
operator_id id; operator_id id;
bool is_value; bool is_value;
@ -46,11 +45,8 @@ namespace blt::gp
class evaluation_context class evaluation_context
{ {
friend class tree_t; public:
explicit evaluation_context() = default;
private:
explicit evaluation_context()
{}
blt::gp::stack_allocator values; blt::gp::stack_allocator values;
}; };
@ -86,7 +82,7 @@ namespace blt::gp
return values; return values;
} }
evaluation_context evaluate(void* context) const; evaluation_context evaluate(void* context, detail::eval_func_t& func) const;
blt::size_t get_depth(gp_program& program); blt::size_t get_depth(gp_program& program);
@ -112,9 +108,9 @@ namespace blt::gp
* Helper template for returning the result of evaluation (this calls it) * Helper template for returning the result of evaluation (this calls it)
*/ */
template<typename T> template<typename T>
T get_evaluation_value(void* context) T get_evaluation_value(void* context, detail::eval_func_t& func)
{ {
auto results = evaluate(context); auto results = evaluate(context, func);
return results.values.pop<T>(); return results.values.pop<T>();
} }

@ -1 +1 @@
Subproject commit 78710a12cca9ecf7f92394ddf66ed5e2c0301484 Subproject commit 9ce6c89ce0145902d31515194a707a9aca948121

View File

@ -63,7 +63,6 @@ namespace blt::gp
auto& info = args.program.get_operator_info(top.id); auto& info = args.program.get_operator_info(top.id);
tree.get_operations().emplace_back( tree.get_operations().emplace_back(
info.function,
args.program.get_typesystem().get_type(info.return_type).size(), args.program.get_typesystem().get_type(info.return_type).size(),
top.id, top.id,
args.program.is_static(top.id)); args.program.is_static(top.id));
@ -71,7 +70,7 @@ namespace blt::gp
if (args.program.is_static(top.id)) if (args.program.is_static(top.id))
{ {
info.function(nullptr, tree.get_values(), tree.get_values(), nullptr); //info.function(nullptr, tree.get_values(), tree.get_values(), nullptr);
continue; continue;
} }

View File

@ -457,8 +457,8 @@ namespace blt::gp
vals.copy_from(data, total_bytes_after); vals.copy_from(data, total_bytes_after);
} }
// now finally update the type. // now finally update the type.
ops[c_node] = {replacement_func_info.function, program.get_typesystem().get_type(replacement_func_info.return_type).size(), ops[c_node] = {program.get_typesystem().get_type(replacement_func_info.return_type).size(), random_replacement,
random_replacement, program.is_static(random_replacement)}; program.is_static(random_replacement)};
} }
#if BLT_DEBUG_LEVEL >= 2 #if BLT_DEBUG_LEVEL >= 2
if (!c.check(program, nullptr)) if (!c.check(program, nullptr))
@ -543,7 +543,7 @@ namespace blt::gp
vals.copy_from(combined_ptr + for_bytes, after_bytes); vals.copy_from(combined_ptr + for_bytes, after_bytes);
ops.insert(ops.begin() + static_cast<blt::ptrdiff_t>(c_node), ops.insert(ops.begin() + static_cast<blt::ptrdiff_t>(c_node),
{replacement_func_info.function, program.get_typesystem().get_type(replacement_func_info.return_type).size(), {program.get_typesystem().get_type(replacement_func_info.return_type).size(),
random_replacement, program.is_static(random_replacement)}); random_replacement, program.is_static(random_replacement)});
#if BLT_DEBUG_LEVEL >= 2 #if BLT_DEBUG_LEVEL >= 2

View File

@ -37,10 +37,7 @@ namespace blt::gp
return buffer; return buffer;
} }
inline auto empty_callable = detail::callable_t( evaluation_context tree_t::evaluate(void* context, detail::eval_func_t& func) const
[](void*, stack_allocator&, stack_allocator&, detail::bitmask_t*) { BLT_ABORT("This should never be called!"); });
evaluation_context tree_t::evaluate(void* context) const
{ {
#if BLT_DEBUG_LEVEL >= 2 #if BLT_DEBUG_LEVEL >= 2
blt::size_t expected_bytes = 0; blt::size_t expected_bytes = 0;
@ -56,27 +53,27 @@ namespace blt::gp
BLT_ABORT("Amount of bytes in stack doesn't match the number of bytes expected for the operations"); BLT_ABORT("Amount of bytes in stack doesn't match the number of bytes expected for the operations");
} }
#endif #endif
// copy the initial values // // copy the initial values
evaluation_context results{}; // evaluation_context results{};
//
// auto value_stack = values;
// auto& values_process = results.values;
// static thread_local detail::bitmask_t bitfield;
// bitfield.clear();
//
// for (const auto& operation : blt::reverse_iterate(operations.begin(), operations.end()))
// {
// if (operation.is_value)
// {
// value_stack.transfer_bytes(values_process, operation.type_size);
// bitfield.push_back(false);
// continue;
// }
// operation.func(context, values_process, values_process, &bitfield);
// bitfield.push_back(true);
// }
auto value_stack = values; return func(*this, context);
auto& values_process = results.values;
static thread_local detail::bitmask_t bitfield;
bitfield.clear();
for (const auto& operation : blt::reverse_iterate(operations.begin(), operations.end()))
{
if (operation.is_value)
{
value_stack.transfer_bytes(values_process, operation.type_size);
bitfield.push_back(false);
continue;
}
operation.func(context, values_process, values_process, &bitfield);
bitfield.push_back(true);
}
return results;
} }
std::ostream& create_indent(std::ostream& out, blt::size_t amount, bool pretty_print) std::ostream& create_indent(std::ostream& out, blt::size_t amount, bool pretty_print)
@ -216,7 +213,7 @@ namespace blt::gp
values_process.pop_back(); values_process.pop_back();
} }
value_stack.push_back(local_depth + 1); value_stack.push_back(local_depth + 1);
operations_stack.emplace_back(empty_callable, operation.type_size, operation.id, true); operations_stack.emplace_back(operation.type_size, operation.id, true);
} }
return depth; return depth;
@ -291,22 +288,22 @@ namespace blt::gp
blt::size_t total_produced = 0; blt::size_t total_produced = 0;
blt::size_t total_consumed = 0; blt::size_t total_consumed = 0;
for (const auto& operation : blt::reverse_iterate(operations.begin(), operations.end())) // for (const auto& operation : blt::reverse_iterate(operations.begin(), operations.end()))
{ // {
if (operation.is_value) // if (operation.is_value)
{ // {
value_stack.transfer_bytes(values_process, operation.type_size); // value_stack.transfer_bytes(values_process, operation.type_size);
total_produced += stack_allocator::aligned_size(operation.type_size); // total_produced += stack_allocator::aligned_size(operation.type_size);
bitfield.push_back(false); // bitfield.push_back(false);
continue; // continue;
} // }
auto& info = program.get_operator_info(operation.id); // auto& info = program.get_operator_info(operation.id);
for (auto& arg : info.argument_types) // for (auto& arg : info.argument_types)
total_consumed += stack_allocator::aligned_size(program.get_typesystem().get_type(arg).size()); // total_consumed += stack_allocator::aligned_size(program.get_typesystem().get_type(arg).size());
operation.func(context, values_process, values_process, &bitfield); // operation.func(context, values_process, values_process, &bitfield);
bitfield.push_back(true); // bitfield.push_back(true);
total_produced += stack_allocator::aligned_size(program.get_typesystem().get_type(info.return_type).size()); // total_produced += stack_allocator::aligned_size(program.get_typesystem().get_type(info.return_type).size());
} // }
auto v1 = results.values.bytes_in_head(); auto v1 = results.values.bytes_in_head();
auto v2 = static_cast<blt::ptrdiff_t>(stack_allocator::aligned_size(operations.front().type_size)); auto v2 = static_cast<blt::ptrdiff_t>(stack_allocator::aligned_size(operations.front().type_size));