blt-gp/tests/config_from_args.h

119 lines
7.3 KiB
C++

#pragma once
/*
* Copyright (C) 2024 Brett Terpstra
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
#ifndef CONFIG_FROM_ARGS_H
#define CONFIG_FROM_ARGS_H
#include <blt/parse/argparse_v2.h>
#include <blt/gp/program.h>
namespace blt::gp
{
inline void setup_crossover_parser(argparse::argument_parser_t& parser)
{
parser.add_flag("--max_crossover_tries").set_default(5u).as_type<u32>().set_help(
"number of times crossover will try to pick a valid point in the tree.");
parser.add_flag("--max_crossover_iterations").set_default(10u).as_type<u32>().set_help(
"how many times the crossover function can fail before we will skip this operation.");
parser.add_flag("--min_tree_size").set_default(5u).as_type<u32>().set_help("the minimum size of"
" the tree to be considered for crossover.");
parser.add_flag("--depth_multiplier").set_default(0.25).as_type<float>().set_help(
"at each depth level, what chance do we have to exit with this as our point?");
parser.add_flag("--terminal_chance").set_default(0.1).as_type<float>().set_help(
"how often should we select terminals over functions. By default, we only allow selection of terminals 10% of the time. "
"This applies to both types of crossover point functions. Traversal will use the parent if it should not pick a terminal.");
parser.add_flag("--traverse").make_flag().set_help(" use traversal to select instead of random point selection.");
}
inline std::tuple<prog_config_t, selection_t*, crossover_t*, mutation_t*> make_config(const int argc, const char** argv)
{
argparse::argument_parser_t parser;
parser.add_flag("--initial_tree_min").set_default(2).as_type<i32>().set_help("The minimum number of nodes in the initial trees");
parser.add_flag("--initial_tree_max").set_default(6).as_type<i32>().set_help("The maximum number of nodes in the initial trees");
parser.add_flag("--elites").set_default(2).as_type<i32>().set_help("Number of best fitness individuals to keep each generation");
parser.add_flag("--max_generations", "-g").set_default(50).as_type<u32>().set_help("The maximum number of generations to run");
parser.add_flag("--population_size", "-p").set_default(500).as_type<u32>().set_help("The size of the population");
parser.add_flag("--threads", "-t").set_default(0).as_type<u32>().set_help("The number of threads to use");
parser.add_flag("--crossover_rate", "-c").set_default(0.8).as_type<float>().set_help("The rate of crossover");
parser.add_flag("--mutation_rate", "-m").set_default(0.1).as_type<float>().set_help("The rate of mutation");
parser.add_flag("--reproduction_rate", "-r").set_default(0.1).as_type<float>().set_help("The rate of reproduction");
const auto mode = parser.add_subparser("mode")->set_help("Select the mode to run the program in.");
mode->add_parser("default");
const auto manual = mode->add_parser("manual");
const auto selection_parser = manual->add_subparser("selection_type")->set_help("Select the type of selection to use.");
const auto tournament_selection = selection_parser->add_parser("tournament");
tournament_selection->add_flag("--tournament_size", "-s").set_default(5).as_type<i32>().set_help("The size of the tournament");
selection_parser->add_parser("best");
selection_parser->add_parser("worst");
selection_parser->add_parser("roulette", "fitness");
const auto crossover_parser = manual->add_subparser("crossover_type")->set_help("Select the type of crossover to use.");
const auto subtree_crossover = crossover_parser->add_parser("subtree_crossover");
setup_crossover_parser(*subtree_crossover);
const auto one_point_crossover = crossover_parser->add_parser("one_point_crossover");
setup_crossover_parser(*one_point_crossover);
const auto advanced_crossover = crossover_parser->add_parser("advanced_crossover");
setup_crossover_parser(*advanced_crossover);
const auto mutation_parser = manual->add_subparser("mutation_type")->set_help("Select the type of mutation to use.");
const auto single_point_mutation = mutation_parser->add_parser("single_point_mutation");
single_point_mutation->add_flag("--replacement_min_depth").set_default(2).as_type<u32>().set_help("Minimum depth of the generated replacement tree");
single_point_mutation->add_flag("--replacement_max_depth").set_default(2).as_type<u32>().set_help("Maximum depth of the generated replacement tree");
single_point_mutation->add_positional("generator").set_choices("grow", "full").set_default("grow");
const auto advanced_mutation = mutation_parser->add_parser("advanced_mutation");
advanced_mutation->add_flag("--replacement_min_depth").set_default(2).as_type<u32>().set_help("Minimum depth of the generated replacement tree");
advanced_mutation->add_flag("--replacement_max_depth").set_default(2).as_type<u32>().set_help("Maximum depth of the generated replacement tree");
advanced_mutation->add_positional("generator").set_choices("grow", "full").set_default("grow");
auto args = parser.parse(argc, argv);
auto config = prog_config_t()
.set_initial_min_tree_size(args.get<i32>("initial_tree_min"))
.set_initial_max_tree_size(args.get<i32>("initial_tree_max"))
.set_elite_count(args.get<i32>("elites"))
.set_crossover_chance(args.get<float>("crossover_rate"))
.set_mutation_chance(args.get<float>("mutation_rate"))
.set_reproduction_chance(args.get<float>("reproduction_rate"))
.set_max_generations(args.get<u32>("max_generations"))
.set_pop_size(args.get<u32>("population_size"))
.set_thread_count(args.get<u32>("threads"));
thread_local select_tournament_t s_tournament_selection;
thread_local select_best_t s_best_selection;
thread_local select_worst_t s_worst_selection;
thread_local select_fitness_proportionate_t s_roulette_selection;
thread_local subtree_crossover_t s_subtree_crossover;
thread_local one_point_crossover_t s_one_point_crossover;
thread_local advanced_crossover_t s_advanced_crossover;
thread_local mutation_t s_single_point_mutation;
thread_local advanced_mutation_t s_advanced_mutation;
if (args.get("mode") == "default")
return {config, &s_tournament_selection, &s_subtree_crossover, &s_advanced_mutation};
if (args.get("mode") == "manual")
{
}
}
}
#endif //CONFIG_FROM_ARGS_H