177 lines
6.6 KiB
C++
177 lines
6.6 KiB
C++
#pragma once
|
|
/*
|
|
* Copyright (C) 2024 Brett Terpstra
|
|
*
|
|
* This program is free software: you can redistribute it and/or modify
|
|
* it under the terms of the GNU General Public License as published by
|
|
* the Free Software Foundation, either version 3 of the License, or
|
|
* (at your option) any later version.
|
|
*
|
|
* This program is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
* GNU General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU General Public License
|
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
*/
|
|
|
|
#ifndef BLT_GP_EXAMPLES_RICE_CLASSIFICATION_H
|
|
#define BLT_GP_EXAMPLES_RICE_CLASSIFICATION_H
|
|
|
|
#include "examples_base.h"
|
|
|
|
namespace blt::gp::example
|
|
{
|
|
class rice_classification_t : public example_base_t
|
|
{
|
|
private:
|
|
enum class rice_type_t
|
|
{
|
|
Cammeo,
|
|
Osmancik
|
|
};
|
|
|
|
struct rice_record
|
|
{
|
|
float area;
|
|
float perimeter;
|
|
float major_axis_length;
|
|
float minor_axis_length;
|
|
float eccentricity;
|
|
float convex_area;
|
|
float extent;
|
|
rice_type_t type;
|
|
};
|
|
|
|
bool fitness_function(const tree_t& current_tree, fitness_t& fitness, size_t) const;
|
|
public:
|
|
template <typename SEED>
|
|
rice_classification_t(SEED&& seed, const prog_config_t& config): example_base_t{std::forward<SEED>(seed), config}
|
|
{
|
|
BLT_INFO("Starting BLT-GP Rice Classification Example");
|
|
fitness_function_ref = [this](const tree_t& t, fitness_t& f, const size_t i)
|
|
{
|
|
return fitness_function(t, f, i);
|
|
};
|
|
}
|
|
|
|
void make_operators();
|
|
|
|
void load_rice_data(std::string_view rice_file_path);
|
|
|
|
confusion_matrix_t test_individual(const individual_t& individual) const;
|
|
|
|
void execute(const std::string_view rice_file_path)
|
|
{
|
|
load_rice_data(rice_file_path);
|
|
make_operators();
|
|
generate_initial_population();
|
|
run_generation_loop();
|
|
evaluate_individuals();
|
|
print_best();
|
|
print_average();
|
|
}
|
|
|
|
void run_generation_loop()
|
|
{
|
|
BLT_DEBUG("Begin Generation Loop");
|
|
while (!program.should_terminate())
|
|
{
|
|
BLT_TRACE("------------{Begin Generation %ld}------------", program.get_current_generation());
|
|
BLT_TRACE("Creating next generation");
|
|
program.create_next_generation();
|
|
BLT_TRACE("Move to next generation");
|
|
program.next_generation();
|
|
BLT_TRACE("Evaluate Fitness");
|
|
program.evaluate_fitness();
|
|
auto& stats = program.get_population_stats();
|
|
BLT_TRACE("Avg Fit: %lf, Best Fit: %lf, Worst Fit: %lf, Overall Fit: %lf",
|
|
stats.average_fitness.load(std::memory_order_relaxed), stats.best_fitness.load(std::memory_order_relaxed),
|
|
stats.worst_fitness.load(std::memory_order_relaxed), stats.overall_fitness.load(std::memory_order_relaxed));
|
|
BLT_TRACE("----------------------------------------------");
|
|
std::cout << std::endl;
|
|
}
|
|
}
|
|
|
|
void evaluate_individuals()
|
|
{
|
|
results.clear();
|
|
for (auto& i : program.get_current_pop().get_individuals())
|
|
results.emplace_back(test_individual(i), &i);
|
|
std::sort(results.begin(), results.end(), [](const auto& a, const auto& b)
|
|
{
|
|
return a.first > b.first;
|
|
});
|
|
}
|
|
|
|
void generate_initial_population()
|
|
{
|
|
BLT_DEBUG("Generate Initial Population");
|
|
static auto sel = select_tournament_t{};
|
|
if (crossover_sel == nullptr)
|
|
crossover_sel = &sel;
|
|
if (mutation_sel == nullptr)
|
|
mutation_sel = &sel;
|
|
if (reproduction_sel == nullptr)
|
|
reproduction_sel = &sel;
|
|
program.generate_population(program.get_typesystem().get_type<float>().id(), fitness_function_ref, *crossover_sel, *mutation_sel,
|
|
*reproduction_sel);
|
|
}
|
|
|
|
void print_best(const size_t amount = 3)
|
|
{
|
|
BLT_INFO("Best results:");
|
|
for (size_t index = 0; index < amount; index++)
|
|
{
|
|
const auto& record = results[index].first;
|
|
const auto& i = *results[index].second;
|
|
|
|
BLT_INFO("Hits %ld, Total Cases %ld, Percent Hit: %lf", record.get_hits(), record.get_total(), record.get_percent_hit());
|
|
std::cout << record.pretty_print() << std::endl;
|
|
BLT_DEBUG("Fitness: %lf, stand: %lf, raw: %lf", i.fitness.adjusted_fitness, i.fitness.standardized_fitness, i.fitness.raw_fitness);
|
|
i.tree.print(program, std::cout);
|
|
|
|
std::cout << "\n";
|
|
}
|
|
}
|
|
|
|
void print_worst(const size_t amount = 3) const
|
|
{
|
|
BLT_INFO("Worst Results:");
|
|
for (size_t index = 0; index < amount; index++)
|
|
{
|
|
const auto& record = results[results.size() - 1 - index].first;
|
|
const auto& i = *results[results.size() - 1 - index].second;
|
|
|
|
BLT_INFO("Hits %ld, Total Cases %ld, Percent Hit: %lf", record.get_hits(), record.get_total(), record.get_percent_hit());
|
|
std::cout << record.pretty_print() << std::endl;
|
|
BLT_DEBUG("Fitness: %lf, stand: %lf, raw: %lf", i.fitness.adjusted_fitness, i.fitness.standardized_fitness, i.fitness.raw_fitness);
|
|
|
|
std::cout << "\n";
|
|
}
|
|
}
|
|
|
|
void print_average()
|
|
{
|
|
BLT_INFO("Average Results");
|
|
confusion_matrix_t avg{};
|
|
avg.set_name_a("cammeo");
|
|
avg.set_name_b("osmancik");
|
|
for (const auto& [matrix, _] : results)
|
|
avg += matrix;
|
|
avg /= results.size();
|
|
BLT_INFO("Hits %ld, Total Cases %ld, Percent Hit: %lf", avg.get_hits(), avg.get_total(), avg.get_percent_hit());
|
|
std::cout << avg.pretty_print() << std::endl;
|
|
std::cout << "\n";
|
|
}
|
|
|
|
private:
|
|
std::vector<rice_record> training_cases;
|
|
std::vector<rice_record> testing_cases;
|
|
std::vector<std::pair<confusion_matrix_t, individual_t*>> results;
|
|
};
|
|
}
|
|
|
|
#endif //BLT_GP_EXAMPLES_RICE_CLASSIFICATION_H
|