begin working on config from args

dev
Brett 2025-05-26 21:49:59 -04:00
parent ffeed055e0
commit f14545a4c2
4 changed files with 123 additions and 2 deletions

View File

@ -27,7 +27,7 @@ macro(compile_options target_name)
sanitizers(${target_name}) sanitizers(${target_name})
endmacro() endmacro()
project(blt-gp VERSION 0.5.36) project(blt-gp VERSION 0.5.37)
include(CTest) include(CTest)

View File

@ -170,10 +170,11 @@ namespace blt::gp
class advanced_crossover_t : public crossover_t class advanced_crossover_t : public crossover_t
{ {
public:
advanced_crossover_t(): crossover_t(config_t{}) advanced_crossover_t(): crossover_t(config_t{})
{ {
} }
public:
bool apply(gp_program& program, const tree_t& p1, const tree_t& p2, tree_t& c1, tree_t& c2) override; bool apply(gp_program& program, const tree_t& p1, const tree_t& p2, tree_t& c1, tree_t& c2) override;
}; };

View File

@ -419,6 +419,7 @@ namespace blt::gp
config.replacement_max_depth config.replacement_max_depth
}); });
start_index = c.manipulate().easy_manipulator().insert_subtree(subtree_point_t(static_cast<ptrdiff_t>(start_index)), tree); start_index = c.manipulate().easy_manipulator().insert_subtree(subtree_point_t(static_cast<ptrdiff_t>(start_index)), tree);
tree.clear(program);
} }
start_index += size; start_index += size;
// vals.copy_from(combined_ptr, for_bytes); // vals.copy_from(combined_ptr, for_bytes);
@ -431,6 +432,7 @@ namespace blt::gp
config.replacement_max_depth config.replacement_max_depth
}); });
start_index = c.manipulate().easy_manipulator().insert_subtree(subtree_point_t(static_cast<ptrdiff_t>(start_index)), tree); start_index = c.manipulate().easy_manipulator().insert_subtree(subtree_point_t(static_cast<ptrdiff_t>(start_index)), tree);
tree.clear(program);
} }
// vals.copy_from(combined_ptr + for_bytes, after_bytes); // vals.copy_from(combined_ptr + for_bytes, after_bytes);

118
tests/config_from_args.h Normal file
View File

@ -0,0 +1,118 @@
#pragma once
/*
* Copyright (C) 2024 Brett Terpstra
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
#ifndef CONFIG_FROM_ARGS_H
#define CONFIG_FROM_ARGS_H
#include <blt/parse/argparse_v2.h>
#include <blt/gp/program.h>
namespace blt::gp
{
inline void setup_crossover_parser(argparse::argument_parser_t& parser)
{
parser.add_flag("--max_crossover_tries").set_default(5u).as_type<u32>().set_help(
"number of times crossover will try to pick a valid point in the tree.");
parser.add_flag("--max_crossover_iterations").set_default(10u).as_type<u32>().set_help(
"how many times the crossover function can fail before we will skip this operation.");
parser.add_flag("--min_tree_size").set_default(5u).as_type<u32>().set_help("the minimum size of"
" the tree to be considered for crossover.");
parser.add_flag("--depth_multiplier").set_default(0.25).as_type<float>().set_help(
"at each depth level, what chance do we have to exit with this as our point?");
parser.add_flag("--terminal_chance").set_default(0.1).as_type<float>().set_help(
"how often should we select terminals over functions. By default, we only allow selection of terminals 10% of the time. "
"This applies to both types of crossover point functions. Traversal will use the parent if it should not pick a terminal.");
parser.add_flag("--traverse").make_flag().set_help(" use traversal to select instead of random point selection.");
}
inline std::tuple<prog_config_t, selection_t*, crossover_t*, mutation_t*> make_config(const int argc, const char** argv)
{
argparse::argument_parser_t parser;
parser.add_flag("--initial_tree_min").set_default(2).as_type<i32>().set_help("The minimum number of nodes in the initial trees");
parser.add_flag("--initial_tree_max").set_default(6).as_type<i32>().set_help("The maximum number of nodes in the initial trees");
parser.add_flag("--elites").set_default(2).as_type<i32>().set_help("Number of best fitness individuals to keep each generation");
parser.add_flag("--max_generations", "-g").set_default(50).as_type<u32>().set_help("The maximum number of generations to run");
parser.add_flag("--population_size", "-p").set_default(500).as_type<u32>().set_help("The size of the population");
parser.add_flag("--threads", "-t").set_default(0).as_type<u32>().set_help("The number of threads to use");
parser.add_flag("--crossover_rate", "-c").set_default(0.8).as_type<float>().set_help("The rate of crossover");
parser.add_flag("--mutation_rate", "-m").set_default(0.1).as_type<float>().set_help("The rate of mutation");
parser.add_flag("--reproduction_rate", "-r").set_default(0.1).as_type<float>().set_help("The rate of reproduction");
const auto mode = parser.add_subparser("mode")->set_help("Select the mode to run the program in.");
mode->add_parser("default");
const auto manual = mode->add_parser("manual");
const auto selection_parser = manual->add_subparser("selection_type")->set_help("Select the type of selection to use.");
const auto tournament_selection = selection_parser->add_parser("tournament");
tournament_selection->add_flag("--tournament_size", "-s").set_default(5).as_type<i32>().set_help("The size of the tournament");
selection_parser->add_parser("best");
selection_parser->add_parser("worst");
selection_parser->add_parser("roulette", "fitness");
const auto crossover_parser = manual->add_subparser("crossover_type")->set_help("Select the type of crossover to use.");
const auto subtree_crossover = crossover_parser->add_parser("subtree_crossover");
setup_crossover_parser(*subtree_crossover);
const auto one_point_crossover = crossover_parser->add_parser("one_point_crossover");
setup_crossover_parser(*one_point_crossover);
const auto advanced_crossover = crossover_parser->add_parser("advanced_crossover");
setup_crossover_parser(*advanced_crossover);
const auto mutation_parser = manual->add_subparser("mutation_type")->set_help("Select the type of mutation to use.");
const auto single_point_mutation = mutation_parser->add_parser("single_point_mutation");
single_point_mutation->add_flag("--replacement_min_depth").set_default(2).as_type<u32>().set_help("Minimum depth of the generated replacement tree");
single_point_mutation->add_flag("--replacement_max_depth").set_default(2).as_type<u32>().set_help("Maximum depth of the generated replacement tree");
single_point_mutation->add_positional("generator").set_choices("grow", "full").set_default("grow");
const auto advanced_mutation = mutation_parser->add_parser("advanced_mutation");
advanced_mutation->add_flag("--replacement_min_depth").set_default(2).as_type<u32>().set_help("Minimum depth of the generated replacement tree");
advanced_mutation->add_flag("--replacement_max_depth").set_default(2).as_type<u32>().set_help("Maximum depth of the generated replacement tree");
advanced_mutation->add_positional("generator").set_choices("grow", "full").set_default("grow");
auto args = parser.parse(argc, argv);
auto config = prog_config_t()
.set_initial_min_tree_size(args.get<i32>("initial_tree_min"))
.set_initial_max_tree_size(args.get<i32>("initial_tree_max"))
.set_elite_count(args.get<i32>("elites"))
.set_crossover_chance(args.get<float>("crossover_rate"))
.set_mutation_chance(args.get<float>("mutation_rate"))
.set_reproduction_chance(args.get<float>("reproduction_rate"))
.set_max_generations(args.get<u32>("max_generations"))
.set_pop_size(args.get<u32>("population_size"))
.set_thread_count(args.get<u32>("threads"));
thread_local select_tournament_t s_tournament_selection;
thread_local select_best_t s_best_selection;
thread_local select_worst_t s_worst_selection;
thread_local select_fitness_proportionate_t s_roulette_selection;
thread_local subtree_crossover_t s_subtree_crossover;
thread_local one_point_crossover_t s_one_point_crossover;
thread_local advanced_crossover_t s_advanced_crossover;
thread_local mutation_t s_single_point_mutation;
thread_local advanced_mutation_t s_advanced_mutation;
if (args.get("mode") == "default")
return {config, &s_tournament_selection, &s_subtree_crossover, &s_advanced_mutation};
if (args.get("mode") == "manual")
{
}
}
}
#endif //CONFIG_FROM_ARGS_H