blt-gp/examples/symbolic_regression.cpp

190 lines
8.8 KiB
C++
Raw Normal View History

2024-07-11 04:11:24 -04:00
/*
* <Short Description>
* Copyright (C) 2024 Brett Terpstra
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
#include <blt/gp/program.h>
2024-07-11 21:14:23 -04:00
#include <blt/profiling/profiler_v2.h>
2024-07-11 04:11:24 -04:00
#include <blt/gp/tree.h>
#include <blt/std/logging.h>
2024-09-06 00:21:55 -04:00
#include <blt/format/format.h>
2024-07-11 04:11:24 -04:00
#include <iostream>
#include "operations_common.h"
2024-09-02 15:41:20 -04:00
#include "blt/math/averages.h"
2024-07-11 04:11:24 -04:00
2024-08-06 13:45:24 -04:00
//static constexpr long SEED = 41912;
static const unsigned long SEED = std::random_device()();
2024-07-11 04:11:24 -04:00
struct context
{
float x, y;
};
2024-08-27 18:07:22 -04:00
std::array<context, 200> training_cases;
2024-07-11 04:11:24 -04:00
2024-08-18 03:32:42 -04:00
blt::gp::mutation_t mut;
2024-07-11 04:11:24 -04:00
blt::gp::prog_config_t config = blt::gp::prog_config_t()
.set_initial_min_tree_size(2)
.set_initial_max_tree_size(6)
2024-09-05 02:20:03 -04:00
.set_elite_count(2)
2024-07-16 23:45:02 -04:00
.set_crossover_chance(0.9)
.set_mutation_chance(0.1)
.set_reproduction_chance(0)
2024-07-11 04:11:24 -04:00
.set_max_generations(50)
2024-09-05 02:20:03 -04:00
.set_pop_size(500)
2024-08-18 01:28:23 -04:00
.set_thread_count(0);
2024-07-11 04:11:24 -04:00
2024-08-30 23:27:25 -04:00
blt::gp::gp_program program{SEED, config};
2024-07-11 04:11:24 -04:00
2024-08-20 13:07:33 -04:00
auto lit = blt::gp::operation_t([]() {
2024-08-31 00:05:04 -04:00
return program.get_random().get_float(-1.0f, 1.0f);
2024-08-20 13:07:33 -04:00
}, "lit").set_ephemeral();
2024-08-21 00:54:39 -04:00
2024-07-11 04:11:24 -04:00
blt::gp::operation_t op_x([](const context& context) {
return context.x;
}, "x");
2024-07-11 21:14:23 -04:00
constexpr auto fitness_function = [](blt::gp::tree_t& current_tree, blt::gp::fitness_t& fitness, blt::size_t) {
2024-07-11 04:11:24 -04:00
constexpr double value_cutoff = 1.e15;
2024-08-27 18:07:22 -04:00
for (auto& fitness_case : training_cases)
2024-07-11 04:11:24 -04:00
{
auto diff = std::abs(fitness_case.y - current_tree.get_evaluation_value<float>(&fitness_case));
if (diff < value_cutoff)
{
fitness.raw_fitness += diff;
2024-09-05 02:20:03 -04:00
if (diff <= 0.01)
2024-07-11 04:11:24 -04:00
fitness.hits++;
} else
fitness.raw_fitness += value_cutoff;
}
fitness.standardized_fitness = fitness.raw_fitness;
fitness.adjusted_fitness = (1.0 / (1.0 + fitness.standardized_fitness));
2024-08-27 18:07:22 -04:00
return static_cast<blt::size_t>(fitness.hits) == training_cases.size();
2024-07-11 04:11:24 -04:00
};
float example_function(float x)
{
return x * x * x * x + x * x * x + x * x + x;
}
2024-09-06 00:21:55 -04:00
BLT_MAKE_CONFIG_TYPE(test, BLT_MAKE_GETTER_AND_SETTER(int, silly) BLT_MAKE_GETTER(int, billy));
2024-07-11 04:11:24 -04:00
int main()
{
2024-09-06 00:21:55 -04:00
test t;
2024-07-13 15:36:49 -04:00
BLT_INFO("Starting BLT-GP Symbolic Regression Example");
2024-07-11 21:14:23 -04:00
BLT_START_INTERVAL("Symbolic Regression", "Main");
2024-07-13 15:36:49 -04:00
BLT_DEBUG("Setup Fitness cases");
2024-08-27 18:07:22 -04:00
for (auto& fitness_case : training_cases)
2024-07-11 04:11:24 -04:00
{
constexpr float range = 10;
constexpr float half_range = range / 2.0;
auto x = program.get_random().get_float(-half_range, half_range);
auto y = example_function(x);
fitness_case = {x, y};
}
2024-07-13 15:36:49 -04:00
BLT_DEBUG("Setup Types and Operators");
2024-08-30 23:27:25 -04:00
blt::gp::operator_builder<context> builder{};
2024-08-20 13:07:33 -04:00
program.set_operations(builder.build(add, sub, mul, pro_div, op_sin, op_cos, op_exp, op_log, lit, op_x));
2024-07-11 04:11:24 -04:00
2024-07-13 15:36:49 -04:00
BLT_DEBUG("Generate Initial Population");
2024-09-05 02:20:03 -04:00
auto sel = blt::gp::select_fitness_proportionate_t{};
2024-08-30 23:27:25 -04:00
program.generate_population(program.get_typesystem().get_type<float>().id(), fitness_function, sel, sel, sel);
2024-07-11 04:11:24 -04:00
2024-07-13 15:36:49 -04:00
BLT_DEBUG("Begin Generation Loop");
2024-07-11 04:11:24 -04:00
while (!program.should_terminate())
{
2024-07-13 15:36:49 -04:00
BLT_TRACE("------------{Begin Generation %ld}------------", program.get_current_generation());
2024-08-22 02:10:55 -04:00
BLT_TRACE("Creating next generation");
2024-07-11 21:14:23 -04:00
BLT_START_INTERVAL("Symbolic Regression", "Gen");
program.create_next_generation();
2024-07-11 21:14:23 -04:00
BLT_END_INTERVAL("Symbolic Regression", "Gen");
2024-07-13 15:36:49 -04:00
BLT_TRACE("Move to next generation");
2024-07-11 21:14:23 -04:00
BLT_START_INTERVAL("Symbolic Regression", "Fitness");
2024-07-11 04:11:24 -04:00
program.next_generation();
2024-07-13 15:36:49 -04:00
BLT_TRACE("Evaluate Fitness");
2024-07-11 04:11:24 -04:00
program.evaluate_fitness();
2024-07-11 21:14:23 -04:00
BLT_END_INTERVAL("Symbolic Regression", "Fitness");
2024-07-13 15:36:49 -04:00
BLT_TRACE("----------------------------------------------");
std::cout << std::endl;
2024-07-11 04:11:24 -04:00
}
2024-07-11 21:14:23 -04:00
BLT_END_INTERVAL("Symbolic Regression", "Main");
2024-07-11 04:11:24 -04:00
auto best = program.get_best_individuals<3>();
BLT_INFO("Best approximations:");
for (auto& i_ref : best)
{
auto& i = i_ref.get();
BLT_DEBUG("Fitness: %lf, stand: %lf, raw: %lf", i.fitness.adjusted_fitness, i.fitness.standardized_fitness, i.fitness.raw_fitness);
i.tree.print(program, std::cout);
std::cout << "\n";
}
2024-07-11 21:14:23 -04:00
auto& stats = program.get_population_stats();
BLT_INFO("Stats:");
BLT_INFO("Average fitness: %lf", stats.average_fitness.load());
BLT_INFO("Best fitness: %lf", stats.best_fitness.load());
BLT_INFO("Worst fitness: %lf", stats.worst_fitness.load());
BLT_INFO("Overall fitness: %lf", stats.overall_fitness.load());
2024-07-11 13:51:14 -04:00
// TODO: make stats helper
2024-07-11 04:11:24 -04:00
2024-07-11 21:14:23 -04:00
BLT_PRINT_PROFILE("Symbolic Regression", blt::PRINT_CYCLES | blt::PRINT_THREAD | blt::PRINT_WALL);
2024-08-23 21:13:33 -04:00
2024-08-25 17:01:06 -04:00
#ifdef BLT_TRACK_ALLOCATIONS
2024-09-02 03:08:03 -04:00
BLT_TRACE("Total Allocations: %ld times with a total of %s, peak allocated bytes %s", blt::gp::tracker.getAllocations(),
blt::byte_convert_t(blt::gp::tracker.getAllocatedBytes()).convert_to_nearest_type().to_pretty_string().c_str(),
blt::byte_convert_t(blt::gp::tracker.getPeakAllocatedBytes()).convert_to_nearest_type().to_pretty_string().c_str());
2024-09-02 15:41:20 -04:00
BLT_TRACE("------------------------------------------------------");
auto evaluation_calls_v = blt::gp::evaluation_calls.get_calls();
auto evaluation_allocations_v = blt::gp::evaluation_allocations.get_calls();
BLT_TRACE("Total Evaluation Calls: %ld; Peak Bytes Allocated %s", evaluation_calls_v,
blt::string::bytes_to_pretty(blt::gp::evaluation_calls.get_value()).c_str());
BLT_TRACE("Total Evaluation Allocations: %ld; Bytes %s; Average %s", evaluation_allocations_v,
blt::string::bytes_to_pretty(blt::gp::evaluation_allocations.get_value()).c_str(),
blt::string::bytes_to_pretty(blt::average(blt::gp::evaluation_allocations.get_value(), evaluation_allocations_v)).c_str());
BLT_TRACE("Percent Evaluation calls allocate? %lf%%", blt::average(evaluation_allocations_v, evaluation_calls_v) * 100);
BLT_TRACE("------------------------------------------------------");
2024-09-01 21:55:29 -04:00
auto crossover_calls_v = blt::gp::crossover_calls.get_calls();
auto crossover_allocations_v = blt::gp::crossover_allocations.get_calls();
auto mutation_calls_v = blt::gp::mutation_calls.get_calls();
auto mutation_allocations_v = blt::gp::mutation_allocations.get_calls();
auto reproduction_calls_v = blt::gp::reproduction_calls.get_calls();
auto reproduction_allocations_v = blt::gp::reproduction_allocations.get_calls();
2024-09-02 15:41:20 -04:00
BLT_TRACE("Total Crossover Calls: %ld; Peak Bytes Allocated %s", crossover_calls_v,
blt::string::bytes_to_pretty(blt::gp::crossover_calls.get_value()).c_str());
BLT_TRACE("Total Mutation Calls: %ld; Peak Bytes Allocated %s", mutation_calls_v,
blt::string::bytes_to_pretty(blt::gp::mutation_calls.get_value()).c_str());
BLT_TRACE("Total Reproduction Calls: %ld; Peak Bytes Allocated %s", reproduction_calls_v,
blt::string::bytes_to_pretty(blt::gp::reproduction_calls.get_value()).c_str());
BLT_TRACE("Total Crossover Allocations: %ld; Bytes %s; Average %s", crossover_allocations_v,
blt::string::bytes_to_pretty(blt::gp::crossover_allocations.get_value()).c_str(),
blt::string::bytes_to_pretty(blt::average(blt::gp::crossover_allocations.get_value(), crossover_allocations_v)).c_str());
BLT_TRACE("Total Mutation Allocations: %ld; Bytes %s; Average %s", mutation_allocations_v,
blt::string::bytes_to_pretty(blt::gp::mutation_allocations.get_value()).c_str(),
blt::string::bytes_to_pretty(blt::average(blt::gp::mutation_allocations.get_value(), mutation_allocations_v)).c_str());
BLT_TRACE("Total Reproduction Allocations: %ld; Bytes %s; Average %s", reproduction_allocations_v,
blt::string::bytes_to_pretty(blt::gp::reproduction_allocations.get_value()).c_str(),
blt::string::bytes_to_pretty(blt::average(blt::gp::reproduction_allocations.get_value(), reproduction_allocations_v)).c_str());
BLT_TRACE("Percent Crossover calls allocate? %lf%%", blt::average(crossover_allocations_v, crossover_calls_v) * 100);
BLT_TRACE("Percent Mutation calls allocate? %lf%%", blt::average(mutation_allocations_v, mutation_calls_v) * 100);
BLT_TRACE("Percent Reproduction calls allocate? %lf%%", blt::average(reproduction_allocations_v, reproduction_calls_v) * 100);
2024-08-25 17:01:06 -04:00
#endif
2024-07-11 21:14:23 -04:00
2024-07-11 04:11:24 -04:00
return 0;
}