diff --git a/.idea/editor.xml b/.idea/editor.xml
index eb796be..b0d69ef 100644
--- a/.idea/editor.xml
+++ b/.idea/editor.xml
@@ -479,26 +479,5 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 18e155d..2712a56 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -27,7 +27,7 @@ macro(compile_options target_name)
sanitizers(${target_name})
endmacro()
-project(blt-gp VERSION 0.2.3)
+project(blt-gp VERSION 0.2.4)
include(CTest)
diff --git a/examples/examples_base.h b/examples/examples_base.h
index 3258d46..6175021 100644
--- a/examples/examples_base.h
+++ b/examples/examples_base.h
@@ -49,6 +49,9 @@ namespace blt::gp::example
return *this;
}
+ gp_program& get_program() { return program; }
+ const gp_program& get_program() const { return program; }
+
protected:
gp_program program;
selection_t* crossover_sel = nullptr;
diff --git a/examples/rice_classification.h b/examples/rice_classification.h
index 8b3713a..1abb70a 100644
--- a/examples/rice_classification.h
+++ b/examples/rice_classification.h
@@ -44,95 +44,132 @@ namespace blt::gp::example
rice_type_t type;
};
- void make_operators()
- {
- static operation_t add{[](const float a, const float b) { return a + b; }, "add"};
- static operation_t sub([](const float a, const float b) { return a - b; }, "sub");
- static operation_t mul([](const float a, const float b) { return a * b; }, "mul");
- static operation_t pro_div([](const float a, const float b) { return b == 0.0f ? 1.0f : a / b; }, "div");
- static operation_t op_sin([](const float a) { return std::sin(a); }, "sin");
- static operation_t op_cos([](const float a) { return std::cos(a); }, "cos");
- static operation_t op_exp([](const float a) { return std::exp(a); }, "exp");
- static operation_t op_log([](const float a) { return a == 0.0f ? 0.0f : std::log(a); }, "log");
- static auto lit = blt::gp::operation_t([]()
- {
- return program.get_random().get_float(-32000.0f, 32000.0f);
- }, "lit").set_ephemeral();
-
- static operation_t op_area([](const rice_record& rice_data)
- {
- return rice_data.area;
- }, "area");
-
- static operation_t op_perimeter([](const rice_record& rice_data)
- {
- return rice_data.perimeter;
- }, "perimeter");
-
- static operation_t op_major_axis_length([](const rice_record& rice_data)
- {
- return rice_data.major_axis_length;
- }, "major_axis_length");
-
- static operation_t op_minor_axis_length([](const rice_record& rice_data)
- {
- return rice_data.minor_axis_length;
- }, "minor_axis_length");
-
- static operation_t op_eccentricity([](const rice_record& rice_data)
- {
- return rice_data.eccentricity;
- }, "eccentricity");
-
- static operation_t op_convex_area([](const rice_record& rice_data)
- {
- return rice_data.convex_area;
- }, "convex_area");
-
- static operation_t op_extent([](const rice_record& rice_data)
- {
- return rice_data.extent;
- }, "extent");
- }
-
- bool fitness_function(const tree_t& current_tree, fitness_t& fitness, size_t) const
- {
- for (auto& training_case : training_cases)
- {
- auto v = current_tree.get_evaluation_value(training_case);
- switch (training_case.type)
- {
- case rice_type_t::Cammeo:
- if (v >= 0)
- fitness.hits++;
- break;
- case rice_type_t::Osmancik:
- if (v < 0)
- fitness.hits++;
- break;
- }
- }
- fitness.raw_fitness = static_cast(fitness.hits);
- fitness.standardized_fitness = fitness.raw_fitness;
- fitness.adjusted_fitness = 1.0 - (1.0 / (1.0 + fitness.standardized_fitness));
- return static_cast(fitness.hits) == training_cases.size();
- }
-
- void load_rice_data(std::string_view rice_file_path);
-
+ bool fitness_function(const tree_t& current_tree, fitness_t& fitness, size_t) const;
public:
template
rice_classification_t(SEED&& seed, const prog_config_t& config): example_base_t{std::forward(seed), config}
{
+ BLT_INFO("Starting BLT-GP Rice Classification Example");
fitness_function_ref = [this](const tree_t& t, fitness_t& f, const size_t i)
{
return fitness_function(t, f, i);
};
}
+ void make_operators();
+
+ void load_rice_data(std::string_view rice_file_path);
+
+ confusion_matrix_t test_individual(const individual_t& individual) const;
+
+ void execute(const std::string_view rice_file_path)
+ {
+ load_rice_data(rice_file_path);
+ make_operators();
+ generate_initial_population();
+ run_generation_loop();
+ evaluate_individuals();
+ print_best();
+ print_average();
+ }
+
+ void run_generation_loop()
+ {
+ BLT_DEBUG("Begin Generation Loop");
+ while (!program.should_terminate())
+ {
+ BLT_TRACE("------------{Begin Generation %ld}------------", program.get_current_generation());
+ BLT_TRACE("Creating next generation");
+ program.create_next_generation();
+ BLT_TRACE("Move to next generation");
+ program.next_generation();
+ BLT_TRACE("Evaluate Fitness");
+ program.evaluate_fitness();
+ auto& stats = program.get_population_stats();
+ BLT_TRACE("Avg Fit: %lf, Best Fit: %lf, Worst Fit: %lf, Overall Fit: %lf",
+ stats.average_fitness.load(std::memory_order_relaxed), stats.best_fitness.load(std::memory_order_relaxed),
+ stats.worst_fitness.load(std::memory_order_relaxed), stats.overall_fitness.load(std::memory_order_relaxed));
+ BLT_TRACE("----------------------------------------------");
+ std::cout << std::endl;
+ }
+ }
+
+ void evaluate_individuals()
+ {
+ results.clear();
+ for (auto& i : program.get_current_pop().get_individuals())
+ results.emplace_back(test_individual(i), &i);
+ std::sort(results.begin(), results.end(), [](const auto& a, const auto& b)
+ {
+ return a.first > b.first;
+ });
+ }
+
+ void generate_initial_population()
+ {
+ BLT_DEBUG("Generate Initial Population");
+ static auto sel = select_tournament_t{};
+ if (crossover_sel == nullptr)
+ crossover_sel = &sel;
+ if (mutation_sel == nullptr)
+ mutation_sel = &sel;
+ if (reproduction_sel == nullptr)
+ reproduction_sel = &sel;
+ program.generate_population(program.get_typesystem().get_type().id(), fitness_function_ref, *crossover_sel, *mutation_sel,
+ *reproduction_sel);
+ }
+
+ void print_best(const size_t amount = 3)
+ {
+ BLT_INFO("Best results:");
+ for (size_t index = 0; index < amount; index++)
+ {
+ const auto& record = results[index].first;
+ const auto& i = *results[index].second;
+
+ BLT_INFO("Hits %ld, Total Cases %ld, Percent Hit: %lf", record.get_hits(), record.get_total(), record.get_percent_hit());
+ std::cout << record.pretty_print() << std::endl;
+ BLT_DEBUG("Fitness: %lf, stand: %lf, raw: %lf", i.fitness.adjusted_fitness, i.fitness.standardized_fitness, i.fitness.raw_fitness);
+ i.tree.print(program, std::cout);
+
+ std::cout << "\n";
+ }
+ }
+
+ void print_worst(const size_t amount = 3) const
+ {
+ BLT_INFO("Worst Results:");
+ for (size_t index = 0; index < amount; index++)
+ {
+ const auto& record = results[results.size() - 1 - index].first;
+ const auto& i = *results[results.size() - 1 - index].second;
+
+ BLT_INFO("Hits %ld, Total Cases %ld, Percent Hit: %lf", record.get_hits(), record.get_total(), record.get_percent_hit());
+ std::cout << record.pretty_print() << std::endl;
+ BLT_DEBUG("Fitness: %lf, stand: %lf, raw: %lf", i.fitness.adjusted_fitness, i.fitness.standardized_fitness, i.fitness.raw_fitness);
+
+ std::cout << "\n";
+ }
+ }
+
+ void print_average()
+ {
+ BLT_INFO("Average Results");
+ confusion_matrix_t avg{};
+ avg.set_name_a("cammeo");
+ avg.set_name_b("osmancik");
+ for (const auto& [matrix, _] : results)
+ avg += matrix;
+ avg /= results.size();
+ BLT_INFO("Hits %ld, Total Cases %ld, Percent Hit: %lf", avg.get_hits(), avg.get_total(), avg.get_percent_hit());
+ std::cout << avg.pretty_print() << std::endl;
+ std::cout << "\n";
+ }
+
private:
std::vector training_cases;
std::vector testing_cases;
+ std::vector> results;
};
}
diff --git a/examples/src/rice_classification.cpp b/examples/src/rice_classification.cpp
index e61d74c..f234030 100644
--- a/examples/src/rice_classification.cpp
+++ b/examples/src/rice_classification.cpp
@@ -41,8 +41,109 @@ blt::gp::prog_config_t config = blt::gp::prog_config_t()
.set_pop_size(500)
.set_thread_count(0);
-void blt::gp::example::rice_classification_t::load_rice_data(std::string_view rice_file_path)
+int main(int argc, const char** argv)
{
+ blt::arg_parse parser;
+ parser.addArgument(blt::arg_builder{"-f", "--file"}.setHelp("File for rice data. Should be in .arff format.").setRequired().build());
+
+ auto args = parser.parse_args(argc, argv);
+
+ if (!args.contains("file"))
+ {
+ BLT_WARN("Please provide path to file with -f or --file");
+ return 1;
+ }
+
+ auto rice_file_path = args.get("file");
+
+ blt::gp::example::rice_classification_t rice_classification{SEED_FUNC, config};
+
+ rice_classification.execute(rice_file_path);
+
+ return 0;
+}
+
+void blt::gp::example::rice_classification_t::make_operators()
+{
+ BLT_DEBUG("Setup Types and Operators");
+ static operation_t add{[](const float a, const float b) { return a + b; }, "add"};
+ static operation_t sub([](const float a, const float b) { return a - b; }, "sub");
+ static operation_t mul([](const float a, const float b) { return a * b; }, "mul");
+ static operation_t pro_div([](const float a, const float b) { return b == 0.0f ? 0.0f : a / b; }, "div");
+ static operation_t op_exp([](const float a) { return std::exp(a); }, "exp");
+ static operation_t op_log([](const float a) { return a == 0.0f ? 0.0f : std::log(a); }, "log");
+ static auto lit = operation_t([this]()
+ {
+ return program.get_random().get_float(-32000.0f, 32000.0f);
+ }, "lit").set_ephemeral();
+
+ static operation_t op_area([](const rice_record& rice_data)
+ {
+ return rice_data.area;
+ }, "area");
+
+ static operation_t op_perimeter([](const rice_record& rice_data)
+ {
+ return rice_data.perimeter;
+ }, "perimeter");
+
+ static operation_t op_major_axis_length([](const rice_record& rice_data)
+ {
+ return rice_data.major_axis_length;
+ }, "major_axis_length");
+
+ static operation_t op_minor_axis_length([](const rice_record& rice_data)
+ {
+ return rice_data.minor_axis_length;
+ }, "minor_axis_length");
+
+ static operation_t op_eccentricity([](const rice_record& rice_data)
+ {
+ return rice_data.eccentricity;
+ }, "eccentricity");
+
+ static operation_t op_convex_area([](const rice_record& rice_data)
+ {
+ return rice_data.convex_area;
+ }, "convex_area");
+
+ static operation_t op_extent([](const rice_record& rice_data)
+ {
+ return rice_data.extent;
+ }, "extent");
+
+ operator_builder builder{};
+ builder.build(add, sub, mul, pro_div, op_exp, op_log, lit, op_area, op_perimeter, op_major_axis_length,
+ op_minor_axis_length, op_eccentricity, op_convex_area, op_extent);
+ program.set_operations(builder.grab());
+}
+
+bool blt::gp::example::rice_classification_t::fitness_function(const tree_t& current_tree, fitness_t& fitness, size_t) const
+{
+ for (auto& training_case : training_cases)
+ {
+ auto v = current_tree.get_evaluation_value(training_case);
+ switch (training_case.type)
+ {
+ case rice_type_t::Cammeo:
+ if (v >= 0)
+ fitness.hits++;
+ break;
+ case rice_type_t::Osmancik:
+ if (v < 0)
+ fitness.hits++;
+ break;
+ }
+ }
+ fitness.raw_fitness = static_cast(fitness.hits);
+ fitness.standardized_fitness = fitness.raw_fitness;
+ fitness.adjusted_fitness = 1.0 - (1.0 / (1.0 + fitness.standardized_fitness));
+ return static_cast(fitness.hits) == training_cases.size();
+}
+
+void blt::gp::example::rice_classification_t::load_rice_data(const std::string_view rice_file_path)
+{
+ BLT_DEBUG("Setup Fitness cases");
auto rice_file_data = fs::getLinesFromFile(rice_file_path);
size_t index = 0;
while (!string::contains(rice_file_data[index++], "@DATA"))
@@ -50,7 +151,7 @@ void blt::gp::example::rice_classification_t::load_rice_data(std::string_view ri
}
std::vector c;
std::vector o;
- for (std::string_view v : iterate(rice_file_data).skip(index))
+ for (const std::string_view v : iterate(rice_file_data).skip(index))
{
auto data = string::split(v, ',');
rice_record r{
@@ -70,11 +171,11 @@ void blt::gp::example::rice_classification_t::load_rice_data(std::string_view ri
size_t total_records = c.size() + o.size();
size_t training_size = std::min(total_records / 3, 1000ul);
- for (blt::size_t i = 0; i < training_size; i++)
+ for (size_t i = 0; i < training_size; i++)
{
auto& random = program.get_random();
auto& vec = random.choice() ? c : o;
- auto pos = random.get_i64(0, static_cast(vec.size()));
+ auto pos = random.get_i64(0, static_cast(vec.size()));
training_cases.push_back(vec[pos]);
vec.erase(vec.begin() + pos);
}
@@ -84,197 +185,31 @@ void blt::gp::example::rice_classification_t::load_rice_data(std::string_view ri
BLT_INFO("Created training set of size %ld, testing set is of size %ld", training_size, testing_cases.size());
}
-struct test_results_t
+blt::gp::confusion_matrix_t blt::gp::example::rice_classification_t::test_individual(const individual_t& individual) const
{
- blt::size_t cc = 0;
- blt::size_t co = 0;
- blt::size_t oo = 0;
- blt::size_t oc = 0;
- blt::size_t hits = 0;
- blt::size_t size = 0;
- double percent_hit = 0;
-
- test_results_t& operator+=(const test_results_t& a)
- {
- cc += a.cc;
- co += a.co;
- oo += a.oo;
- oc += a.oc;
- hits += a.hits;
- size += a.size;
- percent_hit += a.percent_hit;
- return *this;
- }
-
- test_results_t& operator/=(blt::size_t s)
- {
- cc /= s;
- co /= s;
- oo /= s;
- oc /= s;
- hits /= s;
- size /= s;
- percent_hit /= static_cast(s);
- return *this;
- }
-
- friend bool operator<(const test_results_t& a, const test_results_t& b)
- {
- return a.hits < b.hits;
- }
-
- friend bool operator>(const test_results_t& a, const test_results_t& b)
- {
- return a.hits > b.hits;
- }
-};
-
-test_results_t test_individual(blt::gp::individual_t& i)
-{
- test_results_t results;
+ confusion_matrix_t confusion_matrix;
+ confusion_matrix.set_name_a("cammeo");
+ confusion_matrix.set_name_b("osmancik");
for (auto& testing_case : testing_cases)
{
- auto result = i.tree.get_evaluation_value(testing_case);
+ const auto result = individual.tree.get_evaluation_value(testing_case);
switch (testing_case.type)
{
case rice_type_t::Cammeo:
if (result >= 0)
- results.cc++; // cammeo cammeo
+ confusion_matrix.is_A_predicted_A(); // cammeo cammeo
else
- results.co++; // cammeo osmancik
+ confusion_matrix.is_A_predicted_B(); // cammeo osmancik
break;
case rice_type_t::Osmancik:
if (result < 0)
- results.oo++; // osmancik osmancik
+ confusion_matrix.is_B_predicted_B(); // osmancik osmancik
else
- results.oc++; // osmancik cammeo
+ confusion_matrix.is_B_predicted_A(); // osmancik cammeo
break;
}
}
- results.hits = results.cc + results.oo;
- results.size = testing_cases.size();
- results.percent_hit = static_cast(results.hits) / static_cast(results.size) * 100;
-
- return results;
-}
-
-int main(int argc, const char** argv)
-{
- blt::arg_parse parser;
- parser.addArgument(blt::arg_builder{"-f", "--file"}.setHelp("File for rice data. Should be in .arff format.").setRequired().build());
-
- auto args = parser.parse_args(argc, argv);
-
- if (!args.contains("file"))
- {
- BLT_WARN("Please provide path to file with -f or --file");
- return 1;
- }
-
- auto rice_file_path = args.get("file");
-
- BLT_INFO("Starting BLT-GP Rice Classification Example");
- BLT_START_INTERVAL("Rice Classification", "Main");
- BLT_DEBUG("Setup Fitness cases");
- load_rice_data(rice_file_path);
-
- BLT_DEBUG("Setup Types and Operators");
-
- blt::gp::operator_builder builder{};
- program.set_operations(builder.build(add, sub, mul, pro_div, op_exp, op_log, lit, op_area, op_perimeter, op_major_axis_length,
- op_minor_axis_length, op_eccentricity, op_convex_area, op_extent));
-
- BLT_DEBUG("Generate Initial Population");
- auto sel = blt::gp::select_tournament_t{};
- program.generate_population(program.get_typesystem().get_type().id(), fitness_function, sel, sel, sel);
-
- BLT_DEBUG("Begin Generation Loop");
- while (!program.should_terminate())
- {
- BLT_TRACE("------------{Begin Generation %ld}------------", program.get_current_generation());
- BLT_TRACE("Creating next generation");
- BLT_START_INTERVAL("Rice Classification", "Gen");
- program.create_next_generation();
- BLT_END_INTERVAL("Rice Classification", "Gen");
- BLT_TRACE("Move to next generation");
- BLT_START_INTERVAL("Rice Classification", "Fitness");
- program.next_generation();
- BLT_TRACE("Evaluate Fitness");
- program.evaluate_fitness();
- BLT_END_INTERVAL("Rice Classification", "Fitness");
- auto& stats = program.get_population_stats();
- BLT_TRACE("Stats:");
- BLT_TRACE("Average fitness: %lf", stats.average_fitness.load());
- BLT_TRACE("Best fitness: %lf", stats.best_fitness.load());
- BLT_TRACE("Worst fitness: %lf", stats.worst_fitness.load());
- BLT_TRACE("Overall fitness: %lf", stats.overall_fitness.load());
- BLT_TRACE("----------------------------------------------");
- std::cout << std::endl;
- }
-
- BLT_END_INTERVAL("Rice Classification", "Main");
-
- std::vector> results;
- for (auto& i : program.get_current_pop().get_individuals())
- results.emplace_back(test_individual(i), &i);
- std::sort(results.begin(), results.end(), [](const auto& a, const auto& b)
- {
- return a.first > b.first;
- });
-
- BLT_INFO("Best results:");
- for (blt::size_t index = 0; index < 3; index++)
- {
- const auto& record = results[index].first;
- const auto& i = *results[index].second;
-
- BLT_INFO("Hits %ld, Total Cases %ld, Percent Hit: %lf", record.hits, record.size, record.percent_hit);
- BLT_DEBUG("Cammeo Cammeo: %ld", record.cc);
- BLT_DEBUG("Cammeo Osmancik: %ld", record.co);
- BLT_DEBUG("Osmancik Osmancik: %ld", record.oo);
- BLT_DEBUG("Osmancik Cammeo: %ld", record.oc);
- BLT_DEBUG("Fitness: %lf, stand: %lf, raw: %lf", i.fitness.adjusted_fitness, i.fitness.standardized_fitness, i.fitness.raw_fitness);
- i.tree.print(program, std::cout);
-
- std::cout << "\n";
- }
-
- BLT_INFO("Worst Results:");
- for (blt::size_t index = 0; index < 3; index++)
- {
- const auto& record = results[results.size() - 1 - index].first;
- const auto& i = *results[results.size() - 1 - index].second;
-
- BLT_INFO("Hits %ld, Total Cases %ld, Percent Hit: %lf", record.hits, record.size, record.percent_hit);
- BLT_DEBUG("Cammeo Cammeo: %ld", record.cc);
- BLT_DEBUG("Cammeo Osmancik: %ld", record.co);
- BLT_DEBUG("Osmancik Osmancik: %ld", record.oo);
- BLT_DEBUG("Osmancik Cammeo: %ld", record.oc);
- BLT_DEBUG("Fitness: %lf, stand: %lf, raw: %lf", i.fitness.adjusted_fitness, i.fitness.standardized_fitness, i.fitness.raw_fitness);
-
- std::cout << "\n";
- }
-
- BLT_INFO("Average Results");
- test_results_t avg{};
- for (const auto& v : results)
- avg += v.first;
- avg /= results.size();
- BLT_INFO("Hits %ld, Total Cases %ld, Percent Hit: %lf", avg.hits, avg.size, avg.percent_hit);
- BLT_DEBUG("Cammeo Cammeo: %ld", avg.cc);
- BLT_DEBUG("Cammeo Osmancik: %ld", avg.co);
- BLT_DEBUG("Osmancik Osmancik: %ld", avg.oo);
- BLT_DEBUG("Osmancik Cammeo: %ld", avg.oc);
- std::cout << "\n";
-
- BLT_PRINT_PROFILE("Rice Classification", blt::PRINT_CYCLES | blt::PRINT_THREAD | blt::PRINT_WALL);
-
-#ifdef BLT_TRACK_ALLOCATIONS
- BLT_TRACE("Total Allocations: %ld times with a total of %s", blt::gp::tracker.getAllocations(),
- blt::byte_convert_t(blt::gp::tracker.getAllocatedBytes()).convert_to_nearest_type().to_pretty_string().c_str());
-#endif
-
- return 0;
+ return confusion_matrix;
}
diff --git a/examples/src/symbolic_regression.cpp b/examples/src/symbolic_regression.cpp
index 55e34ea..342ff97 100644
--- a/examples/src/symbolic_regression.cpp
+++ b/examples/src/symbolic_regression.cpp
@@ -60,4 +60,50 @@ int main()
regression.execute();
return 0;
-}
\ No newline at end of file
+}
+
+bool blt::gp::example::symbolic_regression_t::fitness_function(const tree_t& current_tree, fitness_t& fitness, size_t) const
+{
+ constexpr static double value_cutoff = 1.e15;
+ for (auto& fitness_case : training_cases)
+ {
+ const auto diff = std::abs(fitness_case.y - current_tree.get_evaluation_value(fitness_case));
+ if (diff < value_cutoff)
+ {
+ fitness.raw_fitness += diff;
+ if (diff <= 0.01)
+ fitness.hits++;
+ }
+ else
+ fitness.raw_fitness += value_cutoff;
+ }
+ fitness.standardized_fitness = fitness.raw_fitness;
+ fitness.adjusted_fitness = (1.0 / (1.0 + fitness.standardized_fitness));
+ return static_cast(fitness.hits) == training_cases.size();
+}
+
+void blt::gp::example::symbolic_regression_t::setup_operations()
+{
+ BLT_DEBUG("Setup Types and Operators");
+ static operation_t add{[](const float a, const float b) { return a + b; }, "add"};
+ static operation_t sub([](const float a, const float b) { return a - b; }, "sub");
+ static operation_t mul([](const float a, const float b) { return a * b; }, "mul");
+ static operation_t pro_div([](const float a, const float b) { return b == 0.0f ? 0.0f : a / b; }, "div");
+ static operation_t op_sin([](const float a) { return std::sin(a); }, "sin");
+ static operation_t op_cos([](const float a) { return std::cos(a); }, "cos");
+ static operation_t op_exp([](const float a) { return std::exp(a); }, "exp");
+ static operation_t op_log([](const float a) { return a == 0.0f ? 0.0f : std::log(a); }, "log");
+ static auto lit = operation_t([this]()
+ {
+ return program.get_random().get_float(-1.0f, 1.0f);
+ }, "lit").set_ephemeral();
+
+ static operation_t op_x([](const context& context)
+ {
+ return context.x;
+ }, "x");
+
+ operator_builder builder{};
+ builder.build(add, sub, mul, pro_div, op_sin, op_cos, op_exp, op_log, lit, op_x);
+ program.set_operations(builder.grab());
+}
diff --git a/examples/symbolic_regression.h b/examples/symbolic_regression.h
index efa4446..3d6ed4b 100644
--- a/examples/symbolic_regression.h
+++ b/examples/symbolic_regression.h
@@ -35,25 +35,7 @@ namespace blt::gp::example
};
private:
- bool fitness_function(const tree_t& current_tree, fitness_t& fitness, size_t) const
- {
- constexpr static double value_cutoff = 1.e15;
- for (auto& fitness_case : training_cases)
- {
- const auto diff = std::abs(fitness_case.y - current_tree.get_evaluation_value(fitness_case));
- if (diff < value_cutoff)
- {
- fitness.raw_fitness += diff;
- if (diff <= 0.01)
- fitness.hits++;
- }
- else
- fitness.raw_fitness += value_cutoff;
- }
- fitness.standardized_fitness = fitness.raw_fitness;
- fitness.adjusted_fitness = (1.0 / (1.0 + fitness.standardized_fitness));
- return static_cast(fitness.hits) == training_cases.size();
- }
+ bool fitness_function(const tree_t& current_tree, fitness_t& fitness, size_t) const;
static float example_function(const float x)
{
@@ -61,7 +43,7 @@ namespace blt::gp::example
}
public:
- template
+ template
symbolic_regression_t(SEED seed, const prog_config_t& config): example_base_t{std::forward(seed), config}
{
BLT_INFO("Starting BLT-GP Symbolic Regression Example");
@@ -81,37 +63,7 @@ namespace blt::gp::example
};
}
- template
- auto make_operations(operator_builder& builder)
- {
- static operation_t add{[](const float a, const float b) { return a + b; }, "add"};
- static operation_t sub([](const float a, const float b) { return a - b; }, "sub");
- static operation_t mul([](const float a, const float b) { return a * b; }, "mul");
- static operation_t pro_div([](const float a, const float b) { return b == 0.0f ? 1.0f : a / b; }, "div");
- static operation_t op_sin([](const float a) { return std::sin(a); }, "sin");
- static operation_t op_cos([](const float a) { return std::cos(a); }, "cos");
- static operation_t op_exp([](const float a) { return std::exp(a); }, "exp");
- static operation_t op_log([](const float a) { return a == 0.0f ? 0.0f : std::log(a); }, "log");
- static auto lit = operation_t([this]()
- {
- return program.get_random().get_float(-1.0f, 1.0f);
- }, "lit").set_ephemeral();
-
- static operation_t op_x([](const context& context)
- {
- return context.x;
- }, "x");
-
- return builder.build(add, sub, mul, pro_div, op_sin, op_cos, op_exp, op_log, lit, op_x);
- }
-
- void setup_operations()
- {
- BLT_DEBUG("Setup Types and Operators");
- operator_builder builder{};
- make_operations(builder);
- program.set_operations(builder.grab());
- }
+ void setup_operations();
void generate_initial_population()
{
diff --git a/include/blt/gp/util/statistics.h b/include/blt/gp/util/statistics.h
index 600aa31..3fd868d 100644
--- a/include/blt/gp/util/statistics.h
+++ b/include/blt/gp/util/statistics.h
@@ -86,6 +86,26 @@ namespace blt::gp
return is_B_pred_A;
}
+ [[nodiscard]] u64 get_hits() const
+ {
+ return is_A_pred_A + is_B_pred_B;
+ }
+
+ [[nodiscard]] u64 get_misses() const
+ {
+ return is_B_pred_A + is_A_pred_B;
+ }
+
+ [[nodiscard]] u64 get_total() const
+ {
+ return get_hits() + get_misses();
+ }
+
+ [[nodiscard]] double get_percent_hit() const
+ {
+ return static_cast(get_hits()) / static_cast(get_total());
+ }
+
confusion_matrix_t& operator+=(const confusion_matrix_t& op)
{
is_A_pred_A += op.is_A_pred_A;
@@ -118,6 +138,16 @@ namespace blt::gp
return result;
}
+ friend bool operator<(const confusion_matrix_t& a, const confusion_matrix_t& b)
+ {
+ return a.get_percent_hit() < b.get_percent_hit();
+ }
+
+ friend bool operator>(const confusion_matrix_t& a, const confusion_matrix_t& b)
+ {
+ return a.get_percent_hit() > b.get_percent_hit();
+ }
+
[[nodiscard]] std::string pretty_print(const std::string& table_name = "Confusion Matrix") const;
private:
@@ -129,40 +159,6 @@ namespace blt::gp
std::string name_B = "B";
};
- struct classifier_results_t : public confusion_matrix_t
- {
- public:
- [[nodiscard]] u64 get_hits() const
- {
- return hits;
- }
-
- [[nodiscard]] u64 get_size() const
- {
- return size;
- }
-
- [[nodiscard]] double get_percent_hit() const
- {
- return static_cast(hits) / static_cast(hits + misses);
- }
-
- void hit()
- {
- ++hits;
- }
-
- void miss()
- {
- ++misses;
- }
-
-
- private:
- u64 hits = 0;
- u64 misses = 0;
- };
-
struct population_stats
{
population_stats() = default;
diff --git a/src/util/statistics.cpp b/src/util/statistics.cpp
index b824d06..7b2f708 100644
--- a/src/util/statistics.cpp
+++ b/src/util/statistics.cpp
@@ -26,18 +26,18 @@ namespace blt::gp {
string::TableFormatter formatter{table_name};
formatter.addColumn("Predicted " + name_A);
formatter.addColumn("Predicted " + name_B);
- formatter.addColumn("");
+ formatter.addColumn("Actual Class");
string::TableRow row;
row.rowValues.push_back(std::to_string(is_A_pred_A));
row.rowValues.push_back(std::to_string(is_A_pred_B));
- row.rowValues.push_back("Actual" + name_A);
+ row.rowValues.push_back(name_A);
formatter.addRow(row);
string::TableRow row2;
row2.rowValues.push_back(std::to_string(is_B_pred_A));
row2.rowValues.push_back(std::to_string(is_B_pred_B));
- row2.rowValues.push_back("Actual" + name_B);
+ row2.rowValues.push_back(name_B);
formatter.addRow(row2);
auto tbl = formatter.createTable(true, true);