sexy refactoring
parent
7ce84ce6c6
commit
0e08a15b5e
|
@ -24,7 +24,7 @@
|
||||||
inline constexpr blt::i32 MIN_DEPTH = 4;
|
inline constexpr blt::i32 MIN_DEPTH = 4;
|
||||||
inline constexpr blt::i32 MAX_DEPTH = 12;
|
inline constexpr blt::i32 MAX_DEPTH = 12;
|
||||||
inline constexpr blt::i32 width = 256, height = 256;
|
inline constexpr blt::i32 width = 256, height = 256;
|
||||||
inline constexpr blt::i32 POPULATION_SIZE = 50;
|
inline constexpr blt::i32 POPULATION_SIZE = 20;
|
||||||
inline constexpr blt::i32 GEN_COUNT = 20;
|
inline constexpr blt::i32 GEN_COUNT = 20;
|
||||||
inline constexpr blt::i32 TOURNAMENT_SIZE = 4;
|
inline constexpr blt::i32 TOURNAMENT_SIZE = 4;
|
||||||
inline constexpr float CROSSOVER_RATE = 0.9;
|
inline constexpr float CROSSOVER_RATE = 0.9;
|
||||||
|
|
|
@ -82,10 +82,7 @@ void node::populate_node(size_t i, std::mt19937_64& engine, const allowed_funcs<
|
||||||
auto terminals = intersection(allowed_args, FUNC_ALLOW_TERMINALS_SET);
|
auto terminals = intersection(allowed_args, FUNC_ALLOW_TERMINALS_SET);
|
||||||
if (terminals.empty())
|
if (terminals.empty())
|
||||||
{
|
{
|
||||||
terminals = FUNC_ALLOW_TERMINALS;
|
terminals = allowed_args;
|
||||||
// BLT_INFO("%s:", function_name_map[to_underlying(type)].c_str());
|
|
||||||
// for (auto v : allowed_args)
|
|
||||||
// BLT_INFO(function_name_map[to_underlying(v)]);
|
|
||||||
}
|
}
|
||||||
std::uniform_int_distribution<blt::size_t> select(0, terminals.size() - 1);
|
std::uniform_int_distribution<blt::size_t> select(0, terminals.size() - 1);
|
||||||
sub_nodes[i] = createNode(terminals[select(engine)]);
|
sub_nodes[i] = createNode(terminals[select(engine)]);
|
||||||
|
@ -94,7 +91,7 @@ void node::populate_node(size_t i, std::mt19937_64& engine, const allowed_funcs<
|
||||||
auto non_terminals = intersection_comp(allowed_args, FUNC_ALLOW_TERMINALS_SET);
|
auto non_terminals = intersection_comp(allowed_args, FUNC_ALLOW_TERMINALS_SET);
|
||||||
if (non_terminals.empty())
|
if (non_terminals.empty())
|
||||||
{
|
{
|
||||||
BLT_WARN("Empty non-terminals set! Filling from terminals!");
|
//BLT_WARN("Empty non-terminals set for node '%s'! Filling from terminals!", function_name_map[to_underlying(type)].c_str());
|
||||||
std::uniform_int_distribution<blt::size_t> select(0, allowed_args.size() - 1);
|
std::uniform_int_distribution<blt::size_t> select(0, allowed_args.size() - 1);
|
||||||
sub_nodes[i] = createNode(allowed_args[select(engine)]);
|
sub_nodes[i] = createNode(allowed_args[select(engine)]);
|
||||||
} else
|
} else
|
||||||
|
@ -237,7 +234,7 @@ void node::evaluate()
|
||||||
}
|
}
|
||||||
#define FUNC_DEFINE(NAME, MIN_ARGS, MAX_ARGS, FUNC, ...) case function_t::NAME: { \
|
#define FUNC_DEFINE(NAME, MIN_ARGS, MAX_ARGS, FUNC, ...) case function_t::NAME: { \
|
||||||
if (function_t::NAME == function_t::IF) { \
|
if (function_t::NAME == function_t::IF) { \
|
||||||
std::cout << "__:" << function_name_map[to_underlying(this->sub_nodes[0]->type)] << std::endl; \
|
/*std::cout << "__:" << function_name_map[to_underlying(this->sub_nodes[0]->type)] << std::endl;*/ \
|
||||||
}\
|
}\
|
||||||
if (FUNC_ALLOW_TERMINALS_SET.contains(function_t::NAME)){ \
|
if (FUNC_ALLOW_TERMINALS_SET.contains(function_t::NAME)){ \
|
||||||
FUNC(img.value(), 0, 0, argc, const_cast<const image**>(sub_node_images.data()), data); \
|
FUNC(img.value(), 0, 0, argc, const_cast<const image**>(sub_node_images.data()), data); \
|
||||||
|
|
50
src/main.cpp
50
src/main.cpp
|
@ -19,6 +19,8 @@
|
||||||
#include <config.h>
|
#include <config.h>
|
||||||
#include <gp.h>
|
#include <gp.h>
|
||||||
|
|
||||||
|
blt::i8* top_of_stack = nullptr;
|
||||||
|
|
||||||
blt::gfx::matrix_state_manager global_matrices;
|
blt::gfx::matrix_state_manager global_matrices;
|
||||||
blt::gfx::resource_manager resources;
|
blt::gfx::resource_manager resources;
|
||||||
blt::gfx::batch_renderer_2d renderer_2d(resources);
|
blt::gfx::batch_renderer_2d renderer_2d(resources);
|
||||||
|
@ -185,8 +187,8 @@ class tree
|
||||||
auto d = depth(n.child);
|
auto d = depth(n.child);
|
||||||
|
|
||||||
node* new_subtree = node::construct_random_tree(d + 1);
|
node* new_subtree = node::construct_random_tree(d + 1);
|
||||||
|
destroyNode(n.parent->sub_nodes[n.index]);
|
||||||
n.parent->sub_nodes[n.index] = new_subtree;
|
n.parent->sub_nodes[n.index] = new_subtree;
|
||||||
destroyNode(n.child);
|
|
||||||
return mutation_result_t{std::move(c)};
|
return mutation_result_t{std::move(c)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -230,14 +232,21 @@ class gp_population
|
||||||
{
|
{
|
||||||
std::unique_ptr<tree> t = nullptr;
|
std::unique_ptr<tree> t = nullptr;
|
||||||
float fitness = 0;
|
float fitness = 0;
|
||||||
|
|
||||||
|
[[nodiscard]] inline gp_i clone() const
|
||||||
|
{
|
||||||
|
return {t->clone(), fitness};
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using population_storage_t = std::array<gp_i, POPULATION_SIZE>;
|
||||||
blt::size_t best = -1;
|
blt::size_t best = -1;
|
||||||
private:
|
private:
|
||||||
std::array<gp_i, POPULATION_SIZE> pop;
|
population_storage_t pop;
|
||||||
public:
|
public:
|
||||||
gp_population() = default;
|
gp_population() = default;
|
||||||
|
|
||||||
static blt::size_t choice()
|
static blt::size_t random_individual_index()
|
||||||
{
|
{
|
||||||
static std::random_device dev;
|
static std::random_device dev;
|
||||||
static std::mt19937_64 engine{dev()};
|
static std::mt19937_64 engine{dev()};
|
||||||
|
@ -253,25 +262,24 @@ class gp_population
|
||||||
evaluate_population();
|
evaluate_population();
|
||||||
}
|
}
|
||||||
|
|
||||||
tree* select()
|
[[nodiscard]] static tree* tournament_select(population_storage_t& p)
|
||||||
{
|
{
|
||||||
tree* n = nullptr;
|
tree* n = nullptr;
|
||||||
float fitness = -2 * 8192;
|
float fitness = -2 * 8192;
|
||||||
//BLT_TRACE("With inital %f", fitness);
|
|
||||||
for (int i = 0; i < TOURNAMENT_SIZE; i++)
|
for (int i = 0; i < TOURNAMENT_SIZE; i++)
|
||||||
{
|
{
|
||||||
blt::size_t index = 0;
|
blt::size_t index = 0;
|
||||||
do
|
do
|
||||||
{
|
{
|
||||||
index = choice();
|
index = random_individual_index();
|
||||||
auto& v = pop[index];
|
auto& v = p[index];
|
||||||
if (v.fitness >= fitness)
|
if (v.fitness >= fitness)
|
||||||
{
|
{
|
||||||
n = v.t.get();
|
n = v.t.get();
|
||||||
fitness = v.fitness;
|
fitness = v.fitness;
|
||||||
}
|
}
|
||||||
//BLT_TRACE("%d: %p -> %f ? %p -> %f || %d %d", index, v.t.get(), v.fitness, n, fitness, n != v.t.get(), v.fitness > fitness);
|
//BLT_TRACE("%d: %p -> %f ? %p -> %f || %d %d", index, v.t.get(), v.fitness, n, fitness, n != v.t.get(), v.fitness > fitness);
|
||||||
} while (n == pop[index].t.get());
|
} while (n == p[index].t.get());
|
||||||
}
|
}
|
||||||
//BLT_DEBUG("%p -> %f", n, fitness);
|
//BLT_DEBUG("%p -> %f", n, fitness);
|
||||||
BLT_ASSERT(n != nullptr);
|
BLT_ASSERT(n != nullptr);
|
||||||
|
@ -280,24 +288,32 @@ class gp_population
|
||||||
|
|
||||||
void run_step()
|
void run_step()
|
||||||
{
|
{
|
||||||
std::array<gp_i, POPULATION_SIZE> new_pop;
|
population_storage_t new_pop;
|
||||||
static std::random_device dev;
|
static std::random_device dev;
|
||||||
static std::mt19937_64 engine{dev()};
|
static std::mt19937_64 engine{dev()};
|
||||||
static std::uniform_real_distribution rand(0.0, 1.0);
|
static std::uniform_real_distribution rand(0.0, 1.0);
|
||||||
|
|
||||||
auto b = get_best();
|
auto b = get_best();
|
||||||
new_pop[0] = {pop[b.first].t->clone(), b.second};
|
new_pop[0] = {pop[b.first].t->clone(), b.second};
|
||||||
blt::size_t insert_pos = 1;
|
|
||||||
|
for (blt::size_t i = 1; i < POPULATION_SIZE; i++)
|
||||||
|
new_pop[i] = {pop[i].t->clone(), pop[i].fitness};
|
||||||
|
|
||||||
blt::size_t crossover_count = 0;
|
blt::size_t crossover_count = 0;
|
||||||
blt::size_t mutation_count = 0;
|
blt::size_t mutation_count = 0;
|
||||||
BLT_TRACE("Running Crossover");
|
BLT_TRACE("Running Crossover");
|
||||||
for (blt::size_t i = 0; i < POPULATION_SIZE; i++)
|
for (blt::size_t i = insert_pos; i < POPULATION_SIZE; i++)
|
||||||
{
|
{
|
||||||
|
if (insert_pos >= POPULATION_SIZE)
|
||||||
|
break;
|
||||||
if (rand(engine) < CROSSOVER_RATE)
|
if (rand(engine) < CROSSOVER_RATE)
|
||||||
{
|
{
|
||||||
auto* p1 = select();
|
tree* p1 = select();
|
||||||
auto* p2 = select();
|
tree* p2;
|
||||||
|
do
|
||||||
|
{
|
||||||
|
p2 = select();
|
||||||
|
} while (p2 == p1);
|
||||||
|
|
||||||
if (auto r = tree::crossover(p1, p2))
|
if (auto r = tree::crossover(p1, p2))
|
||||||
{
|
{
|
||||||
|
@ -309,8 +325,10 @@ class gp_population
|
||||||
}
|
}
|
||||||
|
|
||||||
BLT_TRACE("Running Mutation");
|
BLT_TRACE("Running Mutation");
|
||||||
for (blt::size_t i = 0; i < POPULATION_SIZE; i++)
|
for (blt::size_t i = insert_pos; i < POPULATION_SIZE; i++)
|
||||||
{
|
{
|
||||||
|
if (insert_pos >= POPULATION_SIZE)
|
||||||
|
break;
|
||||||
if (rand(engine) < MUTATION_RATE)
|
if (rand(engine) < MUTATION_RATE)
|
||||||
{
|
{
|
||||||
auto* p1 = select();
|
auto* p1 = select();
|
||||||
|
@ -447,6 +465,8 @@ void update(std::int32_t w, std::int32_t h)
|
||||||
ImGui::Text("Physical Memory Usage: %s", blt::string::fromBytes(data.resident).c_str());
|
ImGui::Text("Physical Memory Usage: %s", blt::string::fromBytes(data.resident).c_str());
|
||||||
ImGui::Text("Shared Memory Usage: %s", blt::string::fromBytes(data.shared).c_str());
|
ImGui::Text("Shared Memory Usage: %s", blt::string::fromBytes(data.shared).c_str());
|
||||||
ImGui::Text("Total Memory Usage: %s", blt::string::fromBytes(data.size).c_str());
|
ImGui::Text("Total Memory Usage: %s", blt::string::fromBytes(data.size).c_str());
|
||||||
|
blt::i8 scope = 0;
|
||||||
|
ImGui::Text("Stack Size: %s", blt::string::fromBytes((top_of_stack - &scope) * sizeof(blt::i8)).c_str());
|
||||||
|
|
||||||
auto lw = 512.0f;
|
auto lw = 512.0f;
|
||||||
auto lh = 512.0f;
|
auto lh = 512.0f;
|
||||||
|
@ -460,6 +480,8 @@ void update(std::int32_t w, std::int32_t h)
|
||||||
|
|
||||||
int main()
|
int main()
|
||||||
{
|
{
|
||||||
|
blt::i8 tos = 0;
|
||||||
|
top_of_stack = &tos;
|
||||||
// auto& funcs = function_arg_allowed_map[to_underlying(function_t::IF)];
|
// auto& funcs = function_arg_allowed_map[to_underlying(function_t::IF)];
|
||||||
// for (auto v : blt::enumerate(funcs))
|
// for (auto v : blt::enumerate(funcs))
|
||||||
// for (auto f : v.second)
|
// for (auto f : v.second)
|
||||||
|
|
Loading…
Reference in New Issue