#pragma once /* * Copyright (C) 2024 Brett Terpstra * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ #ifndef BLT_GP_EXAMPLES_RICE_CLASSIFICATION_H #define BLT_GP_EXAMPLES_RICE_CLASSIFICATION_H #include "examples_base.h" namespace blt::gp::example { class rice_classification_t : public example_base_t { private: enum class rice_type_t { Cammeo, Osmancik }; struct rice_record { float area; float perimeter; float major_axis_length; float minor_axis_length; float eccentricity; float convex_area; float extent; rice_type_t type; }; void make_operators() { static operation_t add{[](const float a, const float b) { return a + b; }, "add"}; static operation_t sub([](const float a, const float b) { return a - b; }, "sub"); static operation_t mul([](const float a, const float b) { return a * b; }, "mul"); static operation_t pro_div([](const float a, const float b) { return b == 0.0f ? 1.0f : a / b; }, "div"); static operation_t op_sin([](const float a) { return std::sin(a); }, "sin"); static operation_t op_cos([](const float a) { return std::cos(a); }, "cos"); static operation_t op_exp([](const float a) { return std::exp(a); }, "exp"); static operation_t op_log([](const float a) { return a == 0.0f ? 0.0f : std::log(a); }, "log"); static auto lit = blt::gp::operation_t([]() { return program.get_random().get_float(-32000.0f, 32000.0f); }, "lit").set_ephemeral(); static operation_t op_area([](const rice_record& rice_data) { return rice_data.area; }, "area"); static operation_t op_perimeter([](const rice_record& rice_data) { return rice_data.perimeter; }, "perimeter"); static operation_t op_major_axis_length([](const rice_record& rice_data) { return rice_data.major_axis_length; }, "major_axis_length"); static operation_t op_minor_axis_length([](const rice_record& rice_data) { return rice_data.minor_axis_length; }, "minor_axis_length"); static operation_t op_eccentricity([](const rice_record& rice_data) { return rice_data.eccentricity; }, "eccentricity"); static operation_t op_convex_area([](const rice_record& rice_data) { return rice_data.convex_area; }, "convex_area"); static operation_t op_extent([](const rice_record& rice_data) { return rice_data.extent; }, "extent"); } bool fitness_function(const tree_t& current_tree, fitness_t& fitness, size_t) const { for (auto& training_case : training_cases) { auto v = current_tree.get_evaluation_value(training_case); switch (training_case.type) { case rice_type_t::Cammeo: if (v >= 0) fitness.hits++; break; case rice_type_t::Osmancik: if (v < 0) fitness.hits++; break; } } fitness.raw_fitness = static_cast(fitness.hits); fitness.standardized_fitness = fitness.raw_fitness; fitness.adjusted_fitness = 1.0 - (1.0 / (1.0 + fitness.standardized_fitness)); return static_cast(fitness.hits) == training_cases.size(); } void load_rice_data(std::string_view rice_file_path); public: template rice_classification_t(SEED&& seed, const prog_config_t& config): example_base_t{std::forward(seed), config} { fitness_function_ref = [this](const tree_t& t, fitness_t& f, const size_t i) { return fitness_function(t, f, i); }; } private: std::vector training_cases; std::vector testing_cases; }; } #endif //BLT_GP_EXAMPLES_RICE_CLASSIFICATION_H