working on swaping new populations

shared
Brett 2024-08-30 23:27:25 -04:00
parent cc76f2791a
commit d4e6c40fe1
12 changed files with 264 additions and 256 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(blt-gp VERSION 0.1.34) project(blt-gp VERSION 0.1.35)
include(CTest) include(CTest)
@ -16,6 +16,8 @@ set(CMAKE_CXX_STANDARD 17)
set(THREADS_PREFER_PTHREAD_FLAG ON) set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
SET(CMAKE_CXX_FLAGS_RELEASE "-O3 -g")
if (NOT TARGET BLT) if (NOT TARGET BLT)
add_subdirectory(lib/blt) add_subdirectory(lib/blt)
endif () endif ()

View File

@ -62,8 +62,7 @@ blt::gp::prog_config_t config = blt::gp::prog_config_t()
.set_pop_size(5000) .set_pop_size(5000)
.set_thread_count(0); .set_thread_count(0);
blt::gp::type_provider type_system; blt::gp::gp_program program{SEED_FUNC, config};
blt::gp::gp_program program{type_system, SEED_FUNC, config};
auto lit = blt::gp::operation_t([]() { auto lit = blt::gp::operation_t([]() {
return program.get_random().get_float(-32000.0f, 32000.0f); return program.get_random().get_float(-32000.0f, 32000.0f);
@ -197,6 +196,11 @@ struct test_results_t
{ {
return a.hits < b.hits; return a.hits < b.hits;
} }
friend bool operator>(const test_results_t& a, const test_results_t& b)
{
return a.hits > b.hits;
}
}; };
test_results_t test_individual(blt::gp::individual_t& i) test_results_t test_individual(blt::gp::individual_t& i)
@ -251,37 +255,23 @@ int main(int argc, const char** argv)
load_rice_data(rice_file_path); load_rice_data(rice_file_path);
BLT_DEBUG("Setup Types and Operators"); BLT_DEBUG("Setup Types and Operators");
type_system.register_type<float>();
blt::gp::operator_builder<rice_record> builder{type_system}; blt::gp::operator_builder<rice_record> builder{};
program.set_operations(builder.build(add, sub, mul, pro_div, op_exp, op_log, lit, op_area, op_perimeter, op_major_axis_length, program.set_operations(builder.build(add, sub, mul, pro_div, op_exp, op_log, lit, op_area, op_perimeter, op_major_axis_length,
op_minor_axis_length, op_eccentricity, op_convex_area, op_extent)); op_minor_axis_length, op_eccentricity, op_convex_area, op_extent));
BLT_DEBUG("Generate Initial Population"); BLT_DEBUG("Generate Initial Population");
auto sel = blt::gp::select_tournament_t{}; auto sel = blt::gp::select_tournament_t{};
program.generate_population(type_system.get_type<float>().id(), fitness_function, sel, sel, sel); program.generate_population(program.get_typesystem().get_type<float>().id(), fitness_function, sel, sel, sel);
BLT_DEBUG("Begin Generation Loop"); BLT_DEBUG("Begin Generation Loop");
while (!program.should_terminate()) while (!program.should_terminate())
{ {
BLT_TRACE("------------{Begin Generation %ld}------------", program.get_current_generation()); BLT_TRACE("------------{Begin Generation %ld}------------", program.get_current_generation());
BLT_TRACE("Creating next generation"); BLT_TRACE("Creating next generation");
#ifdef BLT_TRACK_ALLOCATIONS
auto gen_alloc = blt::gp::tracker.start_measurement();
#endif
BLT_START_INTERVAL("Rice Classification", "Gen"); BLT_START_INTERVAL("Rice Classification", "Gen");
program.create_next_generation(); program.create_next_generation();
BLT_END_INTERVAL("Rice Classification", "Gen"); BLT_END_INTERVAL("Rice Classification", "Gen");
#ifdef BLT_TRACK_ALLOCATIONS
blt::gp::tracker.stop_measurement(gen_alloc);
BLT_TRACE("Generation Allocated %ld times with a total of %s", gen_alloc.getAllocationDifference(),
blt::byte_convert_t(gen_alloc.getAllocatedByteDifference()).convert_to_nearest_type().to_pretty_string().c_str());
auto fitness_alloc = blt::gp::tracker.start_measurement();
#endif
BLT_TRACE("Move to next generation"); BLT_TRACE("Move to next generation");
BLT_START_INTERVAL("Rice Classification", "Fitness"); BLT_START_INTERVAL("Rice Classification", "Fitness");
program.next_generation(); program.next_generation();
@ -294,13 +284,6 @@ int main(int argc, const char** argv)
BLT_TRACE("Best fitness: %lf", stats.best_fitness.load()); BLT_TRACE("Best fitness: %lf", stats.best_fitness.load());
BLT_TRACE("Worst fitness: %lf", stats.worst_fitness.load()); BLT_TRACE("Worst fitness: %lf", stats.worst_fitness.load());
BLT_TRACE("Overall fitness: %lf", stats.overall_fitness.load()); BLT_TRACE("Overall fitness: %lf", stats.overall_fitness.load());
#ifdef BLT_TRACK_ALLOCATIONS
blt::gp::tracker.stop_measurement(fitness_alloc);
BLT_TRACE("Fitness Allocated %ld times with a total of %s", fitness_alloc.getAllocationDifference(),
blt::byte_convert_t(fitness_alloc.getAllocatedByteDifference()).convert_to_nearest_type().to_pretty_string().c_str());
#endif
BLT_TRACE("----------------------------------------------"); BLT_TRACE("----------------------------------------------");
std::cout << std::endl; std::cout << std::endl;
} }
@ -311,7 +294,7 @@ int main(int argc, const char** argv)
for (auto& i : program.get_current_pop().get_individuals()) for (auto& i : program.get_current_pop().get_individuals())
results.emplace_back(test_individual(i), &i); results.emplace_back(test_individual(i), &i);
std::sort(results.begin(), results.end(), [](const auto& a, const auto& b) { std::sort(results.begin(), results.end(), [](const auto& a, const auto& b) {
return !(a.first < b.first); return a.first > b.first;
}); });
BLT_INFO("Best results:"); BLT_INFO("Best results:");
@ -343,7 +326,6 @@ int main(int argc, const char** argv)
BLT_DEBUG("Osmancik Osmancik: %ld", record.oo); BLT_DEBUG("Osmancik Osmancik: %ld", record.oo);
BLT_DEBUG("Osmancik Cammeo: %ld", record.oc); BLT_DEBUG("Osmancik Cammeo: %ld", record.oc);
BLT_DEBUG("Fitness: %lf, stand: %lf, raw: %lf", i.fitness.adjusted_fitness, i.fitness.standardized_fitness, i.fitness.raw_fitness); BLT_DEBUG("Fitness: %lf, stand: %lf, raw: %lf", i.fitness.adjusted_fitness, i.fitness.standardized_fitness, i.fitness.raw_fitness);
i.tree.print(program, std::cout);
std::cout << "\n"; std::cout << "\n";
} }
@ -360,15 +342,6 @@ int main(int argc, const char** argv)
BLT_DEBUG("Osmancik Cammeo: %ld", avg.oc); BLT_DEBUG("Osmancik Cammeo: %ld", avg.oc);
std::cout << "\n"; std::cout << "\n";
auto& stats = program.get_population_stats();
BLT_INFO("Stats:");
BLT_INFO("Average fitness: %lf", stats.average_fitness.load());
BLT_INFO("Best fitness: %lf", stats.best_fitness.load());
BLT_INFO("Worst fitness: %lf", stats.worst_fitness.load());
BLT_INFO("Overall fitness: %lf", stats.overall_fitness.load());
// TODO: make stats helper
BLT_PRINT_PROFILE("Rice Classification", blt::PRINT_CYCLES | blt::PRINT_THREAD | blt::PRINT_WALL); BLT_PRINT_PROFILE("Rice Classification", blt::PRINT_CYCLES | blt::PRINT_THREAD | blt::PRINT_WALL);
#ifdef BLT_TRACK_ALLOCATIONS #ifdef BLT_TRACK_ALLOCATIONS

View File

@ -44,8 +44,7 @@ blt::gp::prog_config_t config = blt::gp::prog_config_t()
.set_pop_size(500) .set_pop_size(500)
.set_thread_count(0); .set_thread_count(0);
blt::gp::type_provider type_system; blt::gp::gp_program program{SEED, config};
blt::gp::gp_program program{type_system, SEED, config};
auto lit = blt::gp::operation_t([]() { auto lit = blt::gp::operation_t([]() {
return program.get_random().get_float(-320.0f, 320.0f); return program.get_random().get_float(-320.0f, 320.0f);
@ -93,14 +92,12 @@ int main()
} }
BLT_DEBUG("Setup Types and Operators"); BLT_DEBUG("Setup Types and Operators");
type_system.register_type<float>(); blt::gp::operator_builder<context> builder{};
blt::gp::operator_builder<context> builder{type_system};
program.set_operations(builder.build(add, sub, mul, pro_div, op_sin, op_cos, op_exp, op_log, lit, op_x)); program.set_operations(builder.build(add, sub, mul, pro_div, op_sin, op_cos, op_exp, op_log, lit, op_x));
BLT_DEBUG("Generate Initial Population"); BLT_DEBUG("Generate Initial Population");
auto sel = blt::gp::select_tournament_t{}; auto sel = blt::gp::select_tournament_t{};
program.generate_population(type_system.get_type<float>().id(), fitness_function, sel, sel, sel); program.generate_population(program.get_typesystem().get_type<float>().id(), fitness_function, sel, sel, sel);
BLT_DEBUG("Begin Generation Loop"); BLT_DEBUG("Begin Generation Loop");
while (!program.should_terminate()) while (!program.should_terminate())

View File

@ -56,7 +56,6 @@
namespace blt::gp namespace blt::gp
{ {
struct argc_t struct argc_t
{ {
blt::u32 argc = 0; blt::u32 argc = 0;
@ -80,20 +79,22 @@ namespace blt::gp
detail::operator_func_t func; detail::operator_func_t func;
}; };
struct operator_storage struct program_operator_storage_t
{ {
// indexed from return TYPE ID, returns index of operator // indexed from return TYPE ID, returns index of operator
blt::expanding_buffer<std::vector<operator_id>> terminals; blt::expanding_buffer<std::vector<operator_id>> terminals;
blt::expanding_buffer<std::vector<operator_id>> non_terminals; blt::expanding_buffer<std::vector<operator_id>> non_terminals;
blt::expanding_buffer<std::vector<std::pair<operator_id, blt::size_t>>> operators_ordered_terminals; blt::expanding_buffer<std::vector<std::pair<operator_id, blt::size_t>>> operators_ordered_terminals;
// indexed from OPERATOR ID (operator number) // indexed from OPERATOR ID (operator number)
blt::hashset_t<operator_id> static_types; blt::hashset_t<operator_id> ephemeral_leaf_operators;
std::vector<operator_info> operators; std::vector<operator_info> operators;
std::vector<detail::print_func_t> print_funcs; std::vector<detail::print_func_t> print_funcs;
std::vector<detail::destroy_func_t> destroy_funcs; std::vector<detail::destroy_func_t> destroy_funcs;
std::vector<std::optional<std::string_view>> names; std::vector<std::optional<std::string_view>> names;
detail::eval_func_t eval_func; detail::eval_func_t eval_func;
type_provider system;
}; };
template<typename Context = detail::empty_t> template<typename Context = detail::empty_t>
@ -104,11 +105,10 @@ namespace blt::gp
friend class blt::gp::detail::operator_storage_test; friend class blt::gp::detail::operator_storage_test;
public: public:
explicit operator_builder(type_provider& system): system(system) explicit operator_builder() = default;
{}
template<typename... Operators> template<typename... Operators>
operator_storage& build(Operators& ... operators) program_operator_storage_t& build(Operators& ... operators)
{ {
std::vector<blt::size_t> sizes; std::vector<blt::size_t> sizes;
(sizes.push_back(add_operator(operators)), ...); (sizes.push_back(add_operator(operators)), ...);
@ -201,7 +201,7 @@ namespace blt::gp
return storage; return storage;
} }
operator_storage&& grab() program_operator_storage_t&& grab()
{ {
return std::move(storage); return std::move(storage);
} }
@ -210,10 +210,14 @@ namespace blt::gp
template<typename RawFunction, typename Return, typename... Args> template<typename RawFunction, typename Return, typename... Args>
auto add_operator(operation_t<RawFunction, Return(Args...)>& op) auto add_operator(operation_t<RawFunction, Return(Args...)>& op)
{ {
// check for types we can register
(storage.system.register_type<Args>(), ...);
storage.system.register_type<Return>();
auto total_size_required = stack_allocator::aligned_size(sizeof(Return)); auto total_size_required = stack_allocator::aligned_size(sizeof(Return));
((total_size_required += stack_allocator::aligned_size(sizeof(Args))), ...); ((total_size_required += stack_allocator::aligned_size(sizeof(Args))), ...);
auto return_type_id = system.get_type<Return>().id(); auto return_type_id = storage.system.get_type<Return>().id();
auto operator_id = blt::gp::operator_id(storage.operators.size()); auto operator_id = blt::gp::operator_id(storage.operators.size());
op.id = operator_id; op.id = operator_id;
@ -262,7 +266,7 @@ namespace blt::gp
}); });
storage.names.push_back(op.get_name()); storage.names.push_back(op.get_name());
if (op.is_ephemeral()) if (op.is_ephemeral())
storage.static_types.insert(operator_id); storage.ephemeral_leaf_operators.insert(operator_id);
return total_size_required * 2; return total_size_required * 2;
} }
@ -271,7 +275,7 @@ namespace blt::gp
{ {
if constexpr (!std::is_same_v<Context, detail::remove_cv_ref<T>>) if constexpr (!std::is_same_v<Context, detail::remove_cv_ref<T>>)
{ {
types.push_back(system.get_type<T>().id()); types.push_back(storage.system.get_type<T>().id());
} }
} }
@ -316,8 +320,7 @@ namespace blt::gp
call_jmp_table_internal(op, context, write_stack, read_stack, std::index_sequence_for<Operators...>(), operators...); call_jmp_table_internal(op, context, write_stack, read_stack, std::index_sequence_for<Operators...>(), operators...);
} }
type_provider& system; program_operator_storage_t storage;
operator_storage storage;
}; };
class gp_program class gp_program
@ -327,37 +330,83 @@ namespace blt::gp
* Note about context size: This is required as context is passed to every operator in the GP tree, this context will be provided by your * Note about context size: This is required as context is passed to every operator in the GP tree, this context will be provided by your
* call to one of the evaluator functions. This was the nicest way to provide this as C++ lacks reflection * call to one of the evaluator functions. This was the nicest way to provide this as C++ lacks reflection
* *
* @param system type system to use in tree generation
* @param engine random engine to use throughout the program. * @param engine random engine to use throughout the program.
* @param context_size number of arguments which are always present as "context" to the GP system / operators * @param context_size number of arguments which are always present as "context" to the GP system / operators
*/ */
explicit gp_program(type_provider& system, blt::u64 seed): explicit gp_program(blt::u64 seed): seed_func([seed] { return seed; })
system(system), seed_func([seed]{return seed;})
{ create_threads(); } { create_threads(); }
explicit gp_program(type_provider& system, blt::u64 seed, prog_config_t config): explicit gp_program(blt::u64 seed, prog_config_t config): seed_func([seed] { return seed; }), config(config)
system(system), seed_func([seed]{return seed;}), config(config)
{ create_threads(); } { create_threads(); }
explicit gp_program(type_provider& system, std::function<blt::u64()> seed_func): explicit gp_program(std::function<blt::u64()> seed_func): seed_func(std::move(seed_func))
system(system), seed_func(std::move(seed_func))
{ create_threads(); } { create_threads(); }
explicit gp_program(type_provider& system, std::function<blt::u64()> seed_func, prog_config_t config): explicit gp_program(std::function<blt::u64()> seed_func, prog_config_t config): seed_func(std::move(seed_func)), config(config)
system(system), seed_func(std::move(seed_func)), config(config)
{ create_threads(); } { create_threads(); }
~gp_program()
{
thread_helper.lifetime_over = true;
thread_helper.barrier.notify_all();
thread_helper.thread_function_condition.notify_all();
for (auto& thread : thread_helper.threads)
{
if (thread->joinable())
thread->join();
}
}
void create_next_generation() void create_next_generation()
{ {
#ifdef BLT_TRACK_ALLOCATIONS
auto gen_alloc = blt::gp::tracker.start_measurement();
#endif
// should already be empty // should already be empty
next_pop.clear(); next_pop.clear();
thread_helper.next_gen_left.store(config.population_size, std::memory_order_release); thread_helper.next_gen_left.store(config.population_size, std::memory_order_release);
(*thread_execution_service)(0); (*thread_execution_service)(0);
#ifdef BLT_TRACK_ALLOCATIONS
blt::gp::tracker.stop_measurement(gen_alloc);
BLT_TRACE("Generation Allocated %ld times with a total of %s", gen_alloc.getAllocationDifference(),
blt::byte_convert_t(gen_alloc.getAllocatedByteDifference()).convert_to_nearest_type().to_pretty_string().c_str());
#endif
}
void next_generation()
{
BLT_ASSERT_MSG(next_pop.get_individuals().size() == config.population_size,
("pop size: " + std::to_string(next_pop.get_individuals().size())).c_str());
std::swap(current_pop, next_pop);
current_generation++;
} }
void evaluate_fitness() void evaluate_fitness()
{ {
#ifdef BLT_TRACK_ALLOCATIONS
auto fitness_alloc = blt::gp::tracker.start_measurement();
#endif
evaluate_fitness_internal(); evaluate_fitness_internal();
#ifdef BLT_TRACK_ALLOCATIONS
blt::gp::tracker.stop_measurement(fitness_alloc);
BLT_TRACE("Fitness Allocated %ld times with a total of %s", fitness_alloc.getAllocationDifference(),
blt::byte_convert_t(fitness_alloc.getAllocatedByteDifference()).convert_to_nearest_type().to_pretty_string().c_str());
#endif
}
void reset_program(type_id root_type, bool eval_fitness_now = true)
{
current_generation = 0;
current_pop = config.pop_initializer.get().generate(
{*this, root_type, config.population_size, config.initial_min_tree_size, config.initial_max_tree_size});
if (eval_fitness_now)
evaluate_fitness_internal();
}
void kill()
{
thread_helper.lifetime_over = true;
} }
/** /**
@ -378,32 +427,39 @@ namespace blt::gp
using LambdaReturn = typename decltype(blt::meta::lambda_helper(fitness_function))::Return; using LambdaReturn = typename decltype(blt::meta::lambda_helper(fitness_function))::Return;
current_pop = config.pop_initializer.get().generate( current_pop = config.pop_initializer.get().generate(
{*this, root_type, config.population_size, config.initial_min_tree_size, config.initial_max_tree_size}); {*this, root_type, config.population_size, config.initial_min_tree_size, config.initial_max_tree_size});
next_pop = population_t(current_pop);
if (config.threads == 1) if (config.threads == 1)
{ {
BLT_INFO("Starting with single thread variant!"); BLT_INFO("Starting with single thread variant!");
thread_execution_service = new std::function( thread_execution_service = std::unique_ptr<std::function<void(blt::size_t)>>(new std::function(
[this, &fitness_function, &crossover_selection, &mutation_selection, &reproduction_selection, &func](blt::size_t) { [this, &fitness_function, &crossover_selection, &mutation_selection, &reproduction_selection, &func](blt::size_t) {
if (thread_helper.evaluation_left > 0) if (thread_helper.evaluation_left > 0)
{ {
for (const auto& ind : blt::enumerate(current_pop.get_individuals())) current_stats.normalized_fitness.clear();
double sum_of_prob = 0;
for (const auto& [index, ind] : blt::enumerate(current_pop.get_individuals()))
{ {
if constexpr (std::is_same_v<LambdaReturn, bool> || std::is_convertible_v<LambdaReturn, bool>) if constexpr (std::is_same_v<LambdaReturn, bool> || std::is_convertible_v<LambdaReturn, bool>)
{ {
auto result = fitness_function(ind.second.tree, ind.second.fitness, ind.first); auto result = fitness_function(ind.tree, ind.fitness, index);
if (result) if (result)
fitness_should_exit = true; fitness_should_exit = true;
} else } else
{ fitness_function(ind.tree, ind.fitness, index);
fitness_function(ind.second.tree, ind.second.fitness, ind.first);
if (ind.fitness.adjusted_fitness > current_stats.best_fitness)
current_stats.best_fitness = ind.fitness.adjusted_fitness;
if (ind.fitness.adjusted_fitness < current_stats.worst_fitness)
current_stats.worst_fitness = ind.fitness.adjusted_fitness;
current_stats.overall_fitness = current_stats.overall_fitness + ind.fitness.adjusted_fitness;
} }
for (auto& ind : current_pop)
if (ind.second.fitness.adjusted_fitness > current_stats.best_fitness) {
current_stats.best_fitness = ind.second.fitness.adjusted_fitness; auto prob = (ind.fitness.adjusted_fitness / current_stats.overall_fitness);
current_stats.normalized_fitness.push_back(sum_of_prob + prob);
if (ind.second.fitness.adjusted_fitness < current_stats.worst_fitness) sum_of_prob += prob;
current_stats.worst_fitness = ind.second.fitness.adjusted_fitness;
current_stats.overall_fitness = current_stats.overall_fitness + ind.second.fitness.adjusted_fitness;
} }
thread_helper.evaluation_left = 0; thread_helper.evaluation_left = 0;
} }
@ -413,9 +469,9 @@ namespace blt::gp
new_children.clear(); new_children.clear();
auto args = get_selector_args(new_children); auto args = get_selector_args(new_children);
crossover_selection.pre_process(*this, current_pop, current_stats); crossover_selection.pre_process(*this, current_pop);
mutation_selection.pre_process(*this, current_pop, current_stats); mutation_selection.pre_process(*this, current_pop);
reproduction_selection.pre_process(*this, current_pop, current_stats); reproduction_selection.pre_process(*this, current_pop);
perform_elitism(args); perform_elitism(args);
@ -427,12 +483,12 @@ namespace blt::gp
thread_helper.next_gen_left = 0; thread_helper.next_gen_left = 0;
} }
}); }));
} else } else
{ {
BLT_INFO("Starting thread execution service!"); BLT_INFO("Starting thread execution service!");
std::scoped_lock lock(thread_helper.thread_function_control); std::scoped_lock lock(thread_helper.thread_function_control);
thread_execution_service = new std::function( thread_execution_service = std::unique_ptr<std::function<void(blt::size_t)>>(new std::function(
[this, &fitness_function, &crossover_selection, &mutation_selection, &reproduction_selection, &func](blt::size_t id) { [this, &fitness_function, &crossover_selection, &mutation_selection, &reproduction_selection, &func](blt::size_t id) {
thread_helper.barrier.wait(); thread_helper.barrier.wait();
if (thread_helper.evaluation_left > 0) if (thread_helper.evaluation_left > 0)
@ -491,11 +547,20 @@ namespace blt::gp
auto args = get_selector_args(new_children); auto args = get_selector_args(new_children);
if (id == 0) if (id == 0)
{ {
crossover_selection.pre_process(*this, current_pop, current_stats); current_stats.normalized_fitness.clear();
double sum_of_prob = 0;
for (auto& ind : current_pop)
{
auto prob = (ind.fitness.adjusted_fitness / current_stats.overall_fitness);
current_stats.normalized_fitness.push_back(sum_of_prob + prob);
sum_of_prob += prob;
}
crossover_selection.pre_process(*this, current_pop);
if (&crossover_selection != &mutation_selection) if (&crossover_selection != &mutation_selection)
mutation_selection.pre_process(*this, current_pop, current_stats); mutation_selection.pre_process(*this, current_pop);
if (&crossover_selection != &reproduction_selection) if (&crossover_selection != &reproduction_selection)
reproduction_selection.pre_process(*this, current_pop, current_stats); reproduction_selection.pre_process(*this, current_pop);
perform_elitism(args); perform_elitism(args);
@ -531,75 +596,13 @@ namespace blt::gp
} }
} }
thread_helper.barrier.wait(); thread_helper.barrier.wait();
}); }));
thread_helper.thread_function_condition.notify_all(); thread_helper.thread_function_condition.notify_all();
} }
if (eval_fitness_now) if (eval_fitness_now)
evaluate_fitness_internal(); evaluate_fitness_internal();
} }
void reset_program(type_id root_type, bool eval_fitness_now = true)
{
current_generation = 0;
current_pop = config.pop_initializer.get().generate(
{*this, root_type, config.population_size, config.initial_min_tree_size, config.initial_max_tree_size});
if (eval_fitness_now)
evaluate_fitness_internal();
}
void next_generation()
{
BLT_ASSERT_MSG(next_pop.get_individuals().size() == config.population_size, ("pop size: " + std::to_string(next_pop.get_individuals().size())).c_str());
current_pop = std::move(next_pop);
current_generation++;
}
inline auto& get_current_pop()
{
return current_pop;
}
template<blt::size_t size>
std::array<blt::size_t, size> get_best_indexes()
{
std::array<blt::size_t, size> arr;
std::vector<std::pair<blt::size_t, double>> values;
values.reserve(current_pop.get_individuals().size());
for (const auto& ind : blt::enumerate(current_pop.get_individuals()))
values.emplace_back(ind.first, ind.second.fitness.adjusted_fitness);
std::sort(values.begin(), values.end(), [](const auto& a, const auto& b) {
return a.second > b.second;
});
for (blt::size_t i = 0; i < size; i++)
arr[i] = values[i].first;
return arr;
}
template<blt::size_t size>
auto get_best_trees()
{
return convert_array<std::array<std::reference_wrapper<tree_t>, size>>(get_best_indexes<size>(),
[this](auto&& arr, blt::size_t index) -> tree_t& {
return current_pop.get_individuals()[arr[index]].tree;
},
std::make_integer_sequence<blt::size_t, size>());
}
template<blt::size_t size>
auto get_best_individuals()
{
return convert_array<std::array<std::reference_wrapper<individual_t>, size>>(get_best_indexes<size>(),
[this](auto&& arr, blt::size_t index) -> individual_t& {
return current_pop.get_individuals()[arr[index]];
},
std::make_integer_sequence<blt::size_t, size>());
}
[[nodiscard]] bool should_terminate() const [[nodiscard]] bool should_terminate() const
{ {
return current_generation >= config.max_generations || fitness_should_exit; return current_generation >= config.max_generations || fitness_should_exit;
@ -610,13 +613,6 @@ namespace blt::gp
return thread_helper.lifetime_over; return thread_helper.lifetime_over;
} }
[[nodiscard]] random_t& get_random() const;
[[nodiscard]] inline type_provider& get_typesystem()
{
return system;
}
inline operator_id select_terminal(type_id id) inline operator_id select_terminal(type_id id)
{ {
// we wanted a terminal, but could not find one, so we will select from a function that has a terminal // we wanted a terminal, but could not find one, so we will select from a function that has a terminal
@ -642,47 +638,49 @@ namespace blt::gp
return get_random().select(storage.operators_ordered_terminals[id]).first; return get_random().select(storage.operators_ordered_terminals[id]).first;
} }
inline operator_info& get_operator_info(operator_id id) inline auto& get_current_pop()
{
return current_pop;
}
[[nodiscard]] random_t& get_random() const;
[[nodiscard]] inline type_provider& get_typesystem()
{
return storage.system;
}
[[nodiscard]] inline operator_info& get_operator_info(operator_id id)
{ {
return storage.operators[id]; return storage.operators[id];
} }
inline detail::print_func_t& get_print_func(operator_id id) [[nodiscard]] inline detail::print_func_t& get_print_func(operator_id id)
{ {
return storage.print_funcs[id]; return storage.print_funcs[id];
} }
inline detail::destroy_func_t& get_destroy_func(operator_id id) [[nodiscard]] inline detail::destroy_func_t& get_destroy_func(operator_id id)
{ {
return storage.destroy_funcs[id]; return storage.destroy_funcs[id];
} }
inline std::optional<std::string_view> get_name(operator_id id) [[nodiscard]] inline std::optional<std::string_view> get_name(operator_id id)
{ {
return storage.names[id]; return storage.names[id];
} }
inline std::vector<operator_id>& get_type_terminals(type_id id) [[nodiscard]] inline std::vector<operator_id>& get_type_terminals(type_id id)
{ {
return storage.terminals[id]; return storage.terminals[id];
} }
inline std::vector<operator_id>& get_type_non_terminals(type_id id) [[nodiscard]] inline std::vector<operator_id>& get_type_non_terminals(type_id id)
{ {
return storage.non_terminals[id]; return storage.non_terminals[id];
} }
inline bool is_static(operator_id id) [[nodiscard]] inline detail::eval_func_t& get_eval_func()
{
return storage.static_types.contains(static_cast<blt::size_t>(id));
}
inline void set_operations(operator_storage op)
{
storage = std::move(op);
}
inline detail::eval_func_t& get_eval_func()
{ {
return storage.eval_func; return storage.eval_func;
} }
@ -692,65 +690,63 @@ namespace blt::gp
return current_generation.load(); return current_generation.load();
} }
[[nodiscard]] inline auto& get_population_stats() [[nodiscard]] inline const auto& get_population_stats() const
{ {
return current_stats; return current_stats;
} }
~gp_program() [[nodiscard]] inline bool is_operator_ephemeral(operator_id id)
{ {
thread_helper.lifetime_over = true; return storage.ephemeral_leaf_operators.contains(static_cast<blt::size_t>(id));
thread_helper.barrier.notify_all();
thread_helper.thread_function_condition.notify_all();
for (auto& thread : thread_helper.threads)
{
if (thread->joinable())
thread->join();
}
auto* cpy = thread_execution_service.load(std::memory_order_acquire);
thread_execution_service = nullptr;
delete cpy;
} }
void kill() inline void set_operations(program_operator_storage_t op)
{ {
thread_helper.lifetime_over = true; storage = std::move(op);
}
template<blt::size_t size>
std::array<blt::size_t, size> get_best_indexes()
{
std::array<blt::size_t, size> arr;
std::vector<std::pair<blt::size_t, double>> values;
values.reserve(current_pop.get_individuals().size());
for (const auto& ind : blt::enumerate(current_pop.get_individuals()))
values.emplace_back(ind.first, ind.second.fitness.adjusted_fitness);
std::sort(values.begin(), values.end(), [](const auto& a, const auto& b) {
return a.second > b.second;
});
for (blt::size_t i = 0; i < size; i++)
arr[i] = values[i].first;
return arr;
}
template<blt::size_t size>
auto get_best_trees()
{
return convert_array<std::array<std::reference_wrapper<individual_t>, size>>(get_best_indexes<size>(),
[this](auto&& arr, blt::size_t index) -> tree_t& {
return current_pop.get_individuals()[arr[index]].tree;
},
std::make_integer_sequence<blt::size_t, size>());
}
template<blt::size_t size>
auto get_best_individuals()
{
return convert_array<std::array<std::reference_wrapper<individual_t>, size>>(get_best_indexes<size>(),
[this](auto&& arr, blt::size_t index) -> individual_t& {
return current_pop.get_individuals()[arr[index]];
},
std::make_integer_sequence<blt::size_t, size>());
} }
private: private:
type_provider& system;
operator_storage storage;
population_t current_pop;
population_stats current_stats{};
population_t next_pop;
std::atomic_uint64_t current_generation = 0;
std::atomic_bool fitness_should_exit = false;
std::function<blt::u64()> seed_func;
prog_config_t config{};
struct concurrency_storage
{
std::vector<std::unique_ptr<std::thread>> threads;
std::mutex thread_function_control;
std::mutex thread_generation_lock;
std::condition_variable thread_function_condition{};
std::atomic_uint64_t evaluation_left = 0;
std::atomic_uint64_t next_gen_left = 0;
std::atomic_bool lifetime_over = false;
blt::barrier barrier;
explicit concurrency_storage(blt::size_t threads): barrier(threads, lifetime_over)
{}
} thread_helper{config.threads == 0 ? std::thread::hardware_concurrency() : config.threads};
// for convenience, shouldn't decrease performance too much
std::atomic<std::function<void(blt::size_t)>*> thread_execution_service = nullptr;
inline selector_args get_selector_args(tracked_vector<tree_t>& next_pop_trees) inline selector_args get_selector_args(tracked_vector<tree_t>& next_pop_trees)
{ {
return {*this, next_pop_trees, current_pop, current_stats, config, get_random()}; return {*this, next_pop_trees, current_pop, current_stats, config, get_random()};
@ -767,12 +763,48 @@ namespace blt::gp
void evaluate_fitness_internal() void evaluate_fitness_internal()
{ {
statistic_history.push_back(current_stats);
current_stats.clear(); current_stats.clear();
thread_helper.evaluation_left.store(current_pop.get_individuals().size(), std::memory_order_release); thread_helper.evaluation_left.store(current_pop.get_individuals().size(), std::memory_order_release);
(*thread_execution_service)(0); (*thread_execution_service)(0);
current_stats.average_fitness = current_stats.overall_fitness / static_cast<double>(config.population_size); current_stats.average_fitness = current_stats.overall_fitness / static_cast<double>(config.population_size);
} }
private:
program_operator_storage_t storage;
std::function<blt::u64()> seed_func;
prog_config_t config{};
population_t current_pop;
population_t next_pop;
std::atomic_uint64_t current_generation = 0;
std::atomic_bool fitness_should_exit = false;
population_stats current_stats{};
std::vector<population_stats> statistic_history;
struct concurrency_storage
{
std::vector<std::unique_ptr<std::thread>> threads;
std::mutex thread_function_control{};
std::mutex thread_generation_lock{};
std::condition_variable thread_function_condition{};
std::atomic_uint64_t evaluation_left = 0;
std::atomic_uint64_t next_gen_left = 0;
std::atomic_bool lifetime_over = false;
blt::barrier barrier;
explicit concurrency_storage(blt::size_t threads): barrier(threads, lifetime_over)
{}
} thread_helper{config.threads == 0 ? std::thread::hardware_concurrency() : config.threads};
std::unique_ptr<std::function<void(blt::size_t)>> thread_execution_service = nullptr;
}; };
} }

View File

@ -88,8 +88,8 @@ namespace blt::gp
{ {
// auto state = tracker.start_measurement(); // auto state = tracker.start_measurement();
// crossover // crossover
auto& p1 = crossover_selection.select(program, current_pop, current_stats); auto& p1 = crossover_selection.select(program, current_pop);
auto& p2 = crossover_selection.select(program, current_pop, current_stats); auto& p2 = crossover_selection.select(program, current_pop);
auto results = config.crossover.get().apply(program, p1, p2); auto results = config.crossover.get().apply(program, p1, p2);
@ -110,7 +110,7 @@ namespace blt::gp
{ {
// auto state = tracker.start_measurement(); // auto state = tracker.start_measurement();
// mutation // mutation
auto& p = mutation_selection.select(program, current_pop, current_stats); auto& p = mutation_selection.select(program, current_pop);
next_pop.push_back(std::move(config.mutator.get().apply(program, p))); next_pop.push_back(std::move(config.mutator.get().apply(program, p)));
// tracker.stop_measurement(state); // tracker.stop_measurement(state);
// BLT_TRACE("Mutation Allocated %ld times with a total of %s", state.getAllocationDifference(), // BLT_TRACE("Mutation Allocated %ld times with a total of %s", state.getAllocationDifference(),
@ -122,7 +122,7 @@ namespace blt::gp
{ {
// auto state = tracker.start_measurement(); // auto state = tracker.start_measurement();
// reproduction // reproduction
auto& p = reproduction_selection.select(program, current_pop, current_stats); auto& p = reproduction_selection.select(program, current_pop);
next_pop.push_back(p); next_pop.push_back(p);
// tracker.stop_measurement(state); // tracker.stop_measurement(state);
// BLT_TRACE("Reproduction Allocated %ld times with a total of %s", state.getAllocationDifference(), // BLT_TRACE("Reproduction Allocated %ld times with a total of %s", state.getAllocationDifference(),
@ -147,9 +147,9 @@ namespace blt::gp
* @param stats the populations statistics * @param stats the populations statistics
* @return * @return
*/ */
virtual tree_t& select(gp_program& program, population_t& pop, population_stats& stats) = 0; virtual tree_t& select(gp_program& program, population_t& pop) = 0;
virtual void pre_process(gp_program&, population_t&, population_stats&) virtual void pre_process(gp_program&, population_t&)
{} {}
virtual ~selection_t() = default; virtual ~selection_t() = default;
@ -158,19 +158,19 @@ namespace blt::gp
class select_best_t : public selection_t class select_best_t : public selection_t
{ {
public: public:
tree_t& select(gp_program& program, population_t& pop, population_stats& stats) final; tree_t& select(gp_program& program, population_t& pop) final;
}; };
class select_worst_t : public selection_t class select_worst_t : public selection_t
{ {
public: public:
tree_t& select(gp_program& program, population_t& pop, population_stats& stats) final; tree_t& select(gp_program& program, population_t& pop) final;
}; };
class select_random_t : public selection_t class select_random_t : public selection_t
{ {
public: public:
tree_t& select(gp_program& program, population_t& pop, population_stats& stats) final; tree_t& select(gp_program& program, population_t& pop) final;
}; };
class select_tournament_t : public selection_t class select_tournament_t : public selection_t
@ -182,7 +182,7 @@ namespace blt::gp
BLT_ABORT("Unable to select with this size. Must select at least 1 individual_t!"); BLT_ABORT("Unable to select with this size. Must select at least 1 individual_t!");
} }
tree_t& select(gp_program& program, population_t& pop, population_stats& stats) final; tree_t& select(gp_program& program, population_t& pop) final;
private: private:
const blt::size_t selection_size; const blt::size_t selection_size;
@ -191,9 +191,7 @@ namespace blt::gp
class select_fitness_proportionate_t : public selection_t class select_fitness_proportionate_t : public selection_t
{ {
public: public:
void pre_process(gp_program& program, population_t& pop, population_stats& stats) final; tree_t& select(gp_program& program, population_t& pop) final;
tree_t& select(gp_program& program, population_t& pop, population_stats& stats) final;
}; };
} }

View File

@ -197,6 +197,17 @@ namespace blt::gp
struct population_stats struct population_stats
{ {
population_stats() = default;
population_stats(const population_stats& copy):
overall_fitness(copy.overall_fitness.load()), average_fitness(copy.average_fitness.load()), best_fitness(copy.best_fitness.load()),
worst_fitness(copy.worst_fitness.load())
{
normalized_fitness.reserve(copy.normalized_fitness.size());
for (auto v : copy.normalized_fitness)
normalized_fitness.push_back(v);
}
std::atomic<double> overall_fitness = 0; std::atomic<double> overall_fitness = 0;
std::atomic<double> average_fitness = 0; std::atomic<double> average_fitness = 0;
std::atomic<double> best_fitness = 0; std::atomic<double> best_fitness = 0;

View File

@ -85,12 +85,13 @@ namespace blt::gp
type_provider() = default; type_provider() = default;
template<typename T> template<typename T>
inline type register_type() inline void register_type()
{ {
if (has_type<T>())
return;
auto t = type::make_type<T>(types.size()); auto t = type::make_type<T>(types.size());
types.insert({blt::type_string_raw<T>(), t}); types.insert({blt::type_string_raw<T>(), t});
types_from_id[t.id()] = t; types_from_id[t.id()] = t;
return t;
} }
template<typename T> template<typename T>
@ -99,6 +100,11 @@ namespace blt::gp
return types[blt::type_string_raw<T>()]; return types[blt::type_string_raw<T>()];
} }
template<typename T>
inline bool has_type(){
return types.find(blt::type_string_raw<T>()) != types.end();
}
inline type get_type(type_id id) inline type get_type(type_id id)
{ {
return types_from_id[id]; return types_from_id[id];

View File

@ -64,10 +64,10 @@ namespace blt::gp
tree.get_operations().emplace_back( tree.get_operations().emplace_back(
args.program.get_typesystem().get_type(info.return_type).size(), args.program.get_typesystem().get_type(info.return_type).size(),
top.id, top.id,
args.program.is_static(top.id)); args.program.is_operator_ephemeral(top.id));
max_depth = std::max(max_depth, top.depth); max_depth = std::max(max_depth, top.depth);
if (args.program.is_static(top.id)) if (args.program.is_operator_ephemeral(top.id))
{ {
info.func(nullptr, tree.get_values(), tree.get_values()); info.func(nullptr, tree.get_values(), tree.get_values());
continue; continue;

View File

@ -76,7 +76,7 @@ namespace blt::gp
if (should_thread_terminate()) if (should_thread_terminate())
return; return;
} }
execution_function = thread_execution_service.load(std::memory_order_acquire); execution_function = thread_execution_service.get();
} }
if (execution_function != nullptr) if (execution_function != nullptr)
(*execution_function)(i); (*execution_function)(i);

View File

@ -21,7 +21,7 @@
namespace blt::gp namespace blt::gp
{ {
tree_t& select_best_t::select(gp_program&, population_t& pop, population_stats&) tree_t& select_best_t::select(gp_program&, population_t& pop)
{ {
auto& first = pop.get_individuals()[0]; auto& first = pop.get_individuals()[0];
double best_fitness = first.fitness.adjusted_fitness; double best_fitness = first.fitness.adjusted_fitness;
@ -37,7 +37,7 @@ namespace blt::gp
return *tree; return *tree;
} }
tree_t& select_worst_t::select(gp_program&, population_t& pop, population_stats&) tree_t& select_worst_t::select(gp_program&, population_t& pop)
{ {
auto& first = pop.get_individuals()[0]; auto& first = pop.get_individuals()[0];
double worst_fitness = first.fitness.adjusted_fitness; double worst_fitness = first.fitness.adjusted_fitness;
@ -53,12 +53,12 @@ namespace blt::gp
return *tree; return *tree;
} }
tree_t& select_random_t::select(gp_program& program, population_t& pop, population_stats&) tree_t& select_random_t::select(gp_program& program, population_t& pop)
{ {
return pop.get_individuals()[program.get_random().get_size_t(0ul, pop.get_individuals().size())].tree; return pop.get_individuals()[program.get_random().get_size_t(0ul, pop.get_individuals().size())].tree;
} }
tree_t& select_tournament_t::select(gp_program& program, population_t& pop, population_stats&) tree_t& select_tournament_t::select(gp_program& program, population_t& pop)
{ {
blt::u64 best = program.get_random().get_u64(0, pop.get_individuals().size()); blt::u64 best = program.get_random().get_u64(0, pop.get_individuals().size());
auto& i_ref = pop.get_individuals(); auto& i_ref = pop.get_individuals();
@ -71,8 +71,9 @@ namespace blt::gp
return i_ref[best].tree; return i_ref[best].tree;
} }
tree_t& select_fitness_proportionate_t::select(gp_program& program, population_t& pop, population_stats& stats) tree_t& select_fitness_proportionate_t::select(gp_program& program, population_t& pop)
{ {
auto& stats = program.get_population_stats();
auto choice = program.get_random().get_double(); auto choice = program.get_random().get_double();
for (const auto& [index, ref] : blt::enumerate(pop)) for (const auto& [index, ref] : blt::enumerate(pop))
{ {
@ -90,16 +91,4 @@ namespace blt::gp
return pop.get_individuals()[0].tree; return pop.get_individuals()[0].tree;
//BLT_ABORT("Unable to find individual"); //BLT_ABORT("Unable to find individual");
} }
void select_fitness_proportionate_t::pre_process(gp_program&, population_t& pop, population_stats& stats)
{
stats.normalized_fitness.clear();
double sum_of_prob = 0;
for (auto& ind : pop)
{
auto prob = (ind.fitness.adjusted_fitness / stats.overall_fitness);
stats.normalized_fitness.push_back(sum_of_prob + prob);
sum_of_prob += prob;
}
}
} }

View File

@ -468,7 +468,7 @@ namespace blt::gp
} }
// now finally update the type. // now finally update the type.
ops[c_node] = {program.get_typesystem().get_type(replacement_func_info.return_type).size(), random_replacement, ops[c_node] = {program.get_typesystem().get_type(replacement_func_info.return_type).size(), random_replacement,
program.is_static(random_replacement)}; program.is_operator_ephemeral(random_replacement)};
} }
#if BLT_DEBUG_LEVEL >= 2 #if BLT_DEBUG_LEVEL >= 2
if (!c.check(program, nullptr)) if (!c.check(program, nullptr))
@ -556,7 +556,7 @@ namespace blt::gp
ops.insert(ops.begin() + static_cast<blt::ptrdiff_t>(c_node), ops.insert(ops.begin() + static_cast<blt::ptrdiff_t>(c_node),
{program.get_typesystem().get_type(replacement_func_info.return_type).size(), {program.get_typesystem().get_type(replacement_func_info.return_type).size(),
random_replacement, program.is_static(random_replacement)}); random_replacement, program.is_operator_ephemeral(random_replacement)});
#if BLT_DEBUG_LEVEL >= 2 #if BLT_DEBUG_LEVEL >= 2
if (!c.check(program, nullptr)) if (!c.check(program, nullptr))

View File

@ -92,7 +92,7 @@ namespace blt::gp
if (print_literals) if (print_literals)
{ {
create_indent(out, indent, pretty_print); create_indent(out, indent, pretty_print);
if (program.is_static(v.id)) if (program.is_operator_ephemeral(v.id))
{ {
program.get_print_func(v.id)(out, reversed); program.get_print_func(v.id)(out, reversed);
reversed.pop_bytes(stack_allocator::aligned_size(v.type_size)); reversed.pop_bytes(stack_allocator::aligned_size(v.type_size));