blt-gp/include/blt/gp/selection.h

250 lines
10 KiB
C++

#pragma once
/*
* Copyright (C) 2024 Brett Terpstra
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
#ifndef BLT_GP_SELECTION_H
#define BLT_GP_SELECTION_H
#include <blt/gp/fwdecl.h>
#include <blt/gp/tree.h>
#include <blt/gp/config.h>
#include <blt/gp/random.h>
#include <blt/std/assert.h>
namespace blt::gp
{
struct selector_args
{
gp_program& program;
population_t& next_pop;
population_t& current_pop;
population_stats& current_stats;
prog_config_t& config;
random_t& random;
};
constexpr inline auto perform_elitism = [](const selector_args& args) {
auto& [program, next_pop, current_pop, current_stats, config, random] = args;
if (config.elites > 0)
{
std::vector<std::pair<std::size_t, double>> values;
for (blt::size_t i = 0; i < config.elites; i++)
values.emplace_back(i, current_pop.get_individuals()[i].fitness.adjusted_fitness);
for (const auto& ind : blt::enumerate(current_pop.get_individuals()))
{
for (blt::size_t i = 0; i < config.elites; i++)
{
// BLT_INFO("%lf >= %lf? // %lf (indexes: %ld %ld)", ind.second.fitness.adjusted_fitness, values[i].second,
// ind.second.fitness.raw_fitness, ind.first, values[i].first);
if (ind.second.fitness.adjusted_fitness >= values[i].second)
{
bool doesnt_contain = true;
for (blt::size_t j = 0; j < config.elites; j++)
{
if (ind.first == values[j].first)
doesnt_contain = false;
}
if (doesnt_contain)
values[i] = {ind.first, ind.second.fitness.adjusted_fitness};
break;
}
}
}
for (blt::size_t i = 0; i < config.elites; i++)
next_pop.get_individuals().push_back(current_pop.get_individuals()[values[i].first]);
}
};
template<typename Crossover, typename Mutation, typename Reproduction>
constexpr inline auto proportionate_next_pop_creator = [](
const selector_args& args, Crossover crossover_selection, Mutation mutation_selection, Reproduction reproduction_selection) {
auto& [program, next_pop, current_pop, current_stats, config, random] = args;
double total_prob = config.mutation_chance + config.crossover_chance;
double crossover_chance = config.crossover_chance / total_prob;
double mutation_chance = crossover_chance + config.mutation_chance / total_prob;
perform_elitism(args);
while (next_pop.get_individuals().size() < config.population_size)
{
auto type = random.get_double();
if (type > crossover_chance && type < mutation_chance)
{
// crossover
auto& p1 = crossover_selection.select(program, current_pop, current_stats);
auto& p2 = crossover_selection.select(program, current_pop, current_stats);
auto results = config.crossover.get().apply(program, p1, p2);
// if crossover fails, we can check for mutation on these guys. otherwise straight copy them into the next pop
if (results)
{
next_pop.get_individuals().emplace_back(std::move(results->child1));
// annoying check
if (next_pop.get_individuals().size() < config.population_size)
next_pop.get_individuals().emplace_back(std::move(results->child2));
} else
{
if (config.try_mutation_on_crossover_failure && random.choice(config.mutation_chance))
next_pop.get_individuals().emplace_back(std::move(config.mutator.get().apply(program, p1)));
else
next_pop.get_individuals().push_back(individual{p1});
// annoying check.
if (next_pop.get_individuals().size() < config.population_size)
{
if (config.try_mutation_on_crossover_failure && random.choice(config.mutation_chance))
next_pop.get_individuals().emplace_back(std::move(config.mutator.get().apply(program, p2)));
else
next_pop.get_individuals().push_back(individual{p2});
}
}
} else if (type > mutation_chance)
{
// mutation
auto& p = mutation_selection.select(program, current_pop, current_stats);
next_pop.get_individuals().emplace_back(std::move(config.mutator.get().apply(program, p)));
} else
{
// reproduction
auto& p = reproduction_selection.select(program, current_pop, current_stats);
next_pop.get_individuals().push_back(individual{p});
}
}
};
template<typename Crossover, typename Mutation, typename Reproduction>
constexpr inline auto default_next_pop_creator = [](
const blt::gp::selector_args& args, Crossover crossover_selection, Mutation mutation_selection, Reproduction reproduction_selection) {
auto& [program, next_pop, current_pop, current_stats, config, random] = args;
perform_elitism(args);
while (next_pop.get_individuals().size() < config.population_size)
{
int sel = random.get_i32(0, 3);
switch (sel)
{
case 0:
// everyone gets a chance once per loop.
if (random.choice(config.crossover_chance))
{
// crossover
auto& p1 = crossover_selection.select(program, current_pop, current_stats);
auto& p2 = crossover_selection.select(program, current_pop, current_stats);
auto results = config.crossover.get().apply(program, p1, p2);
// if crossover fails, we can check for mutation on these guys. otherwise straight copy them into the next pop
if (results)
{
next_pop.get_individuals().emplace_back(std::move(results->child1));
// annoying check
if (next_pop.get_individuals().size() < config.population_size)
next_pop.get_individuals().emplace_back(std::move(results->child2));
}
}
break;
case 1:
if (random.choice(config.mutation_chance))
{
// mutation
auto& p = mutation_selection.select(program, current_pop, current_stats);
next_pop.get_individuals().emplace_back(std::move(config.mutator.get().apply(program, p)));
}
break;
case 2:
if (config.reproduction_chance > 0 && random.choice(config.reproduction_chance))
{
// reproduction
auto& p = reproduction_selection.select(program, current_pop, current_stats);
next_pop.get_individuals().push_back(individual{p});
}
break;
default:
BLT_ABORT("This is not possible!");
}
}
};
class selection_t
{
public:
/**
* @param program gp program to select with, used in randoms
* @param pop population to select from
* @param stats the populations statistics
* @return
*/
virtual tree_t& select(gp_program& program, population_t& pop, population_stats& stats) = 0;
virtual void pre_process(gp_program&, population_t&, population_stats&)
{}
virtual ~selection_t() = default;
};
class select_best_t : public selection_t
{
public:
tree_t& select(gp_program& program, population_t& pop, population_stats& stats) final;
};
class select_worst_t : public selection_t
{
public:
tree_t& select(gp_program& program, population_t& pop, population_stats& stats) final;
};
class select_random_t : public selection_t
{
public:
tree_t& select(gp_program& program, population_t& pop, population_stats& stats) final;
};
class select_tournament_t : public selection_t
{
public:
explicit select_tournament_t(blt::size_t selection_size = 3): selection_size(selection_size)
{
if (selection_size < 1)
BLT_ABORT("Unable to select with this size. Must select at least 1 individual!");
}
tree_t& select(gp_program& program, population_t& pop, population_stats& stats) final;
private:
const blt::size_t selection_size;
};
class select_fitness_proportionate_t : public selection_t
{
public:
void pre_process(gp_program& program, population_t& pop, population_stats& stats) final;
tree_t& select(gp_program& program, population_t& pop, population_stats& stats) final;
};
}
#endif //BLT_GP_SELECTION_H