main
Brett 2024-01-30 12:21:02 -05:00
parent ce53aace34
commit 68dad0f963
4 changed files with 193 additions and 27 deletions

View File

@ -23,5 +23,10 @@
inline constexpr blt::i32 MAX_DEPTH = 12; inline constexpr blt::i32 MAX_DEPTH = 12;
inline constexpr blt::i32 width = 256, height = 256; inline constexpr blt::i32 width = 256, height = 256;
inline constexpr blt::i32 POPULATION_SIZE = 20;
inline constexpr blt::i32 GEN_COUNT = 20;
inline constexpr blt::i32 TOURNAMENT_SIZE = 3;
inline constexpr float CROSSOVER_RATE = 0.9;
inline constexpr float MUTATION_RATE = 0.9;
#endif //GP_IMAGE_TEST_CONFIG_H #endif //GP_IMAGE_TEST_CONFIG_H

View File

@ -65,7 +65,7 @@ struct node
} }
} }
static node* construct_random_tree(); static node* construct_random_tree(blt::size_t max_depth = MAX_DEPTH);
void evaluate(); void evaluate();

View File

@ -127,7 +127,7 @@ void node::print_tree()
std::cout << ") "; std::cout << ") ";
} }
node* node::construct_random_tree() node* node::construct_random_tree(blt::size_t max_depth)
{ {
static std::random_device dev; static std::random_device dev;
static std::mt19937_64 engine{dev()}; static std::mt19937_64 engine{dev()};
@ -149,9 +149,9 @@ node* node::construct_random_tree()
current_depth++; current_depth++;
if (choice(engine)) if (choice(engine))
front.first->grow(front.second >= MAX_DEPTH); front.first->grow(front.second >= max_depth);
else else
front.first->full(front.second >= MAX_DEPTH); front.first->full(front.second >= max_depth);
for (size_t i = 0; i < front.first->argc; i++) for (size_t i = 0; i < front.first->argc; i++)
grow_queue.emplace(front.first->sub_nodes[i], current_depth + 1); grow_queue.emplace(front.first->sub_nodes[i], current_depth + 1);

View File

@ -23,9 +23,11 @@ blt::gfx::batch_renderer_2d renderer_2d(resources);
class tree class tree
{ {
public: public:
struct crossover_result_t struct search_result_t
{ {
std::unique_ptr<tree> c1, c2; node* child;
node* parent;
blt::size_t index;
}; };
private: private:
std::unique_ptr<node, node_deleter> root = nullptr; std::unique_ptr<node, node_deleter> root = nullptr;
@ -64,20 +66,29 @@ class tree
} }
} }
node* select_random_child() search_result_t select_random_child()
{ {
static std::random_device dev; static std::random_device dev;
static std::mt19937_64 engine{dev()}; static std::mt19937_64 engine{dev()};
auto d = depth();
std::uniform_int_distribution depth_dist(0, static_cast<int>(d));
std::uniform_int_distribution select(0, 3); std::uniform_int_distribution select(0, 3);
node* current = root.get(); node* current = root.get();
node* parent = nullptr;
blt::size_t index = 0;
while (true) while (true)
{ {
std::uniform_int_distribution<size_t> children(0, current->argc - 1); std::uniform_int_distribution<size_t> children(0, current->argc - 1);
if (select(engine) == 0 || current->argc == 0)
break;
index = children(engine);
auto* next = current->sub_nodes[index];
if (next == nullptr)
break;
parent = current;
current = next;
} }
return {current, parent, index};
} }
public: public:
@ -86,13 +97,13 @@ class tree
return std::make_unique<tree>(tree{node::construct_random_tree()}); return std::make_unique<tree>(tree{node::construct_random_tree()});
} }
blt::size_t depth() static blt::size_t depth(node* root_node)
{ {
// depth -> node // depth -> node
std::stack<std::pair<blt::size_t, node*>> stack; std::stack<std::pair<blt::size_t, node*>> stack;
blt::size_t max_depth = 0; blt::size_t max_depth = 0;
stack.emplace(0, root.get()); stack.emplace(0, root_node);
while (!stack.empty()) while (!stack.empty())
{ {
auto top = stack.top(); auto top = stack.top();
@ -112,9 +123,38 @@ class tree
return max_depth; return max_depth;
} }
static crossover_result_t crossover(tree* p1, tree* p2) static bool crossover(tree* p1, tree* p2)
{ {
return {}; auto n1 = p1->select_random_child();
auto n2 = p2->select_random_child();
const auto& p1_allowed = function_arg_allowed_set_map[to_underlying(n1.parent->type)][n1.index];
const auto& p2_allowed = function_arg_allowed_set_map[to_underlying(n2.parent->type)][n2.index];
if (!p1_allowed.contains(n2.child->type))
return false;
if (!p2_allowed.contains(n1.child->type))
return false;
n1.parent->sub_nodes[n1.index] = n2.child;
n2.parent->sub_nodes[n2.index] = n1.child;
return true;
}
static void mutate(tree* p)
{
static std::random_device dev;
static std::mt19937_64 engine{dev()};
std::uniform_int_distribution choice(0, 1);
auto n = p->select_random_child();
auto d = depth(n.child);
node* new_subtree = node::construct_random_tree(d + 1);
n.parent->sub_nodes[n.index] = new_subtree;
destroyNode(n.child);
} }
void evaluate() void evaluate()
@ -149,7 +189,120 @@ class tree
} }
}; };
std::unique_ptr<tree> test_tree; class gp_population
{
public:
struct gp_i
{
std::unique_ptr<tree> t = nullptr;
float fitness = 0;
};
private:
std::array<gp_i, POPULATION_SIZE> pop;
public:
gp_population() = default;
static blt::size_t choice()
{
static std::random_device dev;
static std::mt19937_64 engine{dev()};
static std::uniform_int_distribution c(0, POPULATION_SIZE - 1);
return c(engine);
}
void init()
{
for (auto& v : pop)
v = {tree::construct_random_tree()};
evaluate_population();
}
tree* select()
{
tree* n = nullptr;
float fitness = 0;
for (int i = 0; i < TOURNAMENT_SIZE; i++)
{
auto& v = pop[choice()];
if (n != v.t.get() && v.fitness > fitness)
{
n = v.t.get();
fitness = v.fitness;
}
}
return n;
}
void run_step()
{
static std::random_device dev;
static std::mt19937_64 engine{dev()};
static std::uniform_real_distribution rand(0.0, 1.0);
for (blt::size_t i = 0; i < POPULATION_SIZE; i++)
{
if (rand(engine) < CROSSOVER_RATE)
{
auto* p1 = select();
auto* p2 = select();
tree::crossover(p1, p2);
}
}
for (blt::size_t i = 0; i < POPULATION_SIZE; i++)
{
if (rand(engine) < MUTATION_RATE)
{
auto* p1 = select();
tree::mutate(p1);
}
}
}
void evaluate_population()
{
for (auto& v : pop)
{
v.t->evaluate();
v.fitness = v.t->fitness();
}
}
image& display(blt::size_t i)
{
return pop[i].t->getImage();
}
std::pair<blt::size_t, float> get_best()
{
blt::size_t i = 0;
float fitness = 0;
for (blt::size_t j = 0; j < POPULATION_SIZE; j++)
{
if (pop[j].fitness > fitness)
{
i = j;
fitness = pop[j].fitness;
}
}
return {i, fitness};
}
image& display_best()
{
return display(get_best().first);
}
float best_fitness()
{
return get_best().second;
}
};
gp_population pop;
blt::gfx::texture_gl2D* texture; blt::gfx::texture_gl2D* texture;
void print_bits(float f) void print_bits(float f)
@ -178,6 +331,8 @@ void init()
resources.load_resources(); resources.load_resources();
} }
float best = 0;
void update(std::int32_t w, std::int32_t h) void update(std::int32_t w, std::int32_t h)
{ {
global_matrices.update_perspectives(w, h, 90, 0.1, 2000); global_matrices.update_perspectives(w, h, 90, 0.1, 2000);
@ -186,23 +341,29 @@ void update(std::int32_t w, std::int32_t h)
if (ImGui::Button("Regenerate")) if (ImGui::Button("Regenerate"))
{ {
BLT_INFO("Regen tree"); BLT_INFO("Regen tree");
test_tree = tree::construct_random_tree(); pop.init();
test_tree->evaluate(); best = pop.best_fitness();
BLT_INFO("Uploading"); }
texture->upload((void*) test_tree->getImage().getData().data(), GL_RGB, 0, 0, 0, -1, -1, GL_FLOAT);
if (ImGui::Button("Run Step"))
{
BLT_INFO("Running Step");
pop.run_step();
BLT_INFO("Evaluating Population");
pop.evaluate_population();
best = pop.best_fitness();
} }
if (ImGui::Button("Display")) if (ImGui::Button("Display"))
{ {
if (test_tree->hasTree()) BLT_INFO("Uploading");
test_tree->printTree(); texture->upload((void*) pop.display_best().getData().data(), GL_RGB, 0, 0, 0, -1, -1, GL_FLOAT);
}
;
if (ImGui::Button("Eval")) // if (test_tree->hasTree())
{ // test_tree->printTree();
if (test_tree->hasImage())
BLT_DEBUG(test_tree->fitness());
} }
ImGui::Text("Best Fitness: %f", best);
auto lw = 512.0f; auto lw = 512.0f;
auto lh = 512.0f; auto lh = 512.0f;
@ -216,7 +377,7 @@ void update(std::int32_t w, std::int32_t h)
int main() int main()
{ {
shapiro_test_run(); //shapiro_test_run();
blt::gfx::init(blt::gfx::window_data{"Window of GP test", init, update}.setSyncInterval(1)); blt::gfx::init(blt::gfx::window_data{"Window of GP test", init, update}.setSyncInterval(1));
global_matrices.cleanup(); global_matrices.cleanup();
resources.cleanup(); resources.cleanup();