silly little guy

thread
Brett 2024-06-29 10:47:27 -04:00
parent bf4394cb0d
commit cc9f7202c7
9 changed files with 146 additions and 86 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(blt-gp VERSION 0.0.38) project(blt-gp VERSION 0.0.40)
include(CTest) include(CTest)

View File

@ -311,7 +311,7 @@ namespace blt::gp::detail
class operator_storage_test class operator_storage_test
{ {
public: public:
explicit operator_storage_test(blt::gp::gp_operations<context>& ops): ops(ops) explicit operator_storage_test(blt::gp::operator_builder<context>& ops): ops(ops)
{} {}
inline blt::gp::detail::callable_t& operator[](blt::size_t index) inline blt::gp::detail::callable_t& operator[](blt::size_t index)
@ -320,7 +320,7 @@ namespace blt::gp::detail
} }
private: private:
blt::gp::gp_operations<context>& ops; blt::gp::operator_builder<context>& ops;
}; };
} }
@ -413,9 +413,9 @@ int main()
return ctx.x; return ctx.x;
}); });
blt::gp::type_system system; blt::gp::type_provider system;
system.register_type<float>(); system.register_type<float>();
blt::gp::gp_operations<context> ops{system}; blt::gp::operator_builder<context> ops{system};
//BLT_TRACE(blt::type_string<decltype(silly_op_3)::first::type>()); //BLT_TRACE(blt::type_string<decltype(silly_op_3)::first::type>());
//BLT_TRACE(typeid(decltype(silly_op_3)::first::type).name()); //BLT_TRACE(typeid(decltype(silly_op_3)::first::type).name());

View File

@ -23,7 +23,7 @@
static constexpr long SEED = 41912; static constexpr long SEED = 41912;
blt::gp::type_system type_system; blt::gp::type_provider type_system;
blt::gp::gp_program program(type_system, std::mt19937_64{SEED}); // NOLINT blt::gp::gp_program program(type_system, std::mt19937_64{SEED}); // NOLINT
blt::gp::operation_t add([](float a, float b) { blt::gp::operation_t add([](float a, float b) {
@ -52,7 +52,7 @@ int main()
{ {
type_system.register_type<float>(); type_system.register_type<float>();
blt::gp::gp_operations silly{type_system}; blt::gp::operator_builder silly{type_system};
silly.add_operator(add); silly.add_operator(add);
silly.add_operator(sub); silly.add_operator(sub);
silly.add_operator(mul); silly.add_operator(mul);

View File

@ -22,7 +22,7 @@
static constexpr long SEED = 41912; static constexpr long SEED = 41912;
blt::gp::type_system type_system; blt::gp::type_provider type_system;
blt::gp::gp_program program(type_system, std::mt19937_64{SEED}); // NOLINT blt::gp::gp_program program(type_system, std::mt19937_64{SEED}); // NOLINT
blt::gp::operation_t add([](float a, float b) { return a + b; }); blt::gp::operation_t add([](float a, float b) { return a + b; });
@ -54,7 +54,7 @@ int main()
type_system.register_type<float>(); type_system.register_type<float>();
type_system.register_type<bool>(); type_system.register_type<bool>();
blt::gp::gp_operations silly{type_system}; blt::gp::operator_builder silly{type_system};
silly.add_operator(add); silly.add_operator(add);
silly.add_operator(sub); silly.add_operator(sub);
silly.add_operator(mul); silly.add_operator(mul);

View File

@ -22,7 +22,7 @@
static constexpr long SEED = 41912; static constexpr long SEED = 41912;
blt::gp::type_system type_system; blt::gp::type_provider type_system;
blt::gp::gp_program program(type_system, std::mt19937_64{SEED}); // NOLINT blt::gp::gp_program program(type_system, std::mt19937_64{SEED}); // NOLINT
blt::gp::operation_t add([](float a, float b) { return a + b; }); blt::gp::operation_t add([](float a, float b) { return a + b; });
@ -54,25 +54,25 @@ int main()
type_system.register_type<float>(); type_system.register_type<float>();
type_system.register_type<bool>(); type_system.register_type<bool>();
blt::gp::gp_operations silly{type_system}; blt::gp::operator_builder builder{type_system};
silly.add_operator(add); builder.add_operator(add);
silly.add_operator(sub); builder.add_operator(sub);
silly.add_operator(mul); builder.add_operator(mul);
silly.add_operator(pro_div); builder.add_operator(pro_div);
silly.add_operator(op_if); builder.add_operator(op_if);
silly.add_operator(eq_f); builder.add_operator(eq_f);
silly.add_operator(eq_b); builder.add_operator(eq_b);
silly.add_operator(lt); builder.add_operator(lt);
silly.add_operator(gt); builder.add_operator(gt);
silly.add_operator(op_and); builder.add_operator(op_and);
silly.add_operator(op_or); builder.add_operator(op_or);
silly.add_operator(op_xor); builder.add_operator(op_xor);
silly.add_operator(op_not); builder.add_operator(op_not);
silly.add_operator(lit, true); builder.add_operator(lit, true);
program.set_operations(std::move(silly)); program.set_operations(builder.build());
blt::gp::ramped_half_initializer_t pop_init; blt::gp::ramped_half_initializer_t pop_init;

View File

@ -29,7 +29,7 @@ namespace blt::gp
class type; class type;
class type_system; class type_provider;
struct op_container_t; struct op_container_t;

View File

@ -28,6 +28,7 @@
#include <utility> #include <utility>
#include <iostream> #include <iostream>
#include <random> #include <random>
#include <algorithm>
#include <blt/std/ranges.h> #include <blt/std/ranges.h>
#include <blt/std/hashmap.h> #include <blt/std/hashmap.h>
@ -52,24 +53,37 @@ namespace blt::gp
blt::size_t argc_context = 0; blt::size_t argc_context = 0;
}; };
struct operator_storage
{
// indexed from return TYPE ID, returns index of operator
blt::expanding_buffer<std::vector<operator_id>> terminals;
blt::expanding_buffer<std::vector<operator_id>> non_terminals;
// indexed from OPERATOR ID (operator number)
blt::expanding_buffer<std::vector<type>> argument_types;
blt::expanding_buffer<argc_t> operator_argc;
blt::hashset_t<operator_id> static_types;
std::vector<detail::callable_t> operators;
std::vector<detail::transfer_t> transfer_funcs;
};
template<typename Context = detail::empty_t> template<typename Context = detail::empty_t>
class gp_operations class operator_builder
{ {
friend class gp_program; friend class gp_program;
friend class blt::gp::detail::operator_storage_test; friend class blt::gp::detail::operator_storage_test;
public: public:
explicit gp_operations(type_system& system): system(system) explicit operator_builder(type_provider& system): system(system)
{} {}
template<typename Return, typename... Args> template<typename Return, typename... Args>
gp_operations& add_operator(const operation_t<Return(Args...)>& op, bool is_static = false) operator_builder& add_operator(const operation_t<Return(Args...)>& op, bool is_static = false)
{ {
auto return_type_id = system.get_type<Return>().id(); auto return_type_id = system.get_type<Return>().id();
auto operator_id = blt::gp::operator_id(operators.size()); auto operator_id = blt::gp::operator_id(storage.operators.size());
auto& operator_list = op.get_argc() == 0 ? terminals : non_terminals; auto& operator_list = op.get_argc() == 0 ? storage.terminals : storage.non_terminals;
operator_list[return_type_id].push_back(operator_id); operator_list[return_type_id].push_back(operator_id);
if constexpr (sizeof...(Args) > 0) if constexpr (sizeof...(Args) > 0)
@ -84,38 +98,81 @@ namespace blt::gp
BLT_ASSERT(argc.argc_context - argc.argc <= 1 && "Cannot pass multiple context as arguments!"); BLT_ASSERT(argc.argc_context - argc.argc <= 1 && "Cannot pass multiple context as arguments!");
operator_argc[operator_id] = argc; storage.operator_argc[operator_id] = argc;
operators.push_back(op.template make_callable<Context>()); storage.operators.push_back(op.template make_callable<Context>());
transfer_funcs.push_back([](stack_allocator& to, stack_allocator& from) { storage.transfer_funcs.push_back([](stack_allocator& to, stack_allocator& from) {
to.push(from.pop<Return>()); to.push(from.pop<Return>());
}); });
if (is_static) if (is_static)
static_types.insert(operator_id); storage.static_types.insert(operator_id);
return *this; return *this;
} }
operator_storage&& build()
{
blt::hashset_t<type_id> has_terminals;
for (const auto& v : blt::enumerate(storage.terminals))
{
if (!v.second.empty())
has_terminals.insert(v.first);
}
blt::expanding_buffer<std::vector<std::pair<operator_id, blt::size_t>>> operators_ordered_terminals;
for (const auto& op_r : blt::enumerate(storage.non_terminals))
{
if (op_r.second.empty())
continue;
auto return_type = op_r.first;
std::vector<std::pair<operator_id, blt::size_t>> ordered_terminals;
for (const auto& op : op_r.second)
{
// count number of terminals
blt::size_t terminals = 0;
for (const auto& type : storage.argument_types[op])
{
if (has_terminals.contains(type.id()))
terminals++;
}
ordered_terminals.emplace_back(op, terminals);
}
bool found = false;
for (const auto& terms : ordered_terminals)
{
if (terms.second != 0)
{
found = true;
break;
}
}
if (!found)
{
BLT_ABORT(("Failed to find non-terminals "));
}
std::sort(ordered_terminals.begin(), ordered_terminals.end(), [](const auto& a, const auto& b) {
return a.second > b.second;
});
operators_ordered_terminals[return_type] = ordered_terminals;
}
return std::move(storage);
}
private: private:
template<typename T> template<typename T>
void add_non_context_argument(blt::gp::operator_id operator_id) void add_non_context_argument(blt::gp::operator_id operator_id)
{ {
if constexpr (!std::is_same_v<Context, detail::remove_cv_ref<T>>) if constexpr (!std::is_same_v<Context, detail::remove_cv_ref<T>>)
{ {
argument_types[operator_id].push_back(system.get_type<T>()); storage.argument_types[operator_id].push_back(system.get_type<T>());
} }
} }
type_system& system; type_provider& system;
operator_storage storage;
// indexed from return TYPE ID, returns index of operator
blt::expanding_buffer<std::vector<operator_id>> terminals;
blt::expanding_buffer<std::vector<operator_id>> non_terminals;
// indexed from OPERATOR ID (operator number)
blt::expanding_buffer<std::vector<type>> argument_types;
blt::expanding_buffer<argc_t> operator_argc;
blt::hashset_t<operator_id> static_types;
std::vector<detail::callable_t> operators;
std::vector<detail::transfer_t> transfer_funcs;
}; };
class gp_program class gp_program
@ -129,7 +186,7 @@ namespace blt::gp
* @param engine random engine to use throughout the program. TODO replace this with something better * @param engine random engine to use throughout the program. TODO replace this with something better
* @param context_size number of arguments which are always present as "context" to the GP system / operators * @param context_size number of arguments which are always present as "context" to the GP system / operators
*/ */
explicit gp_program(type_system& system, std::mt19937_64 engine): explicit gp_program(type_provider& system, std::mt19937_64 engine):
system(system), engine(engine) system(system), engine(engine)
{} {}
@ -156,21 +213,21 @@ namespace blt::gp
return dist(engine) < cutoff; return dist(engine) < cutoff;
} }
[[nodiscard]] inline type_system& get_typesystem() [[nodiscard]] inline type_provider& get_typesystem()
{ {
return system; return system;
} }
inline operator_id select_terminal(type_id id) inline operator_id select_terminal(type_id id)
{ {
std::uniform_int_distribution<blt::size_t> dist(0, terminals[id].size() - 1); std::uniform_int_distribution<blt::size_t> dist(0, storage.terminals[id].size() - 1);
return terminals[id][dist(engine)]; return storage.terminals[id][dist(engine)];
} }
inline operator_id select_non_terminal(type_id id) inline operator_id select_non_terminal(type_id id)
{ {
std::uniform_int_distribution<blt::size_t> dist(0, non_terminals[id].size() - 1); std::uniform_int_distribution<blt::size_t> dist(0, storage.non_terminals[id].size() - 1);
return non_terminals[id][dist(engine)]; return storage.non_terminals[id][dist(engine)];
} }
// inline operator_id select_non_terminal_too_deep(type_id id) // inline operator_id select_non_terminal_too_deep(type_id id)
@ -189,65 +246,51 @@ namespace blt::gp
inline std::vector<type>& get_argument_types(operator_id id) inline std::vector<type>& get_argument_types(operator_id id)
{ {
return argument_types[id]; return storage.argument_types[id];
} }
inline std::vector<operator_id>& get_type_terminals(type_id id) inline std::vector<operator_id>& get_type_terminals(type_id id)
{ {
return terminals[id]; return storage.terminals[id];
} }
inline std::vector<operator_id>& get_type_non_terminals(type_id id) inline std::vector<operator_id>& get_type_non_terminals(type_id id)
{ {
return non_terminals[id]; return storage.non_terminals[id];
} }
inline argc_t get_argc(operator_id id) inline argc_t get_argc(operator_id id)
{ {
return operator_argc[id]; return storage.operator_argc[id];
} }
inline detail::callable_t& get_operation(operator_id id) inline detail::callable_t& get_operation(operator_id id)
{ {
return operators[id]; return storage.operators[id];
} }
inline detail::transfer_t& get_transfer_func(operator_id id) inline detail::transfer_t& get_transfer_func(operator_id id)
{ {
return transfer_funcs[id]; return storage.transfer_funcs[id];
} }
inline bool is_static(operator_id id) inline bool is_static(operator_id id)
{ {
return static_types.contains(static_cast<blt::size_t>(id)); return storage.static_types.contains(static_cast<blt::size_t>(id));
} }
template<typename Context> inline void set_operations(operator_storage&& op)
inline void set_operations(gp_operations<Context>&& op)
{ {
terminals = std::move(op.terminals); storage = std::move(op);
non_terminals = std::move(op.non_terminals);
argument_types = std::move(op.argument_types);
static_types = std::move(op.static_types);
operator_argc = std::move(op.operator_argc);
operators = std::move(op.operators);
transfer_funcs = std::move(op.transfer_funcs);
} }
private: private:
type_system& system; type_provider& system;
blt::gp::stack_allocator alloc; blt::gp::stack_allocator alloc;
std::mt19937_64 engine;
// indexed from return TYPE ID, returns index of operator operator_storage storage;
blt::expanding_buffer<std::vector<operator_id>> terminals;
blt::expanding_buffer<std::vector<operator_id>> non_terminals; std::mt19937_64 engine;
// indexed from OPERATOR ID (operator number)
blt::expanding_buffer<std::vector<type>> argument_types;
blt::expanding_buffer<argc_t> operator_argc;
blt::hashset_t<operator_id> static_types;
std::vector<detail::callable_t> operators;
std::vector<detail::transfer_t> transfer_funcs;
}; };
} }

View File

@ -80,10 +80,10 @@ namespace blt::gp
* Is a provider for the set of types possible in a GP program * Is a provider for the set of types possible in a GP program
* also provides a set of functions for converting between C++ types and BLT GP types * also provides a set of functions for converting between C++ types and BLT GP types
*/ */
class type_system class type_provider
{ {
public: public:
type_system() = default; type_provider() = default;
template<typename T> template<typename T>
inline type register_type() inline type register_type()
@ -98,6 +98,17 @@ namespace blt::gp
return types[blt::type_string_raw<T>()]; return types[blt::type_string_raw<T>()];
} }
inline type get_type(type_id id)
{
for (const auto& v : types)
{
if (v.second.id() == id)
return v.second;
}
BLT_ABORT(("Type " + std::to_string(id) + " does not exist").c_str());
std::exit(0);
}
/** /**
* This function is slow btw * This function is slow btw
* @param engine * @param engine

View File

@ -59,12 +59,12 @@ namespace blt::gp
{ {
auto top = tree_generator.top(); auto top = tree_generator.top();
tree_generator.pop(); tree_generator.pop();
//BLT_INFO("%ld D: %ld (%ld left)", top.id, top.depth, tree_generator.size());
tree.get_operations().emplace_back( tree.get_operations().emplace_back(
args.program.get_operation(top.id), args.program.get_operation(top.id),
args.program.get_transfer_func(top.id), args.program.get_transfer_func(top.id),
args.program.is_static(top.id) args.program.is_static(top.id));
);
max_depth = std::max(max_depth, top.depth); max_depth = std::max(max_depth, top.depth);
if (args.program.is_static(top.id)) if (args.program.is_static(top.id))
@ -159,10 +159,16 @@ namespace blt::gp
auto remainder = args.size % steps; auto remainder = args.size % steps;
population_t pop; population_t pop;
BLT_INFO(steps);
BLT_INFO(per_step);
BLT_INFO(remainder);
for (auto depth : blt::range(args.min_depth, args.max_depth)) for (auto depth : blt::range(args.min_depth, args.max_depth))
{ {
BLT_TRACE(depth);
for (auto i = 0ul; i < per_step; i++) for (auto i = 0ul; i < per_step; i++)
{ {
BLT_DEBUG(i);
if (args.program.choice()) if (args.program.choice())
pop.getIndividuals().push_back(full.generate({args.program, args.root_type, args.min_depth, depth})); pop.getIndividuals().push_back(full.generate({args.program, args.root_type, args.min_depth, depth}));
else else