hmm works?
parent
68dc109dad
commit
ac76b3c5df
|
@ -1,5 +1,5 @@
|
||||||
cmake_minimum_required(VERSION 3.25)
|
cmake_minimum_required(VERSION 3.25)
|
||||||
project(blt-gp VERSION 0.1.27)
|
project(blt-gp VERSION 0.1.28)
|
||||||
|
|
||||||
include(CTest)
|
include(CTest)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
/*
|
/*
|
||||||
* <Short Description>
|
* This rice classification example uses data from the UC Irvine Machine Learning repository.
|
||||||
|
* The data for this example can be found at:
|
||||||
|
* https://archive.ics.uci.edu/dataset/545/rice+cammeo+and+osmancik
|
||||||
|
*
|
||||||
* Copyright (C) 2024 Brett Terpstra
|
* Copyright (C) 2024 Brett Terpstra
|
||||||
*
|
*
|
||||||
* This program is free software: you can redistribute it and/or modify
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
@ -25,8 +28,6 @@
|
||||||
#include "operations_common.h"
|
#include "operations_common.h"
|
||||||
#include "blt/fs/loader.h"
|
#include "blt/fs/loader.h"
|
||||||
|
|
||||||
|
|
||||||
//static constexpr long SEED = 41912;
|
|
||||||
static const unsigned long SEED = std::random_device()();
|
static const unsigned long SEED = std::random_device()();
|
||||||
|
|
||||||
enum class rice_type_t
|
enum class rice_type_t
|
||||||
|
@ -47,7 +48,7 @@ struct rice_record
|
||||||
rice_type_t type;
|
rice_type_t type;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<rice_record> fitness_cases;
|
std::vector<rice_record> training_cases;
|
||||||
std::vector<rice_record> testing_cases;
|
std::vector<rice_record> testing_cases;
|
||||||
|
|
||||||
blt::gp::prog_config_t config = blt::gp::prog_config_t()
|
blt::gp::prog_config_t config = blt::gp::prog_config_t()
|
||||||
|
@ -58,7 +59,7 @@ blt::gp::prog_config_t config = blt::gp::prog_config_t()
|
||||||
.set_mutation_chance(0.1)
|
.set_mutation_chance(0.1)
|
||||||
.set_reproduction_chance(0)
|
.set_reproduction_chance(0)
|
||||||
.set_max_generations(50)
|
.set_max_generations(50)
|
||||||
.set_pop_size(500)
|
.set_pop_size(5000)
|
||||||
.set_thread_count(0);
|
.set_thread_count(0);
|
||||||
|
|
||||||
blt::gp::type_provider type_system;
|
blt::gp::type_provider type_system;
|
||||||
|
@ -97,21 +98,25 @@ blt::gp::operation_t op_extent([](const rice_record& rice_data) {
|
||||||
}, "extent");
|
}, "extent");
|
||||||
|
|
||||||
constexpr auto fitness_function = [](blt::gp::tree_t& current_tree, blt::gp::fitness_t& fitness, blt::size_t) {
|
constexpr auto fitness_function = [](blt::gp::tree_t& current_tree, blt::gp::fitness_t& fitness, blt::size_t) {
|
||||||
constexpr double value_cutoff = 1.e15;
|
for (auto& training_case : training_cases)
|
||||||
for (auto& fitness_case : fitness_cases)
|
|
||||||
{
|
{
|
||||||
auto diff = std::abs(fitness_case.y - current_tree.get_evaluation_value<float>(&fitness_case));
|
auto v = current_tree.get_evaluation_value<float>(&training_case);
|
||||||
if (diff < value_cutoff)
|
switch (training_case.type)
|
||||||
{
|
{
|
||||||
fitness.raw_fitness += diff;
|
case rice_type_t::Cammeo:
|
||||||
if (diff < 0.01)
|
if (v >= 0)
|
||||||
fitness.hits++;
|
fitness.hits++;
|
||||||
} else
|
break;
|
||||||
fitness.raw_fitness += value_cutoff;
|
case rice_type_t::Osmancik:
|
||||||
|
if (v < 0)
|
||||||
|
fitness.hits++;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
fitness.raw_fitness = static_cast<double>(fitness.hits) / static_cast<double>(training_cases.size());
|
||||||
fitness.standardized_fitness = fitness.raw_fitness;
|
fitness.standardized_fitness = fitness.raw_fitness;
|
||||||
fitness.adjusted_fitness = (1.0 / (1.0 + fitness.standardized_fitness));
|
fitness.adjusted_fitness = 1.0 - (1.0 / (1.0 + fitness.standardized_fitness));
|
||||||
return static_cast<blt::size_t>(fitness.hits) == fitness_cases.size();
|
return static_cast<blt::size_t>(fitness.hits) == training_cases.size();
|
||||||
};
|
};
|
||||||
|
|
||||||
void load_rice_data(std::string_view rice_file_path)
|
void load_rice_data(std::string_view rice_file_path)
|
||||||
|
@ -126,30 +131,32 @@ void load_rice_data(std::string_view rice_file_path)
|
||||||
{
|
{
|
||||||
auto data = blt::string::split(v, ',');
|
auto data = blt::string::split(v, ',');
|
||||||
rice_record r{std::stof(data[0]), std::stof(data[1]), std::stof(data[2]), std::stof(data[3]), std::stof(data[4]), std::stof(data[5]),
|
rice_record r{std::stof(data[0]), std::stof(data[1]), std::stof(data[2]), std::stof(data[3]), std::stof(data[4]), std::stof(data[5]),
|
||||||
std::stof(data[6])};
|
std::stof(data[6]), blt::string::contains(data[7], "Cammeo") ? rice_type_t::Cammeo : rice_type_t::Osmancik};
|
||||||
if (blt::string::contains(data[7], "Cammeo"))
|
switch (r.type)
|
||||||
{
|
{
|
||||||
r.type = rice_type_t::Cammeo;
|
case rice_type_t::Cammeo:
|
||||||
c.push_back(r);
|
c.push_back(r);
|
||||||
} else
|
break;
|
||||||
{
|
case rice_type_t::Osmancik:
|
||||||
r.type = rice_type_t::Osmancik;
|
|
||||||
o.push_back(r);
|
o.push_back(r);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
blt::size_t total_records = c.size() + o.size();
|
blt::size_t total_records = c.size() + o.size();
|
||||||
blt::size_t training_size = total_records / 3;
|
blt::size_t training_size = std::min(total_records / 3, 1000ul);
|
||||||
for (blt::size_t i = 0; i < training_size; i++)
|
for (blt::size_t i = 0; i < training_size; i++)
|
||||||
{
|
{
|
||||||
auto& random = program.get_random();
|
auto& random = program.get_random();
|
||||||
auto& vec = random.choice() ? c : o;
|
auto& vec = random.choice() ? c : o;
|
||||||
auto pos = random.get_i64(0, static_cast<blt::i64>(vec.size()));
|
auto pos = random.get_i64(0, static_cast<blt::i64>(vec.size()));
|
||||||
fitness_cases.push_back(vec[pos]);
|
training_cases.push_back(vec[pos]);
|
||||||
vec.erase(vec.begin() + pos);
|
vec.erase(vec.begin() + pos);
|
||||||
}
|
}
|
||||||
testing_cases.insert(testing_cases.end(), c.begin(), c.end());
|
testing_cases.insert(testing_cases.end(), c.begin(), c.end());
|
||||||
testing_cases.insert(testing_cases.end(), o.begin(), o.end());
|
testing_cases.insert(testing_cases.end(), o.begin(), o.end());
|
||||||
std::shuffle(testing_cases.begin(), testing_cases.end(), program.get_random());
|
std::shuffle(testing_cases.begin(), testing_cases.end(), program.get_random());
|
||||||
|
BLT_INFO("Created training set of size %ld, testing set is of size %ld", training_size, testing_cases.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, const char** argv)
|
int main(int argc, const char** argv)
|
||||||
|
@ -159,7 +166,13 @@ int main(int argc, const char** argv)
|
||||||
|
|
||||||
auto args = parser.parse_args(argc, argv);
|
auto args = parser.parse_args(argc, argv);
|
||||||
|
|
||||||
auto rice_file_path = args.get<std::string>("-f");
|
if (!args.contains("file"))
|
||||||
|
{
|
||||||
|
BLT_WARN("Please provide path to file with -f or --file");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rice_file_path = args.get<std::string>("file");
|
||||||
|
|
||||||
BLT_INFO("Starting BLT-GP Rice Classification Example");
|
BLT_INFO("Starting BLT-GP Rice Classification Example");
|
||||||
BLT_START_INTERVAL("Rice Classification", "Main");
|
BLT_START_INTERVAL("Rice Classification", "Main");
|
||||||
|
@ -170,7 +183,8 @@ int main(int argc, const char** argv)
|
||||||
type_system.register_type<float>();
|
type_system.register_type<float>();
|
||||||
|
|
||||||
blt::gp::operator_builder<rice_record> builder{type_system};
|
blt::gp::operator_builder<rice_record> builder{type_system};
|
||||||
program.set_operations(builder.build(add, sub, mul, pro_div, op_sin, op_cos, op_exp, op_log, lit, op_x));
|
program.set_operations(builder.build(add, sub, mul, pro_div, op_exp, op_log, lit, op_area, op_perimeter, op_major_axis_length,
|
||||||
|
op_minor_axis_length, op_eccentricity, op_convex_area, op_extent));
|
||||||
|
|
||||||
BLT_DEBUG("Generate Initial Population");
|
BLT_DEBUG("Generate Initial Population");
|
||||||
auto sel = blt::gp::select_tournament_t{};
|
auto sel = blt::gp::select_tournament_t{};
|
||||||
|
@ -222,8 +236,47 @@ int main(int argc, const char** argv)
|
||||||
for (auto& i_ref : best)
|
for (auto& i_ref : best)
|
||||||
{
|
{
|
||||||
auto& i = i_ref.get();
|
auto& i = i_ref.get();
|
||||||
|
struct match_t
|
||||||
|
{
|
||||||
|
blt::size_t cc = 0;
|
||||||
|
blt::size_t co = 0;
|
||||||
|
blt::size_t oo = 0;
|
||||||
|
blt::size_t oc = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
match_t match;
|
||||||
|
|
||||||
|
for (auto& testing_case : testing_cases)
|
||||||
|
{
|
||||||
|
auto result = i.tree.get_evaluation_value<float>(&testing_case);
|
||||||
|
switch (testing_case.type)
|
||||||
|
{
|
||||||
|
case rice_type_t::Cammeo:
|
||||||
|
if (result >= 0)
|
||||||
|
match.cc++; // cammeo cammeo
|
||||||
|
else if (result < 0)
|
||||||
|
match.co++; // cammeo osmancik
|
||||||
|
break;
|
||||||
|
case rice_type_t::Osmancik:
|
||||||
|
if (result < 0)
|
||||||
|
match.oo++; // osmancik osmancik
|
||||||
|
else if (result >= 0)
|
||||||
|
match.oc++; // osmancik cammeo
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto hits = match.cc + match.oo;
|
||||||
|
auto size = testing_cases.size();
|
||||||
|
|
||||||
|
BLT_INFO("Hits %ld, Total Cases %ld, Percent Hit: %lf", hits, size, static_cast<double>(hits) / static_cast<double>(size) * 100);
|
||||||
|
BLT_DEBUG("Cammeo Cammeo: %ld", match.cc);
|
||||||
|
BLT_DEBUG("Cammeo Osmancik: %ld", match.co);
|
||||||
|
BLT_DEBUG("Osmancik Osmancik: %ld", match.oo);
|
||||||
|
BLT_DEBUG("Osmancik Cammeo: %ld", match.oc);
|
||||||
BLT_DEBUG("Fitness: %lf, stand: %lf, raw: %lf", i.fitness.adjusted_fitness, i.fitness.standardized_fitness, i.fitness.raw_fitness);
|
BLT_DEBUG("Fitness: %lf, stand: %lf, raw: %lf", i.fitness.adjusted_fitness, i.fitness.standardized_fitness, i.fitness.raw_fitness);
|
||||||
i.tree.print(program, std::cout);
|
i.tree.print(program, std::cout);
|
||||||
|
|
||||||
std::cout << "\n";
|
std::cout << "\n";
|
||||||
}
|
}
|
||||||
auto& stats = program.get_population_stats();
|
auto& stats = program.get_population_stats();
|
||||||
|
|
|
@ -31,7 +31,7 @@ struct context
|
||||||
float x, y;
|
float x, y;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::array<context, 200> fitness_cases;
|
std::array<context, 200> training_cases;
|
||||||
|
|
||||||
blt::gp::prog_config_t config = blt::gp::prog_config_t()
|
blt::gp::prog_config_t config = blt::gp::prog_config_t()
|
||||||
.set_initial_min_tree_size(2)
|
.set_initial_min_tree_size(2)
|
||||||
|
@ -57,7 +57,7 @@ blt::gp::operation_t op_x([](const context& context) {
|
||||||
|
|
||||||
constexpr auto fitness_function = [](blt::gp::tree_t& current_tree, blt::gp::fitness_t& fitness, blt::size_t) {
|
constexpr auto fitness_function = [](blt::gp::tree_t& current_tree, blt::gp::fitness_t& fitness, blt::size_t) {
|
||||||
constexpr double value_cutoff = 1.e15;
|
constexpr double value_cutoff = 1.e15;
|
||||||
for (auto& fitness_case : fitness_cases)
|
for (auto& fitness_case : training_cases)
|
||||||
{
|
{
|
||||||
auto diff = std::abs(fitness_case.y - current_tree.get_evaluation_value<float>(&fitness_case));
|
auto diff = std::abs(fitness_case.y - current_tree.get_evaluation_value<float>(&fitness_case));
|
||||||
if (diff < value_cutoff)
|
if (diff < value_cutoff)
|
||||||
|
@ -70,7 +70,7 @@ constexpr auto fitness_function = [](blt::gp::tree_t& current_tree, blt::gp::fit
|
||||||
}
|
}
|
||||||
fitness.standardized_fitness = fitness.raw_fitness;
|
fitness.standardized_fitness = fitness.raw_fitness;
|
||||||
fitness.adjusted_fitness = (1.0 / (1.0 + fitness.standardized_fitness));
|
fitness.adjusted_fitness = (1.0 / (1.0 + fitness.standardized_fitness));
|
||||||
return static_cast<blt::size_t>(fitness.hits) == fitness_cases.size();
|
return static_cast<blt::size_t>(fitness.hits) == training_cases.size();
|
||||||
};
|
};
|
||||||
|
|
||||||
float example_function(float x)
|
float example_function(float x)
|
||||||
|
@ -83,7 +83,7 @@ int main()
|
||||||
BLT_INFO("Starting BLT-GP Symbolic Regression Example");
|
BLT_INFO("Starting BLT-GP Symbolic Regression Example");
|
||||||
BLT_START_INTERVAL("Symbolic Regression", "Main");
|
BLT_START_INTERVAL("Symbolic Regression", "Main");
|
||||||
BLT_DEBUG("Setup Fitness cases");
|
BLT_DEBUG("Setup Fitness cases");
|
||||||
for (auto& fitness_case : fitness_cases)
|
for (auto& fitness_case : training_cases)
|
||||||
{
|
{
|
||||||
constexpr float range = 10;
|
constexpr float range = 10;
|
||||||
constexpr float half_range = range / 2.0;
|
constexpr float half_range = range / 2.0;
|
||||||
|
|
2
lib/blt
2
lib/blt
|
@ -1 +1 @@
|
||||||
Subproject commit b6354bed7846078e863767ce5afc7daa53b93988
|
Subproject commit 79e080cfd34fb47342f67f19b95ffa27efb0f715
|
|
@ -107,7 +107,7 @@ struct context
|
||||||
float x, y;
|
float x, y;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::array<context, 200> fitness_cases;
|
std::array<context, 200> training_cases;
|
||||||
|
|
||||||
blt::gp::prog_config_t config = blt::gp::prog_config_t()
|
blt::gp::prog_config_t config = blt::gp::prog_config_t()
|
||||||
.set_initial_min_tree_size(2)
|
.set_initial_min_tree_size(2)
|
||||||
|
@ -141,7 +141,7 @@ blt::gp::operation_t op_x([](const context& context) {
|
||||||
|
|
||||||
constexpr auto fitness_function = [](blt::gp::tree_t& current_tree, blt::gp::fitness_t& fitness, blt::size_t) {
|
constexpr auto fitness_function = [](blt::gp::tree_t& current_tree, blt::gp::fitness_t& fitness, blt::size_t) {
|
||||||
constexpr double value_cutoff = 1.e15;
|
constexpr double value_cutoff = 1.e15;
|
||||||
for (auto& fitness_case : fitness_cases)
|
for (auto& fitness_case : training_cases)
|
||||||
{
|
{
|
||||||
auto ctx = current_tree.evaluate(&fitness_case);
|
auto ctx = current_tree.evaluate(&fitness_case);
|
||||||
auto diff = std::abs(fitness_case.y - *current_tree.get_evaluation_ref<move_float>(ctx));
|
auto diff = std::abs(fitness_case.y - *current_tree.get_evaluation_ref<move_float>(ctx));
|
||||||
|
@ -157,7 +157,7 @@ constexpr auto fitness_function = [](blt::gp::tree_t& current_tree, blt::gp::fit
|
||||||
}
|
}
|
||||||
fitness.standardized_fitness = fitness.raw_fitness;
|
fitness.standardized_fitness = fitness.raw_fitness;
|
||||||
fitness.adjusted_fitness = (1.0 / (1.0 + fitness.standardized_fitness));
|
fitness.adjusted_fitness = (1.0 / (1.0 + fitness.standardized_fitness));
|
||||||
return static_cast<blt::size_t>(fitness.hits) == fitness_cases.size();
|
return static_cast<blt::size_t>(fitness.hits) == training_cases.size();
|
||||||
};
|
};
|
||||||
|
|
||||||
float example_function(float x)
|
float example_function(float x)
|
||||||
|
@ -170,7 +170,7 @@ int main()
|
||||||
BLT_INFO("Starting BLT-GP Symbolic Regression Example");
|
BLT_INFO("Starting BLT-GP Symbolic Regression Example");
|
||||||
BLT_START_INTERVAL("Symbolic Regression", "Main");
|
BLT_START_INTERVAL("Symbolic Regression", "Main");
|
||||||
BLT_DEBUG("Setup Fitness cases");
|
BLT_DEBUG("Setup Fitness cases");
|
||||||
for (auto& fitness_case : fitness_cases)
|
for (auto& fitness_case : training_cases)
|
||||||
{
|
{
|
||||||
constexpr float range = 10;
|
constexpr float range = 10;
|
||||||
constexpr float half_range = range / 2.0;
|
constexpr float half_range = range / 2.0;
|
||||||
|
|
Loading…
Reference in New Issue