diff --git a/src/main.cpp b/src/main.cpp index dd11c62..ffefaba 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -228,6 +228,16 @@ class tree class gp_population { public: + enum class crossover_error_t + { + + }; + struct selection + { + blt::size_t index; + tree* t_ref; + }; + struct gp_i { std::unique_ptr t = nullptr; @@ -262,9 +272,10 @@ class gp_population evaluate_population(); } - [[nodiscard]] static tree* tournament_select(population_storage_t& p) + [[nodiscard]] static selection tournament_select(population_storage_t& p) { tree* n = nullptr; + blt::size_t ni = 0; float fitness = -2 * 8192; for (int i = 0; i < TOURNAMENT_SIZE; i++) { @@ -272,18 +283,18 @@ class gp_population do { index = random_individual_index(); - auto& v = p[index]; - if (v.fitness >= fitness) - { - n = v.t.get(); - fitness = v.fitness; - } - //BLT_TRACE("%d: %p -> %f ? %p -> %f || %d %d", index, v.t.get(), v.fitness, n, fitness, n != v.t.get(), v.fitness > fitness); } while (n == p[index].t.get()); + + auto& v = p[index]; + if (v.fitness >= fitness) + { + n = v.t.get(); + ni = index; + fitness = v.fitness; + } } - //BLT_DEBUG("%p -> %f", n, fitness); BLT_ASSERT(n != nullptr); - return n; + return {ni, n}; } void run_step() @@ -302,49 +313,41 @@ class gp_population blt::size_t crossover_count = 0; blt::size_t mutation_count = 0; BLT_TRACE("Running Crossover"); - for (blt::size_t i = insert_pos; i < POPULATION_SIZE; i++) + for (blt::size_t i = 0; i < POPULATION_SIZE; i++) { - if (insert_pos >= POPULATION_SIZE) - break; if (rand(engine) < CROSSOVER_RATE) { - tree* p1 = select(); - tree* p2; + selection p1 = tournament_select(new_pop); + selection p2{}; do { - p2 = select(); - } while (p2 == p1); + p2 = tournament_select(new_pop); + } while (p2.t_ref == p1.t_ref); - if (auto r = tree::crossover(p1, p2)) + if (auto r = tree::crossover(p1.t_ref, p2.t_ref)) { - new_pop[insert_pos++] = {std::move(r->c1)}; - new_pop[insert_pos++] = {std::move(r->c2)}; + new_pop[p1.index] = {std::move(r->c1)}; + new_pop[p2.index] = {std::move(r->c2)}; crossover_count++; } } } BLT_TRACE("Running Mutation"); - for (blt::size_t i = insert_pos; i < POPULATION_SIZE; i++) + for (blt::size_t i = 0; i < POPULATION_SIZE; i++) { - if (insert_pos >= POPULATION_SIZE) - break; if (rand(engine) < MUTATION_RATE) { - auto* p1 = select(); + auto p1 = tournament_select(new_pop); - if (auto r = tree::mutate(p1)) + if (auto r = tree::mutate(p1.t_ref)) { - new_pop[insert_pos++] = gp_i{std::move(r->c), 0}; + new_pop[p1.index] = gp_i{std::move(r->c), 0}; mutation_count++; } } } - while (insert_pos < POPULATION_SIZE) - { - new_pop[insert_pos++] = {select()->clone()}; - } pop = std::move(new_pop); BLT_TRACE("ran %d crossovers and %d mutations", crossover_count, mutation_count); } @@ -414,6 +417,25 @@ void print_bits(float f) std::cout << std::endl; } +enum class error +{ + S1, S2, S3 +}; + +struct simple +{ + float f; +}; + +blt::expected t1(bool b) +{ + float f = 50.0f * b; + if (b) + return f; + else + return error::S1; +} + void init() { global_matrices.create_internals();