selection and next generation

thread
Brett 2024-07-09 03:57:58 -04:00
parent 46a22b520b
commit e367411220
11 changed files with 168 additions and 46 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(blt-gp VERSION 0.0.52) project(blt-gp VERSION 0.0.53)
include(CTest) include(CTest)

View File

@ -149,8 +149,8 @@ int main()
// results->child1.print(program, std::cout, print_literals, pretty_print, print_returns); // results->child1.print(program, std::cout, print_literals, pretty_print, print_returns);
// BLT_TRACE("Child 2: %f", results->child2.get_evaluation_value<float>(nullptr)); // BLT_TRACE("Child 2: %f", results->child2.get_evaluation_value<float>(nullptr));
// results->child2.print(program, std::cout, print_literals, pretty_print, print_returns); // results->child2.print(program, std::cout, print_literals, pretty_print, print_returns);
new_pop.get_individuals().push_back({std::move(results->child1)}); new_pop.get_individuals().emplace_back(std::move(results->child1));
new_pop.get_individuals().push_back({std::move(results->child2)}); new_pop.get_individuals().emplace_back(std::move(results->child2));
} else } else
{ {
switch (results.error()) switch (results.error())

View File

@ -114,7 +114,7 @@ int main()
BLT_INFO("Mutation:"); BLT_INFO("Mutation:");
for (auto& tree : pop.for_each_tree()) for (auto& tree : pop.for_each_tree())
{ {
new_pop.get_individuals().push_back({mutator.apply(program, generator, tree)}); new_pop.get_individuals().emplace_back(mutator.apply(program, generator, tree));
} }
BLT_INFO("Post-Mutation"); BLT_INFO("Post-Mutation");
for (auto& tree : new_pop.for_each_tree()) for (auto& tree : new_pop.for_each_tree())

View File

@ -96,7 +96,7 @@ int main()
BLT_INFO("Mutation:"); BLT_INFO("Mutation:");
for (auto& tree : pop.for_each_tree()) for (auto& tree : pop.for_each_tree())
{ {
new_pop.get_individuals().push_back({mutator.apply(program, generator, tree)}); new_pop.get_individuals().emplace_back(mutator.apply(program, generator, tree));
} }
BLT_INFO("Post-Mutation"); BLT_INFO("Post-Mutation");
for (auto& tree : new_pop.for_each_tree()) for (auto& tree : new_pop.for_each_tree())

View File

@ -40,6 +40,7 @@
#include <blt/gp/typesystem.h> #include <blt/gp/typesystem.h>
#include <blt/gp/operations.h> #include <blt/gp/operations.h>
#include <blt/gp/transformers.h> #include <blt/gp/transformers.h>
#include <blt/gp/selection.h>
#include <blt/gp/tree.h> #include <blt/gp/tree.h>
#include <blt/gp/stack.h> #include <blt/gp/stack.h>
@ -230,9 +231,20 @@ namespace blt::gp
struct config_t struct config_t
{ {
blt::size_t population_size = 500; blt::size_t population_size = 500;
blt::size_t max_generations = 50;
blt::size_t initial_min_tree_size = 3; blt::size_t initial_min_tree_size = 3;
blt::size_t initial_max_tree_size = 10; blt::size_t initial_max_tree_size = 10;
// percent chance that we will do crossover
double crossover_chance = 0.8;
// percent chance that we will do mutation
double mutation_chance = 0.1;
// everything else will just be selected
blt::size_t elites = 0;
bool try_mutation_on_crossover_failure = true;
std::reference_wrapper<mutation_t> mutator; std::reference_wrapper<mutation_t> mutator;
std::reference_wrapper<crossover_t> crossover; std::reference_wrapper<crossover_t> crossover;
std::reference_wrapper<population_initializer_t> pop_initializer; std::reference_wrapper<population_initializer_t> pop_initializer;
@ -243,14 +255,9 @@ namespace blt::gp
// default config with a user specified initializer // default config with a user specified initializer
config_t(const std::reference_wrapper<population_initializer_t>& popInitializer); // NOLINT config_t(const std::reference_wrapper<population_initializer_t>& popInitializer); // NOLINT
config_t(size_t populationSize, size_t initialMinTreeSize, size_t initialMaxTreeSize); config_t(size_t populationSize, const std::reference_wrapper<population_initializer_t>& popInitializer);
config_t(size_t populationSize, size_t initialMinTreeSize, size_t initialMaxTreeSize, config_t(size_t populationSize); // NOLINT
const std::reference_wrapper<population_initializer_t>& popInitializer);
config_t(size_t populationSize, size_t initialMinTreeSize, size_t initialMaxTreeSize,
const std::reference_wrapper<mutation_t>& mutator, const std::reference_wrapper<crossover_t>& crossover,
const std::reference_wrapper<population_initializer_t>& popInitializer);
config_t& set_pop_size(blt::size_t pop) config_t& set_pop_size(blt::size_t pop)
{ {
@ -287,6 +294,36 @@ namespace blt::gp
pop_initializer = ref; pop_initializer = ref;
return *this; return *this;
} }
config_t& set_elite_count(blt::size_t new_elites)
{
elites = new_elites;
return *this;
}
config_t& set_crossover_chance(double new_crossover_chance)
{
crossover_chance = new_crossover_chance;
return *this;
}
config_t& set_mutation_chance(double new_mutation_chance)
{
mutation_chance = new_mutation_chance;
return *this;
}
config_t& set_max_generations(blt::size_t new_max_generations)
{
max_generations = new_max_generations;
return *this;
}
config_t& set_try_mutation_on_crossover_failure(bool new_try_mutation_on_crossover_failure)
{
try_mutation_on_crossover_failure = new_try_mutation_on_crossover_failure;
return *this;
}
}; };
/** /**
@ -307,6 +344,60 @@ namespace blt::gp
void generate_population(type_id root_type); void generate_population(type_id root_type);
template<typename Crossover, typename Mutation, typename Reproduction>
void create_next_generation(Crossover&& crossover_selection, Mutation&& mutation_selection, Reproduction&& reproduction_selection)
{
static std::uniform_real_distribution dist(0.0, 1.0);
double total_prob = config.mutation_chance + config.crossover_chance;
double crossover_chance = config.crossover_chance / total_prob;
double mutation_chance = crossover_chance + config.mutation_chance / total_prob;
// should already be empty
next_pop.clear();
crossover_selection.pre_process(*this, current_pop, current_stats);
mutation_selection.pre_process(*this, current_pop, current_stats);
reproduction_selection.pre_process(*this, current_pop, current_stats);
for (blt::size_t i = 0; i < config.population_size; i++)
{
auto type = dist(get_random());
if (type > crossover_chance && type < mutation_chance)
{
// crossover
auto& p1 = crossover_selection.select(*this, current_pop, current_stats);
auto& p2 = crossover_selection.select(*this, current_pop, current_stats);
auto results = config.crossover.get().apply(*this, p1, p2);
// if crossover fails, we can check for mutation on these guys. otherwise straight copy them into the next pop
if (results)
{
next_pop.get_individuals().emplace_back(std::move(results->child1));
next_pop.get_individuals().emplace_back(std::move(results->child2));
} else
{
if (config.try_mutation_on_crossover_failure && choice(config.mutation_chance))
next_pop.get_individuals().emplace_back(std::move(config.mutator.get().apply(*this, p1)));
else
next_pop.get_individuals().push_back(p1);
if (config.try_mutation_on_crossover_failure && choice(config.mutation_chance))
next_pop.get_individuals().emplace_back(std::move(config.mutator.get().apply(*this, p2)));
else
next_pop.get_individuals().push_back(p2);
}
} else if (type > mutation_chance)
{
// mutation
auto& p = mutation_selection.select(*this, current_pop, current_stats);
next_pop.get_individuals().emplace_back(std::move(config.mutator.get().apply(*this, p)));
} else
{
// reproduction
auto& p = reproduction_selection.select(*this, current_pop, current_stats);
next_pop.get_individuals().push_back(p);
}
}
}
/** /**
* takes in a lambda for the fitness evaluation function (must return a value convertable to double) * takes in a lambda for the fitness evaluation function (must return a value convertable to double)
@ -361,7 +452,7 @@ namespace blt::gp
void next_generation() void next_generation()
{ {
current_pop = next_pop; current_pop = std::move(next_pop);
current_generation++; current_generation++;
} }

View File

@ -79,6 +79,13 @@ namespace blt::gp
{ {
blt::size_t replacement_min_depth = 3; blt::size_t replacement_min_depth = 3;
blt::size_t replacement_max_depth = 7; blt::size_t replacement_max_depth = 7;
std::reference_wrapper<tree_generator_t> generator;
config_t(tree_generator_t& generator): generator(generator) // NOLINT
{}
config_t();
}; };
mutation_t() = default; mutation_t() = default;
@ -86,9 +93,10 @@ namespace blt::gp
explicit mutation_t(const config_t& config): config(config) explicit mutation_t(const config_t& config): config(config)
{} {}
virtual tree_t apply(gp_program& program, tree_generator_t& generator, const tree_t& p); // NOLINT virtual tree_t apply(gp_program& program, const tree_t& p); // NOLINT
virtual ~mutation_t() = default; virtual ~mutation_t() = default;
private: private:
config_t config; config_t config;
}; };

View File

@ -107,7 +107,6 @@ namespace blt::gp
private: private:
std::vector<op_container_t> operations; std::vector<op_container_t> operations;
blt::gp::stack_allocator values; blt::gp::stack_allocator values;
blt::size_t depth;
}; };
struct individual struct individual
@ -116,6 +115,22 @@ namespace blt::gp
double raw_fitness = 0; double raw_fitness = 0;
double adjusted_fitness = 0; double adjusted_fitness = 0;
double probability = 0; double probability = 0;
individual() = default;
explicit individual(tree_t&& tree): tree(tree)
{}
explicit individual(const tree_t& tree): tree(tree)
{}
individual(const individual&) = default;
individual(individual&&) = default;
individual& operator=(const individual&) = delete;
individual& operator=(individual&&) = default;
}; };
struct population_stats struct population_stats
@ -213,6 +228,21 @@ namespace blt::gp
{ {
return individuals.end(); return individuals.end();
} }
void clear()
{
individuals.clear();
}
population_t() = default;
population_t(const population_t&) = default;
population_t(population_t&&) = default;
population_t& operator=(const population_t&) = delete;
population_t& operator=(population_t&&) = default;
private: private:
std::vector<individual> individuals; std::vector<individual> individuals;

View File

@ -120,7 +120,7 @@ namespace blt::gp
population_t pop; population_t pop;
for (auto i = 0ul; i < args.size; i++) for (auto i = 0ul; i < args.size; i++)
pop.get_individuals().push_back({grow.generate(args.to_gen_args())}); pop.get_individuals().emplace_back(grow.generate(args.to_gen_args()));
return pop; return pop;
} }
@ -130,7 +130,7 @@ namespace blt::gp
population_t pop; population_t pop;
for (auto i = 0ul; i < args.size; i++) for (auto i = 0ul; i < args.size; i++)
pop.get_individuals().push_back({full.generate(args.to_gen_args())}); pop.get_individuals().emplace_back(full.generate(args.to_gen_args()));
return pop; return pop;
} }
@ -142,9 +142,9 @@ namespace blt::gp
for (auto i = 0ul; i < args.size; i++) for (auto i = 0ul; i < args.size; i++)
{ {
if (args.program.choice()) if (args.program.choice())
pop.get_individuals().push_back({full.generate(args.to_gen_args())}); pop.get_individuals().emplace_back(full.generate(args.to_gen_args()));
else else
pop.get_individuals().push_back({grow.generate(args.to_gen_args())}); pop.get_individuals().emplace_back(grow.generate(args.to_gen_args()));
} }
return pop; return pop;
@ -162,18 +162,18 @@ namespace blt::gp
for (auto i = 0ul; i < per_step; i++) for (auto i = 0ul; i < per_step; i++)
{ {
if (args.program.choice()) if (args.program.choice())
pop.get_individuals().push_back({full.generate({args.program, args.root_type, args.min_depth, depth})}); pop.get_individuals().emplace_back(full.generate({args.program, args.root_type, args.min_depth, depth}));
else else
pop.get_individuals().push_back({grow.generate({args.program, args.root_type, args.min_depth, depth})}); pop.get_individuals().emplace_back(grow.generate({args.program, args.root_type, args.min_depth, depth}));
} }
} }
for (auto i = 0ul; i < remainder; i++) for (auto i = 0ul; i < remainder; i++)
{ {
if (args.program.choice()) if (args.program.choice())
pop.get_individuals().push_back({full.generate(args.to_gen_args())}); pop.get_individuals().emplace_back(full.generate(args.to_gen_args()));
else else
pop.get_individuals().push_back({grow.generate(args.to_gen_args())}); pop.get_individuals().emplace_back(grow.generate(args.to_gen_args()));
} }
blt_assert(pop.get_individuals().size() == args.size); blt_assert(pop.get_individuals().size() == args.size);

View File

@ -27,24 +27,6 @@ namespace blt::gp
static crossover_t s_crossover; static crossover_t s_crossover;
static ramped_half_initializer_t s_init; static ramped_half_initializer_t s_init;
gp_program::config_t::config_t(size_t populationSize, size_t initialMinTreeSize, size_t initialMaxTreeSize):
population_size(populationSize), initial_min_tree_size(initialMinTreeSize), initial_max_tree_size(initialMaxTreeSize), mutator(s_mutator),
crossover(s_crossover), pop_initializer(s_init)
{}
gp_program::config_t::config_t(size_t populationSize, size_t initialMinTreeSize, size_t initialMaxTreeSize,
const std::reference_wrapper<mutation_t>& mutator, const std::reference_wrapper<crossover_t>& crossover,
const std::reference_wrapper<population_initializer_t>& popInitializer):
population_size(populationSize), initial_min_tree_size(initialMinTreeSize), initial_max_tree_size(initialMaxTreeSize), mutator(mutator),
crossover(crossover), pop_initializer(popInitializer)
{}
gp_program::config_t::config_t(size_t populationSize, size_t initialMinTreeSize, size_t initialMaxTreeSize,
const std::reference_wrapper<population_initializer_t>& popInitializer):
population_size(populationSize), initial_min_tree_size(initialMinTreeSize), initial_max_tree_size(initialMaxTreeSize),
mutator(s_mutator), crossover(s_crossover), pop_initializer(popInitializer)
{}
gp_program::config_t::config_t(): mutator(s_mutator), crossover(s_crossover), pop_initializer(s_init) gp_program::config_t::config_t(): mutator(s_mutator), crossover(s_crossover), pop_initializer(s_init)
{ {
@ -52,7 +34,13 @@ namespace blt::gp
gp_program::config_t::config_t(const std::reference_wrapper<population_initializer_t>& popInitializer): gp_program::config_t::config_t(const std::reference_wrapper<population_initializer_t>& popInitializer):
mutator(s_mutator), crossover(s_crossover), pop_initializer(popInitializer) mutator(s_mutator), crossover(s_crossover), pop_initializer(popInitializer)
{ {}
} gp_program::config_t::config_t(size_t populationSize, const std::reference_wrapper<population_initializer_t>& popInitializer):
population_size(populationSize), mutator(s_mutator), crossover(s_crossover), pop_initializer(s_init)
{}
gp_program::config_t::config_t(size_t populationSize):
population_size(populationSize), mutator(s_mutator), crossover(s_crossover), pop_initializer(s_init)
{}
} }

View File

@ -80,7 +80,7 @@ namespace blt::gp
return ind->tree; return ind->tree;
} }
tree_t& select_fitness_proportionate_t::select(gp_program& program, population_t& pop, population_stats& stats) tree_t& select_fitness_proportionate_t::select(gp_program& program, population_t& pop, population_stats&)
{ {
static std::uniform_real_distribution dist(0.0, 1.0); static std::uniform_real_distribution dist(0.0, 1.0);
auto choice = dist(program.get_random()); auto choice = dist(program.get_random());

View File

@ -22,6 +22,8 @@
namespace blt::gp namespace blt::gp
{ {
grow_generator_t grow_generator;
blt::expected<crossover_t::result_t, crossover_t::error_t> crossover_t::apply(gp_program& program, const tree_t& p1, const tree_t& p2) // NOLINT blt::expected<crossover_t::result_t, crossover_t::error_t> crossover_t::apply(gp_program& program, const tree_t& p1, const tree_t& p2) // NOLINT
{ {
result_t result{p1, p2}; result_t result{p1, p2};
@ -254,7 +256,7 @@ namespace blt::gp
return result; return result;
} }
tree_t mutation_t::apply(gp_program& program, tree_generator_t& generator, const tree_t& p) tree_t mutation_t::apply(gp_program& program, const tree_t& p)
{ {
auto c = p; auto c = p;
@ -305,7 +307,7 @@ namespace blt::gp
ops.erase(begin_p, end_p); ops.erase(begin_p, end_p);
auto new_tree = generator.generate({program, type_info.return_type, config.replacement_min_depth, config.replacement_max_depth}); auto new_tree = config.generator.get().generate({program, type_info.return_type, config.replacement_min_depth, config.replacement_max_depth});
auto& new_ops = new_tree.get_operations(); auto& new_ops = new_tree.get_operations();
auto& new_vals = new_tree.get_values(); auto& new_vals = new_tree.get_values();
@ -329,4 +331,7 @@ namespace blt::gp
return c; return c;
} }
mutation_t::config_t::config_t(): generator(grow_generator)
{}
} }