diff --git a/include/config.h b/include/config.h index 3e269e7..6085b6e 100644 --- a/include/config.h +++ b/include/config.h @@ -23,5 +23,10 @@ inline constexpr blt::i32 MAX_DEPTH = 12; inline constexpr blt::i32 width = 256, height = 256; +inline constexpr blt::i32 POPULATION_SIZE = 20; +inline constexpr blt::i32 GEN_COUNT = 20; +inline constexpr blt::i32 TOURNAMENT_SIZE = 3; +inline constexpr float CROSSOVER_RATE = 0.9; +inline constexpr float MUTATION_RATE = 0.9; #endif //GP_IMAGE_TEST_CONFIG_H diff --git a/include/gp.h b/include/gp.h index c2b484c..4a9f84a 100644 --- a/include/gp.h +++ b/include/gp.h @@ -65,7 +65,7 @@ struct node } } - static node* construct_random_tree(); + static node* construct_random_tree(blt::size_t max_depth = MAX_DEPTH); void evaluate(); diff --git a/src/gp.cpp b/src/gp.cpp index 4735d70..55047b6 100644 --- a/src/gp.cpp +++ b/src/gp.cpp @@ -127,7 +127,7 @@ void node::print_tree() std::cout << ") "; } -node* node::construct_random_tree() +node* node::construct_random_tree(blt::size_t max_depth) { static std::random_device dev; static std::mt19937_64 engine{dev()}; @@ -149,9 +149,9 @@ node* node::construct_random_tree() current_depth++; if (choice(engine)) - front.first->grow(front.second >= MAX_DEPTH); + front.first->grow(front.second >= max_depth); else - front.first->full(front.second >= MAX_DEPTH); + front.first->full(front.second >= max_depth); for (size_t i = 0; i < front.first->argc; i++) grow_queue.emplace(front.first->sub_nodes[i], current_depth + 1); diff --git a/src/main.cpp b/src/main.cpp index 60eff60..9365b87 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -23,9 +23,11 @@ blt::gfx::batch_renderer_2d renderer_2d(resources); class tree { public: - struct crossover_result_t + struct search_result_t { - std::unique_ptr c1, c2; + node* child; + node* parent; + blt::size_t index; }; private: std::unique_ptr root = nullptr; @@ -64,20 +66,29 @@ class tree } } - node* select_random_child() + search_result_t select_random_child() { static std::random_device dev; static std::mt19937_64 engine{dev()}; - auto d = depth(); - std::uniform_int_distribution depth_dist(0, static_cast(d)); std::uniform_int_distribution select(0, 3); node* current = root.get(); + node* parent = nullptr; + blt::size_t index = 0; while (true) { std::uniform_int_distribution children(0, current->argc - 1); + if (select(engine) == 0 || current->argc == 0) + break; + index = children(engine); + auto* next = current->sub_nodes[index]; + if (next == nullptr) + break; + parent = current; + current = next; } + return {current, parent, index}; } public: @@ -86,13 +97,13 @@ class tree return std::make_unique(tree{node::construct_random_tree()}); } - blt::size_t depth() + static blt::size_t depth(node* root_node) { // depth -> node std::stack> stack; blt::size_t max_depth = 0; - stack.emplace(0, root.get()); + stack.emplace(0, root_node); while (!stack.empty()) { auto top = stack.top(); @@ -112,9 +123,38 @@ class tree return max_depth; } - static crossover_result_t crossover(tree* p1, tree* p2) + static bool crossover(tree* p1, tree* p2) { - return {}; + auto n1 = p1->select_random_child(); + auto n2 = p2->select_random_child(); + + const auto& p1_allowed = function_arg_allowed_set_map[to_underlying(n1.parent->type)][n1.index]; + const auto& p2_allowed = function_arg_allowed_set_map[to_underlying(n2.parent->type)][n2.index]; + + if (!p1_allowed.contains(n2.child->type)) + return false; + if (!p2_allowed.contains(n1.child->type)) + return false; + + n1.parent->sub_nodes[n1.index] = n2.child; + n2.parent->sub_nodes[n2.index] = n1.child; + + return true; + } + + static void mutate(tree* p) + { + static std::random_device dev; + static std::mt19937_64 engine{dev()}; + std::uniform_int_distribution choice(0, 1); + + auto n = p->select_random_child(); + + auto d = depth(n.child); + + node* new_subtree = node::construct_random_tree(d + 1); + n.parent->sub_nodes[n.index] = new_subtree; + destroyNode(n.child); } void evaluate() @@ -149,7 +189,120 @@ class tree } }; -std::unique_ptr test_tree; +class gp_population +{ + public: + struct gp_i + { + std::unique_ptr t = nullptr; + float fitness = 0; + }; + private: + std::array pop; + public: + gp_population() = default; + + static blt::size_t choice() + { + static std::random_device dev; + static std::mt19937_64 engine{dev()}; + static std::uniform_int_distribution c(0, POPULATION_SIZE - 1); + + return c(engine); + } + + void init() + { + for (auto& v : pop) + v = {tree::construct_random_tree()}; + evaluate_population(); + } + + tree* select() + { + tree* n = nullptr; + float fitness = 0; + for (int i = 0; i < TOURNAMENT_SIZE; i++) + { + auto& v = pop[choice()]; + if (n != v.t.get() && v.fitness > fitness) + { + n = v.t.get(); + fitness = v.fitness; + } + } + return n; + } + + void run_step() + { + static std::random_device dev; + static std::mt19937_64 engine{dev()}; + static std::uniform_real_distribution rand(0.0, 1.0); + + for (blt::size_t i = 0; i < POPULATION_SIZE; i++) + { + if (rand(engine) < CROSSOVER_RATE) + { + auto* p1 = select(); + auto* p2 = select(); + + tree::crossover(p1, p2); + } + } + + for (blt::size_t i = 0; i < POPULATION_SIZE; i++) + { + if (rand(engine) < MUTATION_RATE) + { + auto* p1 = select(); + + tree::mutate(p1); + } + } + } + + void evaluate_population() + { + for (auto& v : pop) + { + v.t->evaluate(); + v.fitness = v.t->fitness(); + } + } + + image& display(blt::size_t i) + { + return pop[i].t->getImage(); + } + + std::pair get_best() + { + blt::size_t i = 0; + float fitness = 0; + for (blt::size_t j = 0; j < POPULATION_SIZE; j++) + { + if (pop[j].fitness > fitness) + { + i = j; + fitness = pop[j].fitness; + } + } + return {i, fitness}; + } + + image& display_best() + { + return display(get_best().first); + } + + float best_fitness() + { + return get_best().second; + } +}; + +gp_population pop; blt::gfx::texture_gl2D* texture; void print_bits(float f) @@ -178,6 +331,8 @@ void init() resources.load_resources(); } +float best = 0; + void update(std::int32_t w, std::int32_t h) { global_matrices.update_perspectives(w, h, 90, 0.1, 2000); @@ -186,23 +341,29 @@ void update(std::int32_t w, std::int32_t h) if (ImGui::Button("Regenerate")) { BLT_INFO("Regen tree"); - test_tree = tree::construct_random_tree(); - test_tree->evaluate(); - BLT_INFO("Uploading"); - texture->upload((void*) test_tree->getImage().getData().data(), GL_RGB, 0, 0, 0, -1, -1, GL_FLOAT); + pop.init(); + best = pop.best_fitness(); + } + + if (ImGui::Button("Run Step")) + { + BLT_INFO("Running Step"); + pop.run_step(); + BLT_INFO("Evaluating Population"); + pop.evaluate_population(); + best = pop.best_fitness(); } if (ImGui::Button("Display")) { - if (test_tree->hasTree()) - test_tree->printTree(); - } - - if (ImGui::Button("Eval")) - { - if (test_tree->hasImage()) - BLT_DEBUG(test_tree->fitness()); + BLT_INFO("Uploading"); + texture->upload((void*) pop.display_best().getData().data(), GL_RGB, 0, 0, 0, -1, -1, GL_FLOAT); + + ; +// if (test_tree->hasTree()) +// test_tree->printTree(); } + ImGui::Text("Best Fitness: %f", best); auto lw = 512.0f; auto lh = 512.0f; @@ -216,7 +377,7 @@ void update(std::int32_t w, std::int32_t h) int main() { - shapiro_test_run(); + //shapiro_test_run(); blt::gfx::init(blt::gfx::window_data{"Window of GP test", init, update}.setSyncInterval(1)); global_matrices.cleanup(); resources.cleanup();