things are working now. missing fitness reset was partial problem

dev-0.2.1
Brett 2024-12-24 13:06:09 -05:00
parent e1083426fc
commit 3e0fe06017
8 changed files with 210 additions and 186 deletions

View File

@ -27,7 +27,7 @@ macro(compile_options target_name)
sanitizers(${target_name})
endmacro()
project(blt-gp VERSION 0.2.4)
project(blt-gp VERSION 0.2.5)
include(CTest)

View File

@ -45,6 +45,7 @@ namespace blt::gp::example
};
bool fitness_function(const tree_t& current_tree, fitness_t& fitness, size_t) const;
public:
template <typename SEED>
rice_classification_t(SEED&& seed, const prog_config_t& config): example_base_t{std::forward<SEED>(seed), config}
@ -60,7 +61,7 @@ namespace blt::gp::example
void load_rice_data(std::string_view rice_file_path);
confusion_matrix_t test_individual(const individual_t& individual) const;
[[nodiscard]] confusion_matrix_t test_individual(const individual_t& individual) const;
void execute(const std::string_view rice_file_path)
{
@ -166,6 +167,9 @@ namespace blt::gp::example
std::cout << "\n";
}
auto& get_results() { return results; }
const auto& get_results() const { return results; }
private:
std::vector<rice_record> training_cases;
std::vector<rice_record> testing_cases;

View File

@ -122,7 +122,7 @@ bool blt::gp::example::rice_classification_t::fitness_function(const tree_t& cur
{
for (auto& training_case : training_cases)
{
auto v = current_tree.get_evaluation_value<float>(training_case);
const auto v = current_tree.get_evaluation_value<float>(training_case);
switch (training_case.type)
{
case rice_type_t::Cammeo:
@ -137,7 +137,8 @@ bool blt::gp::example::rice_classification_t::fitness_function(const tree_t& cur
}
fitness.raw_fitness = static_cast<double>(fitness.hits);
fitness.standardized_fitness = fitness.raw_fitness;
fitness.adjusted_fitness = 1.0 - (1.0 / (1.0 + fitness.standardized_fitness));
// fitness.adjusted_fitness = 1.0 - (1.0 / (1.0 + fitness.standardized_fitness));
fitness.adjusted_fitness = fitness.standardized_fitness / static_cast<double>(training_cases.size());
return static_cast<size_t>(fitness.hits) == training_cases.size();
}
@ -169,20 +170,20 @@ void blt::gp::example::rice_classification_t::load_rice_data(const std::string_v
}
}
size_t total_records = c.size() + o.size();
size_t training_size = std::min(total_records / 3, 1000ul);
for (size_t i = 0; i < training_size; i++)
const size_t total_records = c.size() + o.size();
const size_t testing_size = total_records / 3;
for (size_t i = 0; i < testing_size; i++)
{
auto& random = program.get_random();
auto& vec = random.choice() ? c : o;
auto pos = random.get_i64(0, static_cast<i64>(vec.size()));
training_cases.push_back(vec[pos]);
const auto pos = random.get_i64(0, static_cast<i64>(vec.size()));
testing_cases.push_back(vec[pos]);
vec.erase(vec.begin() + pos);
}
testing_cases.insert(testing_cases.end(), c.begin(), c.end());
testing_cases.insert(testing_cases.end(), o.begin(), o.end());
std::shuffle(testing_cases.begin(), testing_cases.end(), program.get_random());
BLT_INFO("Created training set of size %ld, testing set is of size %ld", training_size, testing_cases.size());
training_cases.insert(training_cases.end(), c.begin(), c.end());
training_cases.insert(training_cases.end(), o.begin(), o.end());
std::shuffle(training_cases.begin(), training_cases.end(), program.get_random());
BLT_INFO("Created testing set of size %ld, training set is of size %ld", testing_cases.size(), training_cases.size());
}
blt::gp::confusion_matrix_t blt::gp::example::rice_classification_t::test_individual(const individual_t& individual) const

View File

@ -35,14 +35,14 @@ int main()
.set_elite_count(2)
.set_crossover_chance(0.9)
.set_mutation_chance(0.1)
.set_reproduction_chance(0.25)
.set_reproduction_chance(0.0)
.set_max_generations(50)
.set_pop_size(500)
.set_thread_count(16);
// example on how you can change the mutation config
blt::gp::mutation_t::config_t mut_config{};
mut_config.generator = full_generator;
mut_config.generator = grow_generator;
mut_config.replacement_min_depth = 2;
mut_config.replacement_max_depth = 6;

View File

@ -490,6 +490,7 @@ namespace blt::gp
double sum_of_prob = 0;
for (const auto& [index, ind] : blt::enumerate(current_pop.get_individuals()))
{
ind.fitness = {};
if constexpr (std::is_same_v<LambdaReturn, bool> || std::is_convertible_v<LambdaReturn, bool>)
{
auto result = fitness_function(ind.tree, ind.fitness, index);
@ -565,7 +566,7 @@ namespace blt::gp
{
auto& ind = current_pop.get_individuals()[i];
ind.fitness = {};
if constexpr (std::is_same_v<LambdaReturn, bool> || std::is_convertible_v<LambdaReturn, bool>)
{
auto result = fitness_function(ind.tree, ind.fitness, i);

View File

@ -71,9 +71,14 @@ namespace blt::gp
struct config_t
{
// number of times crossover will try to pick a valid point in the tree. this is purely based on the return type of the operators
blt::u16 max_crossover_tries = 5;
blt::f32 traverse_chance = 0.5;
blt::u32 min_tree_size = 5;
u32 max_crossover_tries = 5;
// if tree have fewer nodes than this number, they will not be considered for crossover
u32 min_tree_size = 3;
// used by the traverse version of get_crossover_point
// at each depth level, what chance do we have to exit with this as our point? or in other words what's the chance we continue traversing
// this is what this option configures.
f32 traverse_chance = 0.5;
// legacy settings:
@ -88,9 +93,9 @@ namespace blt::gp
explicit crossover_t(const config_t& config): config(config)
{}
std::optional<crossover_t::crossover_point_t> get_crossover_point(gp_program& program, const tree_t& c1, const tree_t& c2) const;
std::optional<crossover_point_t> get_crossover_point(gp_program& program, const tree_t& c1, const tree_t& c2) const;
std::optional<crossover_t::crossover_point_t> get_crossover_point_traverse(gp_program& program, const tree_t& c1, const tree_t& c2) const;
std::optional<crossover_point_t> get_crossover_point_traverse(gp_program& program, const tree_t& c1, const tree_t& c2) const;
std::optional<point_info_t> get_point_traverse(gp_program& program, const tree_t& t, std::optional<type_id> type) const;

View File

@ -213,7 +213,7 @@ namespace blt::gp
double raw_fitness = 0;
double standardized_fitness = 0;
double adjusted_fitness = 0;
blt::i64 hits = 0;
i64 hits = 0;
};
struct individual_t

View File

@ -26,19 +26,18 @@
namespace blt::gp
{
grow_generator_t grow_generator;
inline tree_t& get_static_tree_tl(gp_program& program)
{
static thread_local tree_t new_tree{program};
thread_local tree_t new_tree{program};
new_tree.clear(program);
return new_tree;
}
inline blt::size_t accumulate_type_sizes(detail::op_iter_t begin, detail::op_iter_t end)
inline size_t accumulate_type_sizes(detail::op_iter_t begin, detail::op_iter_t end)
{
blt::size_t total = 0;
size_t total = 0;
for (auto it = begin; it != end; ++it)
{
if (it->is_value)
@ -48,16 +47,17 @@ namespace blt::gp
}
template <typename>
blt::u8* get_thread_pointer_for_size(blt::size_t bytes)
u8* get_thread_pointer_for_size(size_t bytes)
{
static thread_local blt::expanding_buffer<blt::u8> buffer;
thread_local expanding_buffer<u8> buffer;
if (bytes > buffer.size())
buffer.resize(bytes);
return buffer.data();
}
mutation_t::config_t::config_t(): generator(grow_generator)
{}
{
}
bool crossover_t::apply(gp_program& program, const tree_t& p1, const tree_t& p2, tree_t& c1, tree_t& c2) // NOLINT
{
@ -67,17 +67,17 @@ namespace blt::gp
auto& c1_ops = c1.get_operations();
auto& c2_ops = c2.get_operations();
auto point = get_crossover_point(program, p1, p2);
const auto point = get_crossover_point(program, p1, p2);
if (!point)
return false;
auto selection = program.get_random().get_u32(0, 2);
const auto selection = program.get_random().get_u32(0, 2);
// Used to make copies of operators. Statically stored for memory caching purposes.
// Thread local as this function cannot have external modifications / will be called from multiple threads.
static thread_local tracked_vector<op_container_t> c1_operators;
static thread_local tracked_vector<op_container_t> c2_operators;
thread_local tracked_vector<op_container_t> c1_operators;
thread_local tracked_vector<op_container_t> c2_operators;
c1_operators.clear();
c2_operators.clear();
@ -179,12 +179,12 @@ namespace blt::gp
auto& c1_ops = c1.get_operations();
auto& c2_ops = c2.get_operations();
blt::size_t crossover_point = program.get_random().get_size_t(1ul, c1_ops.size());
size_t crossover_point = program.get_random().get_size_t(1ul, c1_ops.size());
while (config.avoid_terminals && program.get_operator_info(c1_ops[crossover_point].id).argc.is_terminal())
crossover_point = program.get_random().get_size_t(1ul, c1_ops.size());
blt::size_t attempted_point = 0;
size_t attempted_point = 0;
const auto& crossover_point_type = program.get_operator_info(c1_ops[crossover_point].id);
operator_info_t* attempted_point_type = nullptr;
@ -215,8 +215,7 @@ namespace blt::gp
}
// should we try again over the whole tree? probably not.
return {};
} else
{
}
attempted_point = program.get_random().get_size_t(1ul, c2_ops.size());
attempted_point_type = &program.get_operator_info(c2_ops[attempted_point].id);
if (config.avoid_terminals && attempted_point_type->argc.is_terminal())
@ -225,7 +224,7 @@ namespace blt::gp
break;
counter++;
}
} while (true);
while (true);
return crossover_point_t{static_cast<blt::ptrdiff_t>(crossover_point), static_cast<blt::ptrdiff_t>(attempted_point)};
}
@ -233,13 +232,12 @@ namespace blt::gp
std::optional<crossover_t::crossover_point_t> crossover_t::get_crossover_point_traverse(gp_program& program, const tree_t& c1,
const tree_t& c2) const
{
auto c1_point_o = get_point_traverse_retry(program, c1, {});
const auto c1_point_o = get_point_traverse_retry(program, c1, {});
if (!c1_point_o)
return {};
auto c2_point_o = get_point_traverse_retry(program, c2, c1_point_o->type_operator_info.return_type);
const auto c2_point_o = get_point_traverse_retry(program, c2, c1_point_o->type_operator_info.return_type);
if (!c2_point_o)
return {};
return {{c1_point_o->point, c2_point_o->point}};
}
@ -256,7 +254,7 @@ namespace blt::gp
{
auto& random = program.get_random();
blt::ptrdiff_t point = 0;
ptrdiff_t point = 0;
while (true)
{
auto& current_op_type = program.get_operator_info(t.get_operations()[point].id);
@ -269,13 +267,13 @@ namespace blt::gp
// traverse to a child
if (random.choice(config.traverse_chance))
{
auto args = current_op_type.argc.argc;
auto argument = random.get_size_t(0, args);
const auto args = current_op_type.argc.argc;
const auto argument = random.get_size_t(0, args);
// move to the first child
point += 1;
// loop through all the children we wish to skip. The result will be the first node of the next child, becoming the new parent
for (blt::size_t i = 0; i < argument; i++)
for (size_t i = 0; i < argument; i++)
point = t.find_endpoint(program, point);
continue;
@ -288,7 +286,7 @@ namespace blt::gp
std::optional<crossover_t::point_info_t> crossover_t::get_point_traverse_retry(gp_program& program, const tree_t& t,
std::optional<type_id> type) const
{
for (blt::size_t i = 0; i < config.max_crossover_tries; i++)
for (size_t i = 0; i < config.max_crossover_tries; i++)
{
if (auto found = get_point_traverse(program, t, type))
return found;
@ -412,7 +410,8 @@ namespace blt::gp
selected_point = static_cast<blt::i32>(index);
break;
}
} else
}
else
{
if (choice > mutation_operator_chances[index - 1] && choice <= value)
{
@ -516,11 +515,13 @@ namespace blt::gp
vals.pop_bytes(static_cast<blt::ptrdiff_t>(total_bytes_after + total_bytes_for));
vals.copy_from(data, total_bytes_after);
ops.erase(ops.begin() + static_cast<blt::ptrdiff_t>(start_index), ops.begin() + static_cast<blt::ptrdiff_t>(end_index));
} else if (current_func_info.argc.argc == replacement_func_info.argc.argc)
}
else if (current_func_info.argc.argc == replacement_func_info.argc.argc)
{
// exactly enough args
// return types should have been replaced if needed. this part should do nothing?
} else
}
else
{
// not enough args
blt::size_t start_index = c_node + 1;
@ -534,8 +535,10 @@ namespace blt::gp
{
auto& tree = get_static_tree_tl(program);
config.generator.get().generate(tree,
{program, replacement_func_info.argument_types[i].id, config.replacement_min_depth,
config.replacement_max_depth});
{
program, replacement_func_info.argument_types[i].id, config.replacement_min_depth,
config.replacement_max_depth
});
vals.insert(tree.get_values());
ops.insert(ops.begin() + static_cast<blt::ptrdiff_t>(start_index), tree.get_operations().begin(),
tree.get_operations().end());
@ -544,8 +547,10 @@ namespace blt::gp
vals.copy_from(data, total_bytes_after);
}
// now finally update the type.
ops[c_node] = {program.get_typesystem().get_type(replacement_func_info.return_type).size(), random_replacement,
program.is_operator_ephemeral(random_replacement)};
ops[c_node] = {
program.get_typesystem().get_type(replacement_func_info.return_type).size(), random_replacement,
program.is_operator_ephemeral(random_replacement)
};
}
#if BLT_DEBUG_LEVEL >= 2
if (!c.check(program, nullptr))
@ -587,7 +592,8 @@ namespace blt::gp
}
}
random_replacement = program.get_random().select(program.get_type_non_terminals(current_func_info.return_type.id));
} while (true);
}
while (true);
exit:
auto& replacement_func_info = program.get_operator_info(random_replacement);
auto new_argc = replacement_func_info.argc.argc;
@ -607,8 +613,10 @@ namespace blt::gp
{
auto& tree = get_static_tree_tl(program);
config.generator.get().generate(tree,
{program, replacement_func_info.argument_types[i].id, config.replacement_min_depth,
config.replacement_max_depth});
{
program, replacement_func_info.argument_types[i].id, config.replacement_min_depth,
config.replacement_max_depth
});
blt::size_t total_bytes_for = tree.total_value_bytes();
vals.copy_from(tree.get_values(), total_bytes_for);
ops.insert(ops.begin() + static_cast<blt::ptrdiff_t>(start_index), tree.get_operations().begin(),
@ -621,8 +629,10 @@ namespace blt::gp
{
auto& tree = get_static_tree_tl(program);
config.generator.get().generate(tree,
{program, replacement_func_info.argument_types[i].id, config.replacement_min_depth,
config.replacement_max_depth});
{
program, replacement_func_info.argument_types[i].id, config.replacement_min_depth,
config.replacement_max_depth
});
blt::size_t total_bytes_for = tree.total_value_bytes();
vals.copy_from(tree.get_values(), total_bytes_for);
ops.insert(ops.begin() + static_cast<blt::ptrdiff_t>(start_index), tree.get_operations().begin(),
@ -632,8 +642,10 @@ namespace blt::gp
vals.copy_from(combined_ptr + for_bytes, after_bytes);
ops.insert(ops.begin() + static_cast<blt::ptrdiff_t>(c_node),
{program.get_typesystem().get_type(replacement_func_info.return_type).size(),
random_replacement, program.is_operator_ephemeral(random_replacement)});
{
program.get_typesystem().get_type(replacement_func_info.return_type).size(),
random_replacement, program.is_operator_ephemeral(random_replacement)
});
#if BLT_DEBUG_LEVEL >= 2
if (!c.check(program, nullptr))
@ -743,7 +755,8 @@ namespace blt::gp
{
from = pt;
to = pf;
} else
}
else
{
from = pf;
to = pt;