now both examples are in their own header file
parent
946ddcc572
commit
e1083426fc
|
@ -479,26 +479,5 @@
|
|||
<option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=StringLiteralTypo/@EntryIndexRemoved" />
|
||||
<option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=CommentTypo/@EntryIndexRemoved" />
|
||||
<option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=IdentifierTypo/@EntryIndexRemoved" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/@KeyIndexDefined" value="true" type="bool" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=AccessRight/@EntryIndexedValue" value="public" type="string" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=AccessRight/@EntryIndexRemoved" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=GetterAndSetterKind/@EntryIndexedValue" value="Getter" type="string" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=GetterAndSetterKind/@EntryIndexRemoved" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=AcceptParameterKind/@EntryIndexedValue" value="Value and copy into field" type="string" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=AcceptParameterKind/@EntryIndexRemoved" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=ReturnKind/@EntryIndexedValue" value="Value" type="string" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=ReturnKind/@EntryIndexRemoved" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=InlineDefinition/@EntryIndexedValue" value="True" type="string" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=InlineDefinition/@EntryIndexRemoved" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=InsertVirtualSpecifier/@EntryIndexedValue" value="False" type="string" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=InsertVirtualSpecifier/@EntryIndexRemoved" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=UseConstParameterTypes/@EntryIndexedValue" value="True" type="string" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=UseConstParameterTypes/@EntryIndexRemoved" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=AddPrefixesToGetters/@EntryIndexedValue" value="True" type="string" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=AddPrefixesToGetters/@EntryIndexRemoved" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=AddPrefixesToSetters/@EntryIndexedValue" value="True" type="string" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=AddPrefixesToSetters/@EntryIndexRemoved" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=DeclareGettersNodiscard/@EntryIndexedValue" value="True" type="string" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppGettersAndSetters/Options/=DeclareGettersNodiscard/@EntryIndexRemoved" />
|
||||
</component>
|
||||
</project>
|
|
@ -27,7 +27,7 @@ macro(compile_options target_name)
|
|||
sanitizers(${target_name})
|
||||
endmacro()
|
||||
|
||||
project(blt-gp VERSION 0.2.3)
|
||||
project(blt-gp VERSION 0.2.4)
|
||||
|
||||
include(CTest)
|
||||
|
||||
|
|
|
@ -49,6 +49,9 @@ namespace blt::gp::example
|
|||
return *this;
|
||||
}
|
||||
|
||||
gp_program& get_program() { return program; }
|
||||
const gp_program& get_program() const { return program; }
|
||||
|
||||
protected:
|
||||
gp_program program;
|
||||
selection_t* crossover_sel = nullptr;
|
||||
|
|
|
@ -44,95 +44,132 @@ namespace blt::gp::example
|
|||
rice_type_t type;
|
||||
};
|
||||
|
||||
void make_operators()
|
||||
{
|
||||
static operation_t add{[](const float a, const float b) { return a + b; }, "add"};
|
||||
static operation_t sub([](const float a, const float b) { return a - b; }, "sub");
|
||||
static operation_t mul([](const float a, const float b) { return a * b; }, "mul");
|
||||
static operation_t pro_div([](const float a, const float b) { return b == 0.0f ? 1.0f : a / b; }, "div");
|
||||
static operation_t op_sin([](const float a) { return std::sin(a); }, "sin");
|
||||
static operation_t op_cos([](const float a) { return std::cos(a); }, "cos");
|
||||
static operation_t op_exp([](const float a) { return std::exp(a); }, "exp");
|
||||
static operation_t op_log([](const float a) { return a == 0.0f ? 0.0f : std::log(a); }, "log");
|
||||
static auto lit = blt::gp::operation_t([]()
|
||||
{
|
||||
return program.get_random().get_float(-32000.0f, 32000.0f);
|
||||
}, "lit").set_ephemeral();
|
||||
|
||||
static operation_t op_area([](const rice_record& rice_data)
|
||||
{
|
||||
return rice_data.area;
|
||||
}, "area");
|
||||
|
||||
static operation_t op_perimeter([](const rice_record& rice_data)
|
||||
{
|
||||
return rice_data.perimeter;
|
||||
}, "perimeter");
|
||||
|
||||
static operation_t op_major_axis_length([](const rice_record& rice_data)
|
||||
{
|
||||
return rice_data.major_axis_length;
|
||||
}, "major_axis_length");
|
||||
|
||||
static operation_t op_minor_axis_length([](const rice_record& rice_data)
|
||||
{
|
||||
return rice_data.minor_axis_length;
|
||||
}, "minor_axis_length");
|
||||
|
||||
static operation_t op_eccentricity([](const rice_record& rice_data)
|
||||
{
|
||||
return rice_data.eccentricity;
|
||||
}, "eccentricity");
|
||||
|
||||
static operation_t op_convex_area([](const rice_record& rice_data)
|
||||
{
|
||||
return rice_data.convex_area;
|
||||
}, "convex_area");
|
||||
|
||||
static operation_t op_extent([](const rice_record& rice_data)
|
||||
{
|
||||
return rice_data.extent;
|
||||
}, "extent");
|
||||
}
|
||||
|
||||
bool fitness_function(const tree_t& current_tree, fitness_t& fitness, size_t) const
|
||||
{
|
||||
for (auto& training_case : training_cases)
|
||||
{
|
||||
auto v = current_tree.get_evaluation_value<float>(training_case);
|
||||
switch (training_case.type)
|
||||
{
|
||||
case rice_type_t::Cammeo:
|
||||
if (v >= 0)
|
||||
fitness.hits++;
|
||||
break;
|
||||
case rice_type_t::Osmancik:
|
||||
if (v < 0)
|
||||
fitness.hits++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
fitness.raw_fitness = static_cast<double>(fitness.hits);
|
||||
fitness.standardized_fitness = fitness.raw_fitness;
|
||||
fitness.adjusted_fitness = 1.0 - (1.0 / (1.0 + fitness.standardized_fitness));
|
||||
return static_cast<size_t>(fitness.hits) == training_cases.size();
|
||||
}
|
||||
|
||||
void load_rice_data(std::string_view rice_file_path);
|
||||
|
||||
bool fitness_function(const tree_t& current_tree, fitness_t& fitness, size_t) const;
|
||||
public:
|
||||
template <typename SEED>
|
||||
rice_classification_t(SEED&& seed, const prog_config_t& config): example_base_t{std::forward<SEED>(seed), config}
|
||||
{
|
||||
BLT_INFO("Starting BLT-GP Rice Classification Example");
|
||||
fitness_function_ref = [this](const tree_t& t, fitness_t& f, const size_t i)
|
||||
{
|
||||
return fitness_function(t, f, i);
|
||||
};
|
||||
}
|
||||
|
||||
void make_operators();
|
||||
|
||||
void load_rice_data(std::string_view rice_file_path);
|
||||
|
||||
confusion_matrix_t test_individual(const individual_t& individual) const;
|
||||
|
||||
void execute(const std::string_view rice_file_path)
|
||||
{
|
||||
load_rice_data(rice_file_path);
|
||||
make_operators();
|
||||
generate_initial_population();
|
||||
run_generation_loop();
|
||||
evaluate_individuals();
|
||||
print_best();
|
||||
print_average();
|
||||
}
|
||||
|
||||
void run_generation_loop()
|
||||
{
|
||||
BLT_DEBUG("Begin Generation Loop");
|
||||
while (!program.should_terminate())
|
||||
{
|
||||
BLT_TRACE("------------{Begin Generation %ld}------------", program.get_current_generation());
|
||||
BLT_TRACE("Creating next generation");
|
||||
program.create_next_generation();
|
||||
BLT_TRACE("Move to next generation");
|
||||
program.next_generation();
|
||||
BLT_TRACE("Evaluate Fitness");
|
||||
program.evaluate_fitness();
|
||||
auto& stats = program.get_population_stats();
|
||||
BLT_TRACE("Avg Fit: %lf, Best Fit: %lf, Worst Fit: %lf, Overall Fit: %lf",
|
||||
stats.average_fitness.load(std::memory_order_relaxed), stats.best_fitness.load(std::memory_order_relaxed),
|
||||
stats.worst_fitness.load(std::memory_order_relaxed), stats.overall_fitness.load(std::memory_order_relaxed));
|
||||
BLT_TRACE("----------------------------------------------");
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void evaluate_individuals()
|
||||
{
|
||||
results.clear();
|
||||
for (auto& i : program.get_current_pop().get_individuals())
|
||||
results.emplace_back(test_individual(i), &i);
|
||||
std::sort(results.begin(), results.end(), [](const auto& a, const auto& b)
|
||||
{
|
||||
return a.first > b.first;
|
||||
});
|
||||
}
|
||||
|
||||
void generate_initial_population()
|
||||
{
|
||||
BLT_DEBUG("Generate Initial Population");
|
||||
static auto sel = select_tournament_t{};
|
||||
if (crossover_sel == nullptr)
|
||||
crossover_sel = &sel;
|
||||
if (mutation_sel == nullptr)
|
||||
mutation_sel = &sel;
|
||||
if (reproduction_sel == nullptr)
|
||||
reproduction_sel = &sel;
|
||||
program.generate_population(program.get_typesystem().get_type<float>().id(), fitness_function_ref, *crossover_sel, *mutation_sel,
|
||||
*reproduction_sel);
|
||||
}
|
||||
|
||||
void print_best(const size_t amount = 3)
|
||||
{
|
||||
BLT_INFO("Best results:");
|
||||
for (size_t index = 0; index < amount; index++)
|
||||
{
|
||||
const auto& record = results[index].first;
|
||||
const auto& i = *results[index].second;
|
||||
|
||||
BLT_INFO("Hits %ld, Total Cases %ld, Percent Hit: %lf", record.get_hits(), record.get_total(), record.get_percent_hit());
|
||||
std::cout << record.pretty_print() << std::endl;
|
||||
BLT_DEBUG("Fitness: %lf, stand: %lf, raw: %lf", i.fitness.adjusted_fitness, i.fitness.standardized_fitness, i.fitness.raw_fitness);
|
||||
i.tree.print(program, std::cout);
|
||||
|
||||
std::cout << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
void print_worst(const size_t amount = 3) const
|
||||
{
|
||||
BLT_INFO("Worst Results:");
|
||||
for (size_t index = 0; index < amount; index++)
|
||||
{
|
||||
const auto& record = results[results.size() - 1 - index].first;
|
||||
const auto& i = *results[results.size() - 1 - index].second;
|
||||
|
||||
BLT_INFO("Hits %ld, Total Cases %ld, Percent Hit: %lf", record.get_hits(), record.get_total(), record.get_percent_hit());
|
||||
std::cout << record.pretty_print() << std::endl;
|
||||
BLT_DEBUG("Fitness: %lf, stand: %lf, raw: %lf", i.fitness.adjusted_fitness, i.fitness.standardized_fitness, i.fitness.raw_fitness);
|
||||
|
||||
std::cout << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
void print_average()
|
||||
{
|
||||
BLT_INFO("Average Results");
|
||||
confusion_matrix_t avg{};
|
||||
avg.set_name_a("cammeo");
|
||||
avg.set_name_b("osmancik");
|
||||
for (const auto& [matrix, _] : results)
|
||||
avg += matrix;
|
||||
avg /= results.size();
|
||||
BLT_INFO("Hits %ld, Total Cases %ld, Percent Hit: %lf", avg.get_hits(), avg.get_total(), avg.get_percent_hit());
|
||||
std::cout << avg.pretty_print() << std::endl;
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<rice_record> training_cases;
|
||||
std::vector<rice_record> testing_cases;
|
||||
std::vector<std::pair<confusion_matrix_t, individual_t*>> results;
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -41,8 +41,109 @@ blt::gp::prog_config_t config = blt::gp::prog_config_t()
|
|||
.set_pop_size(500)
|
||||
.set_thread_count(0);
|
||||
|
||||
void blt::gp::example::rice_classification_t::load_rice_data(std::string_view rice_file_path)
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
blt::arg_parse parser;
|
||||
parser.addArgument(blt::arg_builder{"-f", "--file"}.setHelp("File for rice data. Should be in .arff format.").setRequired().build());
|
||||
|
||||
auto args = parser.parse_args(argc, argv);
|
||||
|
||||
if (!args.contains("file"))
|
||||
{
|
||||
BLT_WARN("Please provide path to file with -f or --file");
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto rice_file_path = args.get<std::string>("file");
|
||||
|
||||
blt::gp::example::rice_classification_t rice_classification{SEED_FUNC, config};
|
||||
|
||||
rice_classification.execute(rice_file_path);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void blt::gp::example::rice_classification_t::make_operators()
|
||||
{
|
||||
BLT_DEBUG("Setup Types and Operators");
|
||||
static operation_t add{[](const float a, const float b) { return a + b; }, "add"};
|
||||
static operation_t sub([](const float a, const float b) { return a - b; }, "sub");
|
||||
static operation_t mul([](const float a, const float b) { return a * b; }, "mul");
|
||||
static operation_t pro_div([](const float a, const float b) { return b == 0.0f ? 0.0f : a / b; }, "div");
|
||||
static operation_t op_exp([](const float a) { return std::exp(a); }, "exp");
|
||||
static operation_t op_log([](const float a) { return a == 0.0f ? 0.0f : std::log(a); }, "log");
|
||||
static auto lit = operation_t([this]()
|
||||
{
|
||||
return program.get_random().get_float(-32000.0f, 32000.0f);
|
||||
}, "lit").set_ephemeral();
|
||||
|
||||
static operation_t op_area([](const rice_record& rice_data)
|
||||
{
|
||||
return rice_data.area;
|
||||
}, "area");
|
||||
|
||||
static operation_t op_perimeter([](const rice_record& rice_data)
|
||||
{
|
||||
return rice_data.perimeter;
|
||||
}, "perimeter");
|
||||
|
||||
static operation_t op_major_axis_length([](const rice_record& rice_data)
|
||||
{
|
||||
return rice_data.major_axis_length;
|
||||
}, "major_axis_length");
|
||||
|
||||
static operation_t op_minor_axis_length([](const rice_record& rice_data)
|
||||
{
|
||||
return rice_data.minor_axis_length;
|
||||
}, "minor_axis_length");
|
||||
|
||||
static operation_t op_eccentricity([](const rice_record& rice_data)
|
||||
{
|
||||
return rice_data.eccentricity;
|
||||
}, "eccentricity");
|
||||
|
||||
static operation_t op_convex_area([](const rice_record& rice_data)
|
||||
{
|
||||
return rice_data.convex_area;
|
||||
}, "convex_area");
|
||||
|
||||
static operation_t op_extent([](const rice_record& rice_data)
|
||||
{
|
||||
return rice_data.extent;
|
||||
}, "extent");
|
||||
|
||||
operator_builder<rice_record> builder{};
|
||||
builder.build(add, sub, mul, pro_div, op_exp, op_log, lit, op_area, op_perimeter, op_major_axis_length,
|
||||
op_minor_axis_length, op_eccentricity, op_convex_area, op_extent);
|
||||
program.set_operations(builder.grab());
|
||||
}
|
||||
|
||||
bool blt::gp::example::rice_classification_t::fitness_function(const tree_t& current_tree, fitness_t& fitness, size_t) const
|
||||
{
|
||||
for (auto& training_case : training_cases)
|
||||
{
|
||||
auto v = current_tree.get_evaluation_value<float>(training_case);
|
||||
switch (training_case.type)
|
||||
{
|
||||
case rice_type_t::Cammeo:
|
||||
if (v >= 0)
|
||||
fitness.hits++;
|
||||
break;
|
||||
case rice_type_t::Osmancik:
|
||||
if (v < 0)
|
||||
fitness.hits++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
fitness.raw_fitness = static_cast<double>(fitness.hits);
|
||||
fitness.standardized_fitness = fitness.raw_fitness;
|
||||
fitness.adjusted_fitness = 1.0 - (1.0 / (1.0 + fitness.standardized_fitness));
|
||||
return static_cast<size_t>(fitness.hits) == training_cases.size();
|
||||
}
|
||||
|
||||
void blt::gp::example::rice_classification_t::load_rice_data(const std::string_view rice_file_path)
|
||||
{
|
||||
BLT_DEBUG("Setup Fitness cases");
|
||||
auto rice_file_data = fs::getLinesFromFile(rice_file_path);
|
||||
size_t index = 0;
|
||||
while (!string::contains(rice_file_data[index++], "@DATA"))
|
||||
|
@ -50,7 +151,7 @@ void blt::gp::example::rice_classification_t::load_rice_data(std::string_view ri
|
|||
}
|
||||
std::vector<rice_record> c;
|
||||
std::vector<rice_record> o;
|
||||
for (std::string_view v : iterate(rice_file_data).skip(index))
|
||||
for (const std::string_view v : iterate(rice_file_data).skip(index))
|
||||
{
|
||||
auto data = string::split(v, ',');
|
||||
rice_record r{
|
||||
|
@ -70,11 +171,11 @@ void blt::gp::example::rice_classification_t::load_rice_data(std::string_view ri
|
|||
|
||||
size_t total_records = c.size() + o.size();
|
||||
size_t training_size = std::min(total_records / 3, 1000ul);
|
||||
for (blt::size_t i = 0; i < training_size; i++)
|
||||
for (size_t i = 0; i < training_size; i++)
|
||||
{
|
||||
auto& random = program.get_random();
|
||||
auto& vec = random.choice() ? c : o;
|
||||
auto pos = random.get_i64(0, static_cast<blt::i64>(vec.size()));
|
||||
auto pos = random.get_i64(0, static_cast<i64>(vec.size()));
|
||||
training_cases.push_back(vec[pos]);
|
||||
vec.erase(vec.begin() + pos);
|
||||
}
|
||||
|
@ -84,197 +185,31 @@ void blt::gp::example::rice_classification_t::load_rice_data(std::string_view ri
|
|||
BLT_INFO("Created training set of size %ld, testing set is of size %ld", training_size, testing_cases.size());
|
||||
}
|
||||
|
||||
struct test_results_t
|
||||
blt::gp::confusion_matrix_t blt::gp::example::rice_classification_t::test_individual(const individual_t& individual) const
|
||||
{
|
||||
blt::size_t cc = 0;
|
||||
blt::size_t co = 0;
|
||||
blt::size_t oo = 0;
|
||||
blt::size_t oc = 0;
|
||||
blt::size_t hits = 0;
|
||||
blt::size_t size = 0;
|
||||
double percent_hit = 0;
|
||||
|
||||
test_results_t& operator+=(const test_results_t& a)
|
||||
{
|
||||
cc += a.cc;
|
||||
co += a.co;
|
||||
oo += a.oo;
|
||||
oc += a.oc;
|
||||
hits += a.hits;
|
||||
size += a.size;
|
||||
percent_hit += a.percent_hit;
|
||||
return *this;
|
||||
}
|
||||
|
||||
test_results_t& operator/=(blt::size_t s)
|
||||
{
|
||||
cc /= s;
|
||||
co /= s;
|
||||
oo /= s;
|
||||
oc /= s;
|
||||
hits /= s;
|
||||
size /= s;
|
||||
percent_hit /= static_cast<double>(s);
|
||||
return *this;
|
||||
}
|
||||
|
||||
friend bool operator<(const test_results_t& a, const test_results_t& b)
|
||||
{
|
||||
return a.hits < b.hits;
|
||||
}
|
||||
|
||||
friend bool operator>(const test_results_t& a, const test_results_t& b)
|
||||
{
|
||||
return a.hits > b.hits;
|
||||
}
|
||||
};
|
||||
|
||||
test_results_t test_individual(blt::gp::individual_t& i)
|
||||
{
|
||||
test_results_t results;
|
||||
confusion_matrix_t confusion_matrix;
|
||||
confusion_matrix.set_name_a("cammeo");
|
||||
confusion_matrix.set_name_b("osmancik");
|
||||
|
||||
for (auto& testing_case : testing_cases)
|
||||
{
|
||||
auto result = i.tree.get_evaluation_value<float>(testing_case);
|
||||
const auto result = individual.tree.get_evaluation_value<float>(testing_case);
|
||||
switch (testing_case.type)
|
||||
{
|
||||
case rice_type_t::Cammeo:
|
||||
if (result >= 0)
|
||||
results.cc++; // cammeo cammeo
|
||||
confusion_matrix.is_A_predicted_A(); // cammeo cammeo
|
||||
else
|
||||
results.co++; // cammeo osmancik
|
||||
confusion_matrix.is_A_predicted_B(); // cammeo osmancik
|
||||
break;
|
||||
case rice_type_t::Osmancik:
|
||||
if (result < 0)
|
||||
results.oo++; // osmancik osmancik
|
||||
confusion_matrix.is_B_predicted_B(); // osmancik osmancik
|
||||
else
|
||||
results.oc++; // osmancik cammeo
|
||||
confusion_matrix.is_B_predicted_A(); // osmancik cammeo
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
results.hits = results.cc + results.oo;
|
||||
results.size = testing_cases.size();
|
||||
results.percent_hit = static_cast<double>(results.hits) / static_cast<double>(results.size) * 100;
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
blt::arg_parse parser;
|
||||
parser.addArgument(blt::arg_builder{"-f", "--file"}.setHelp("File for rice data. Should be in .arff format.").setRequired().build());
|
||||
|
||||
auto args = parser.parse_args(argc, argv);
|
||||
|
||||
if (!args.contains("file"))
|
||||
{
|
||||
BLT_WARN("Please provide path to file with -f or --file");
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto rice_file_path = args.get<std::string>("file");
|
||||
|
||||
BLT_INFO("Starting BLT-GP Rice Classification Example");
|
||||
BLT_START_INTERVAL("Rice Classification", "Main");
|
||||
BLT_DEBUG("Setup Fitness cases");
|
||||
load_rice_data(rice_file_path);
|
||||
|
||||
BLT_DEBUG("Setup Types and Operators");
|
||||
|
||||
blt::gp::operator_builder<rice_record> builder{};
|
||||
program.set_operations(builder.build(add, sub, mul, pro_div, op_exp, op_log, lit, op_area, op_perimeter, op_major_axis_length,
|
||||
op_minor_axis_length, op_eccentricity, op_convex_area, op_extent));
|
||||
|
||||
BLT_DEBUG("Generate Initial Population");
|
||||
auto sel = blt::gp::select_tournament_t{};
|
||||
program.generate_population(program.get_typesystem().get_type<float>().id(), fitness_function, sel, sel, sel);
|
||||
|
||||
BLT_DEBUG("Begin Generation Loop");
|
||||
while (!program.should_terminate())
|
||||
{
|
||||
BLT_TRACE("------------{Begin Generation %ld}------------", program.get_current_generation());
|
||||
BLT_TRACE("Creating next generation");
|
||||
BLT_START_INTERVAL("Rice Classification", "Gen");
|
||||
program.create_next_generation();
|
||||
BLT_END_INTERVAL("Rice Classification", "Gen");
|
||||
BLT_TRACE("Move to next generation");
|
||||
BLT_START_INTERVAL("Rice Classification", "Fitness");
|
||||
program.next_generation();
|
||||
BLT_TRACE("Evaluate Fitness");
|
||||
program.evaluate_fitness();
|
||||
BLT_END_INTERVAL("Rice Classification", "Fitness");
|
||||
auto& stats = program.get_population_stats();
|
||||
BLT_TRACE("Stats:");
|
||||
BLT_TRACE("Average fitness: %lf", stats.average_fitness.load());
|
||||
BLT_TRACE("Best fitness: %lf", stats.best_fitness.load());
|
||||
BLT_TRACE("Worst fitness: %lf", stats.worst_fitness.load());
|
||||
BLT_TRACE("Overall fitness: %lf", stats.overall_fitness.load());
|
||||
BLT_TRACE("----------------------------------------------");
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
BLT_END_INTERVAL("Rice Classification", "Main");
|
||||
|
||||
std::vector<std::pair<test_results_t, blt::gp::individual_t*>> results;
|
||||
for (auto& i : program.get_current_pop().get_individuals())
|
||||
results.emplace_back(test_individual(i), &i);
|
||||
std::sort(results.begin(), results.end(), [](const auto& a, const auto& b)
|
||||
{
|
||||
return a.first > b.first;
|
||||
});
|
||||
|
||||
BLT_INFO("Best results:");
|
||||
for (blt::size_t index = 0; index < 3; index++)
|
||||
{
|
||||
const auto& record = results[index].first;
|
||||
const auto& i = *results[index].second;
|
||||
|
||||
BLT_INFO("Hits %ld, Total Cases %ld, Percent Hit: %lf", record.hits, record.size, record.percent_hit);
|
||||
BLT_DEBUG("Cammeo Cammeo: %ld", record.cc);
|
||||
BLT_DEBUG("Cammeo Osmancik: %ld", record.co);
|
||||
BLT_DEBUG("Osmancik Osmancik: %ld", record.oo);
|
||||
BLT_DEBUG("Osmancik Cammeo: %ld", record.oc);
|
||||
BLT_DEBUG("Fitness: %lf, stand: %lf, raw: %lf", i.fitness.adjusted_fitness, i.fitness.standardized_fitness, i.fitness.raw_fitness);
|
||||
i.tree.print(program, std::cout);
|
||||
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
BLT_INFO("Worst Results:");
|
||||
for (blt::size_t index = 0; index < 3; index++)
|
||||
{
|
||||
const auto& record = results[results.size() - 1 - index].first;
|
||||
const auto& i = *results[results.size() - 1 - index].second;
|
||||
|
||||
BLT_INFO("Hits %ld, Total Cases %ld, Percent Hit: %lf", record.hits, record.size, record.percent_hit);
|
||||
BLT_DEBUG("Cammeo Cammeo: %ld", record.cc);
|
||||
BLT_DEBUG("Cammeo Osmancik: %ld", record.co);
|
||||
BLT_DEBUG("Osmancik Osmancik: %ld", record.oo);
|
||||
BLT_DEBUG("Osmancik Cammeo: %ld", record.oc);
|
||||
BLT_DEBUG("Fitness: %lf, stand: %lf, raw: %lf", i.fitness.adjusted_fitness, i.fitness.standardized_fitness, i.fitness.raw_fitness);
|
||||
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
BLT_INFO("Average Results");
|
||||
test_results_t avg{};
|
||||
for (const auto& v : results)
|
||||
avg += v.first;
|
||||
avg /= results.size();
|
||||
BLT_INFO("Hits %ld, Total Cases %ld, Percent Hit: %lf", avg.hits, avg.size, avg.percent_hit);
|
||||
BLT_DEBUG("Cammeo Cammeo: %ld", avg.cc);
|
||||
BLT_DEBUG("Cammeo Osmancik: %ld", avg.co);
|
||||
BLT_DEBUG("Osmancik Osmancik: %ld", avg.oo);
|
||||
BLT_DEBUG("Osmancik Cammeo: %ld", avg.oc);
|
||||
std::cout << "\n";
|
||||
|
||||
BLT_PRINT_PROFILE("Rice Classification", blt::PRINT_CYCLES | blt::PRINT_THREAD | blt::PRINT_WALL);
|
||||
|
||||
#ifdef BLT_TRACK_ALLOCATIONS
|
||||
BLT_TRACE("Total Allocations: %ld times with a total of %s", blt::gp::tracker.getAllocations(),
|
||||
blt::byte_convert_t(blt::gp::tracker.getAllocatedBytes()).convert_to_nearest_type().to_pretty_string().c_str());
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
return confusion_matrix;
|
||||
}
|
||||
|
|
|
@ -60,4 +60,50 @@ int main()
|
|||
regression.execute();
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
bool blt::gp::example::symbolic_regression_t::fitness_function(const tree_t& current_tree, fitness_t& fitness, size_t) const
|
||||
{
|
||||
constexpr static double value_cutoff = 1.e15;
|
||||
for (auto& fitness_case : training_cases)
|
||||
{
|
||||
const auto diff = std::abs(fitness_case.y - current_tree.get_evaluation_value<float>(fitness_case));
|
||||
if (diff < value_cutoff)
|
||||
{
|
||||
fitness.raw_fitness += diff;
|
||||
if (diff <= 0.01)
|
||||
fitness.hits++;
|
||||
}
|
||||
else
|
||||
fitness.raw_fitness += value_cutoff;
|
||||
}
|
||||
fitness.standardized_fitness = fitness.raw_fitness;
|
||||
fitness.adjusted_fitness = (1.0 / (1.0 + fitness.standardized_fitness));
|
||||
return static_cast<size_t>(fitness.hits) == training_cases.size();
|
||||
}
|
||||
|
||||
void blt::gp::example::symbolic_regression_t::setup_operations()
|
||||
{
|
||||
BLT_DEBUG("Setup Types and Operators");
|
||||
static operation_t add{[](const float a, const float b) { return a + b; }, "add"};
|
||||
static operation_t sub([](const float a, const float b) { return a - b; }, "sub");
|
||||
static operation_t mul([](const float a, const float b) { return a * b; }, "mul");
|
||||
static operation_t pro_div([](const float a, const float b) { return b == 0.0f ? 0.0f : a / b; }, "div");
|
||||
static operation_t op_sin([](const float a) { return std::sin(a); }, "sin");
|
||||
static operation_t op_cos([](const float a) { return std::cos(a); }, "cos");
|
||||
static operation_t op_exp([](const float a) { return std::exp(a); }, "exp");
|
||||
static operation_t op_log([](const float a) { return a == 0.0f ? 0.0f : std::log(a); }, "log");
|
||||
static auto lit = operation_t([this]()
|
||||
{
|
||||
return program.get_random().get_float(-1.0f, 1.0f);
|
||||
}, "lit").set_ephemeral();
|
||||
|
||||
static operation_t op_x([](const context& context)
|
||||
{
|
||||
return context.x;
|
||||
}, "x");
|
||||
|
||||
operator_builder<context> builder{};
|
||||
builder.build(add, sub, mul, pro_div, op_sin, op_cos, op_exp, op_log, lit, op_x);
|
||||
program.set_operations(builder.grab());
|
||||
}
|
||||
|
|
|
@ -35,25 +35,7 @@ namespace blt::gp::example
|
|||
};
|
||||
|
||||
private:
|
||||
bool fitness_function(const tree_t& current_tree, fitness_t& fitness, size_t) const
|
||||
{
|
||||
constexpr static double value_cutoff = 1.e15;
|
||||
for (auto& fitness_case : training_cases)
|
||||
{
|
||||
const auto diff = std::abs(fitness_case.y - current_tree.get_evaluation_value<float>(fitness_case));
|
||||
if (diff < value_cutoff)
|
||||
{
|
||||
fitness.raw_fitness += diff;
|
||||
if (diff <= 0.01)
|
||||
fitness.hits++;
|
||||
}
|
||||
else
|
||||
fitness.raw_fitness += value_cutoff;
|
||||
}
|
||||
fitness.standardized_fitness = fitness.raw_fitness;
|
||||
fitness.adjusted_fitness = (1.0 / (1.0 + fitness.standardized_fitness));
|
||||
return static_cast<size_t>(fitness.hits) == training_cases.size();
|
||||
}
|
||||
bool fitness_function(const tree_t& current_tree, fitness_t& fitness, size_t) const;
|
||||
|
||||
static float example_function(const float x)
|
||||
{
|
||||
|
@ -61,7 +43,7 @@ namespace blt::gp::example
|
|||
}
|
||||
|
||||
public:
|
||||
template<typename SEED>
|
||||
template <typename SEED>
|
||||
symbolic_regression_t(SEED seed, const prog_config_t& config): example_base_t{std::forward<SEED>(seed), config}
|
||||
{
|
||||
BLT_INFO("Starting BLT-GP Symbolic Regression Example");
|
||||
|
@ -81,37 +63,7 @@ namespace blt::gp::example
|
|||
};
|
||||
}
|
||||
|
||||
template <typename Ctx>
|
||||
auto make_operations(operator_builder<Ctx>& builder)
|
||||
{
|
||||
static operation_t add{[](const float a, const float b) { return a + b; }, "add"};
|
||||
static operation_t sub([](const float a, const float b) { return a - b; }, "sub");
|
||||
static operation_t mul([](const float a, const float b) { return a * b; }, "mul");
|
||||
static operation_t pro_div([](const float a, const float b) { return b == 0.0f ? 1.0f : a / b; }, "div");
|
||||
static operation_t op_sin([](const float a) { return std::sin(a); }, "sin");
|
||||
static operation_t op_cos([](const float a) { return std::cos(a); }, "cos");
|
||||
static operation_t op_exp([](const float a) { return std::exp(a); }, "exp");
|
||||
static operation_t op_log([](const float a) { return a == 0.0f ? 0.0f : std::log(a); }, "log");
|
||||
static auto lit = operation_t([this]()
|
||||
{
|
||||
return program.get_random().get_float(-1.0f, 1.0f);
|
||||
}, "lit").set_ephemeral();
|
||||
|
||||
static operation_t op_x([](const context& context)
|
||||
{
|
||||
return context.x;
|
||||
}, "x");
|
||||
|
||||
return builder.build(add, sub, mul, pro_div, op_sin, op_cos, op_exp, op_log, lit, op_x);
|
||||
}
|
||||
|
||||
void setup_operations()
|
||||
{
|
||||
BLT_DEBUG("Setup Types and Operators");
|
||||
operator_builder<context> builder{};
|
||||
make_operations(builder);
|
||||
program.set_operations(builder.grab());
|
||||
}
|
||||
void setup_operations();
|
||||
|
||||
void generate_initial_population()
|
||||
{
|
||||
|
|
|
@ -86,6 +86,26 @@ namespace blt::gp
|
|||
return is_B_pred_A;
|
||||
}
|
||||
|
||||
[[nodiscard]] u64 get_hits() const
|
||||
{
|
||||
return is_A_pred_A + is_B_pred_B;
|
||||
}
|
||||
|
||||
[[nodiscard]] u64 get_misses() const
|
||||
{
|
||||
return is_B_pred_A + is_A_pred_B;
|
||||
}
|
||||
|
||||
[[nodiscard]] u64 get_total() const
|
||||
{
|
||||
return get_hits() + get_misses();
|
||||
}
|
||||
|
||||
[[nodiscard]] double get_percent_hit() const
|
||||
{
|
||||
return static_cast<double>(get_hits()) / static_cast<double>(get_total());
|
||||
}
|
||||
|
||||
confusion_matrix_t& operator+=(const confusion_matrix_t& op)
|
||||
{
|
||||
is_A_pred_A += op.is_A_pred_A;
|
||||
|
@ -118,6 +138,16 @@ namespace blt::gp
|
|||
return result;
|
||||
}
|
||||
|
||||
friend bool operator<(const confusion_matrix_t& a, const confusion_matrix_t& b)
|
||||
{
|
||||
return a.get_percent_hit() < b.get_percent_hit();
|
||||
}
|
||||
|
||||
friend bool operator>(const confusion_matrix_t& a, const confusion_matrix_t& b)
|
||||
{
|
||||
return a.get_percent_hit() > b.get_percent_hit();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string pretty_print(const std::string& table_name = "Confusion Matrix") const;
|
||||
|
||||
private:
|
||||
|
@ -129,40 +159,6 @@ namespace blt::gp
|
|||
std::string name_B = "B";
|
||||
};
|
||||
|
||||
struct classifier_results_t : public confusion_matrix_t
|
||||
{
|
||||
public:
|
||||
[[nodiscard]] u64 get_hits() const
|
||||
{
|
||||
return hits;
|
||||
}
|
||||
|
||||
[[nodiscard]] u64 get_size() const
|
||||
{
|
||||
return size;
|
||||
}
|
||||
|
||||
[[nodiscard]] double get_percent_hit() const
|
||||
{
|
||||
return static_cast<double>(hits) / static_cast<double>(hits + misses);
|
||||
}
|
||||
|
||||
void hit()
|
||||
{
|
||||
++hits;
|
||||
}
|
||||
|
||||
void miss()
|
||||
{
|
||||
++misses;
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
u64 hits = 0;
|
||||
u64 misses = 0;
|
||||
};
|
||||
|
||||
struct population_stats
|
||||
{
|
||||
population_stats() = default;
|
||||
|
|
|
@ -26,18 +26,18 @@ namespace blt::gp {
|
|||
string::TableFormatter formatter{table_name};
|
||||
formatter.addColumn("Predicted " + name_A);
|
||||
formatter.addColumn("Predicted " + name_B);
|
||||
formatter.addColumn("");
|
||||
formatter.addColumn("Actual Class");
|
||||
|
||||
string::TableRow row;
|
||||
row.rowValues.push_back(std::to_string(is_A_pred_A));
|
||||
row.rowValues.push_back(std::to_string(is_A_pred_B));
|
||||
row.rowValues.push_back("Actual" + name_A);
|
||||
row.rowValues.push_back(name_A);
|
||||
formatter.addRow(row);
|
||||
|
||||
string::TableRow row2;
|
||||
row2.rowValues.push_back(std::to_string(is_B_pred_A));
|
||||
row2.rowValues.push_back(std::to_string(is_B_pred_B));
|
||||
row2.rowValues.push_back("Actual" + name_B);
|
||||
row2.rowValues.push_back(name_B);
|
||||
formatter.addRow(row2);
|
||||
|
||||
auto tbl = formatter.createTable(true, true);
|
||||
|
|
Loading…
Reference in New Issue