main
Brett 2024-02-01 18:46:57 -05:00
parent 0e08a15b5e
commit 298c640d03
1 changed files with 52 additions and 30 deletions

View File

@ -228,6 +228,16 @@ class tree
class gp_population class gp_population
{ {
public: public:
enum class crossover_error_t
{
};
struct selection
{
blt::size_t index;
tree* t_ref;
};
struct gp_i struct gp_i
{ {
std::unique_ptr<tree> t = nullptr; std::unique_ptr<tree> t = nullptr;
@ -262,9 +272,10 @@ class gp_population
evaluate_population(); evaluate_population();
} }
[[nodiscard]] static tree* tournament_select(population_storage_t& p) [[nodiscard]] static selection tournament_select(population_storage_t& p)
{ {
tree* n = nullptr; tree* n = nullptr;
blt::size_t ni = 0;
float fitness = -2 * 8192; float fitness = -2 * 8192;
for (int i = 0; i < TOURNAMENT_SIZE; i++) for (int i = 0; i < TOURNAMENT_SIZE; i++)
{ {
@ -272,18 +283,18 @@ class gp_population
do do
{ {
index = random_individual_index(); index = random_individual_index();
} while (n == p[index].t.get());
auto& v = p[index]; auto& v = p[index];
if (v.fitness >= fitness) if (v.fitness >= fitness)
{ {
n = v.t.get(); n = v.t.get();
ni = index;
fitness = v.fitness; fitness = v.fitness;
} }
//BLT_TRACE("%d: %p -> %f ? %p -> %f || %d %d", index, v.t.get(), v.fitness, n, fitness, n != v.t.get(), v.fitness > fitness);
} while (n == p[index].t.get());
} }
//BLT_DEBUG("%p -> %f", n, fitness);
BLT_ASSERT(n != nullptr); BLT_ASSERT(n != nullptr);
return n; return {ni, n};
} }
void run_step() void run_step()
@ -302,49 +313,41 @@ class gp_population
blt::size_t crossover_count = 0; blt::size_t crossover_count = 0;
blt::size_t mutation_count = 0; blt::size_t mutation_count = 0;
BLT_TRACE("Running Crossover"); BLT_TRACE("Running Crossover");
for (blt::size_t i = insert_pos; i < POPULATION_SIZE; i++) for (blt::size_t i = 0; i < POPULATION_SIZE; i++)
{ {
if (insert_pos >= POPULATION_SIZE)
break;
if (rand(engine) < CROSSOVER_RATE) if (rand(engine) < CROSSOVER_RATE)
{ {
tree* p1 = select(); selection p1 = tournament_select(new_pop);
tree* p2; selection p2{};
do do
{ {
p2 = select(); p2 = tournament_select(new_pop);
} while (p2 == p1); } while (p2.t_ref == p1.t_ref);
if (auto r = tree::crossover(p1, p2)) if (auto r = tree::crossover(p1.t_ref, p2.t_ref))
{ {
new_pop[insert_pos++] = {std::move(r->c1)}; new_pop[p1.index] = {std::move(r->c1)};
new_pop[insert_pos++] = {std::move(r->c2)}; new_pop[p2.index] = {std::move(r->c2)};
crossover_count++; crossover_count++;
} }
} }
} }
BLT_TRACE("Running Mutation"); BLT_TRACE("Running Mutation");
for (blt::size_t i = insert_pos; i < POPULATION_SIZE; i++) for (blt::size_t i = 0; i < POPULATION_SIZE; i++)
{ {
if (insert_pos >= POPULATION_SIZE)
break;
if (rand(engine) < MUTATION_RATE) if (rand(engine) < MUTATION_RATE)
{ {
auto* p1 = select(); auto p1 = tournament_select(new_pop);
if (auto r = tree::mutate(p1)) if (auto r = tree::mutate(p1.t_ref))
{ {
new_pop[insert_pos++] = gp_i{std::move(r->c), 0}; new_pop[p1.index] = gp_i{std::move(r->c), 0};
mutation_count++; mutation_count++;
} }
} }
} }
while (insert_pos < POPULATION_SIZE)
{
new_pop[insert_pos++] = {select()->clone()};
}
pop = std::move(new_pop); pop = std::move(new_pop);
BLT_TRACE("ran %d crossovers and %d mutations", crossover_count, mutation_count); BLT_TRACE("ran %d crossovers and %d mutations", crossover_count, mutation_count);
} }
@ -414,6 +417,25 @@ void print_bits(float f)
std::cout << std::endl; std::cout << std::endl;
} }
enum class error
{
S1, S2, S3
};
struct simple
{
float f;
};
blt::expected<float, error> t1(bool b)
{
float f = 50.0f * b;
if (b)
return f;
else
return error::S1;
}
void init() void init()
{ {
global_matrices.create_internals(); global_matrices.create_internals();