blt-gp/src/transformers.cpp

567 lines
27 KiB
C++

/*
* <Short Description>
* Copyright (C) 2024 Brett Terpstra
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
#include <blt/gp/transformers.h>
#include <blt/gp/program.h>
#include <blt/std/ranges.h>
#include <blt/std/utility.h>
#include <algorithm>
#include <blt/std/memory.h>
#include <blt/profiling/profiler_v2.h>
#include <random>
namespace blt::gp
{
#if BLT_DEBUG_LEVEL >= 2 || defined(BLT_TRACK_ALLOCATIONS)
std::atomic_uint64_t mutate_point_counter = 0;
std::atomic_uint64_t mutate_expression_counter = 0;
std::atomic_uint64_t mutate_adjust_counter = 0;
std::atomic_uint64_t mutate_sub_func_counter = 0;
std::atomic_uint64_t mutate_jump_counter = 0;
std::atomic_uint64_t mutate_copy_counter = 0;
inline void print_mutate_stats()
{
std::cerr << "Mutation statistics (Total: " << (mutate_point_counter + mutate_expression_counter + mutate_adjust_counter +
mutate_sub_func_counter + mutate_jump_counter + mutate_copy_counter) << "):" << std::endl;
std::cerr << "\tSuccessful Point Mutations: " << mutate_point_counter << std::endl;
std::cerr << "\tSuccessful Expression Mutations: " << mutate_expression_counter << std::endl;
std::cerr << "\tSuccessful Adjust Mutations: " << mutate_adjust_counter << std::endl;
std::cerr << "\tSuccessful Sub Func Mutations: " << mutate_sub_func_counter << std::endl;
std::cerr << "\tSuccessful Jump Mutations: " << mutate_jump_counter << std::endl;
std::cerr << "\tSuccessful Copy Mutations: " << mutate_copy_counter << std::endl;
}
#ifdef BLT_TRACK_ALLOCATIONS
struct run_me_baby
{
~run_me_baby()
{
print_mutate_stats();
}
};
run_me_baby this_will_run_when_program_exits;
#endif
#endif
grow_generator_t grow_generator;
mutation_t::config_t::config_t(): generator(grow_generator)
{
}
bool crossover_t::apply(gp_program& program, const tree_t& p1, const tree_t& p2, tree_t& c1, tree_t& c2) // NOLINT
{
if (p1.size() < config.min_tree_size || p2.size() < config.min_tree_size)
return false;
std::optional<crossover_point_t> point;
if (config.traverse)
point = get_crossover_point_traverse(p1, p2);
else
point = get_crossover_point(p1, p2);
if (!point)
return false;
// TODO: more crossover!
switch (program.get_random().get_u32(0, 2))
{
case 0:
case 1:
c1.swap_subtrees(point->p1_crossover_point, c2, point->p2_crossover_point);
break;
default:
#if BLT_DEBUG_LEVEL > 0
BLT_ABORT("This place should be unreachable!");
#else
BLT_UNREACHABLE;
#endif
}
#if BLT_DEBUG_LEVEL >= 2
if (!c1.check(detail::debug::context_ptr) || !c2.check(detail::debug::context_ptr))
throw std::runtime_error("Tree check failed");
#endif
return true;
}
std::optional<crossover_t::crossover_point_t> crossover_t::get_crossover_point(const tree_t& c1,
const tree_t& c2) const
{
auto first = c1.select_subtree(config.terminal_chance);
auto second = c2.select_subtree(first.type, config.max_crossover_tries, config.terminal_chance);
if (!second)
return {};
return {{first, *second}};
}
std::optional<crossover_t::crossover_point_t> crossover_t::get_crossover_point_traverse(const tree_t& c1,
const tree_t& c2) const
{
auto c1_point_o = get_point_traverse_retry(c1, {});
if (!c1_point_o)
return {};
auto c2_point_o = get_point_traverse_retry(c2, c1_point_o->type);
if (!c2_point_o)
return {};
return {{*c1_point_o, *c2_point_o}};
}
std::optional<tree_t::subtree_point_t> crossover_t::get_point_traverse_retry(const tree_t& t, const std::optional<type_id> type) const
{
if (type)
return t.select_subtree_traverse(*type, config.max_crossover_tries, config.terminal_chance, config.depth_multiplier);
return t.select_subtree_traverse(config.terminal_chance, config.depth_multiplier);
}
bool mutation_t::apply(gp_program& program, const tree_t&, tree_t& c)
{
// TODO: options for this?
mutate_point(program, c, c.select_subtree());
return true;
}
size_t mutation_t::mutate_point(gp_program& program, tree_t& c, const tree_t::subtree_point_t node) const
{
auto& new_tree = tree_t::get_thread_local(program);
config.generator.get().generate(new_tree, {program, node.type, config.replacement_min_depth, config.replacement_max_depth});
c.replace_subtree(node, new_tree);
// this will check to make sure that the tree is in a correct and executable state. it requires that the evaluation is context free!
#if BLT_DEBUG_LEVEL >= 2
if (!c.check(detail::debug::context_ptr))
{
print_mutate_stats();
throw std::runtime_error("Mutate Point tree check failed");
}
#endif
#if defined(BLT_TRACK_ALLOCATIONS) || BLT_DEBUG_LEVEL >= 2
++mutate_point_counter;
#endif
return node.pos + new_tree.size();
}
bool advanced_mutation_t::apply(gp_program& program, [[maybe_unused]] const tree_t& p, tree_t& c)
{
for (size_t c_node = 0; c_node < c.size(); c_node++)
{
#if BLT_DEBUG_LEVEL >= 2
auto c_copy = c;
#endif
if (!program.get_random().choice(per_node_mutation_chance / static_cast<double>(c.size())))
continue;
// select an operator to apply
auto selected_point = static_cast<i32>(mutation_operator::COPY);
auto choice = program.get_random().get_double();
for (const auto& [index, value] : enumerate(mutation_operator_chances))
{
if (choice <= value)
{
selected_point = static_cast<i32>(index);
break;
}
}
switch (static_cast<mutation_operator>(selected_point))
{
case mutation_operator::EXPRESSION:
c_node += mutate_point(program, c, c.subtree_from_point(static_cast<ptrdiff_t>(c_node)));
#if BLT_TRACK_ALLOCATIONS || BLT_DEBUG_LEVEL >= 2
++mutate_expression_counter;
#endif
break;
case mutation_operator::ADJUST:
{
// this is going to be evil >:3
const auto& node = c.get_operator(c_node);
if (!node.is_value())
{
auto& current_func_info = program.get_operator_info(node.id());
operator_id random_replacement = program.get_random().select(
program.get_type_non_terminals(current_func_info.return_type.id));
auto& replacement_func_info = program.get_operator_info(random_replacement);
// cache memory used for offset data.
thread_local tracked_vector<tree_t::child_t> children_data;
children_data.clear();
c.find_child_extends(children_data, c_node, current_func_info.argument_types.size());
for (const auto& [index, val] : blt::enumerate(replacement_func_info.argument_types))
{
// need to generate replacement.
if (index < current_func_info.argument_types.size() && val.id != current_func_info.argument_types[index].id)
{
// TODO: new config?
auto& tree = tree_t::get_thread_local(program);
config.generator.get().generate(tree,
{program, val.id, config.replacement_min_depth, config.replacement_max_depth});
auto& [child_start, child_end] = children_data[children_data.size() - 1 - index];
c.replace_subtree(c.subtree_from_point(child_start), child_end, tree);
// shift over everybody after.
if (index > 0)
{
// don't need to update if the index is the last
for (auto& new_child : iterate(children_data.end() - static_cast<ptrdiff_t>(index),
children_data.end()))
{
// remove the old tree size, then add the new tree size to get the correct positions.
new_child.start =
new_child.start - (child_end - child_start) +
static_cast<ptrdiff_t>(tree.size());
new_child.end =
new_child.end - (child_end - child_start) + static_cast<ptrdiff_t>(tree.size());
}
}
child_end = static_cast<ptrdiff_t>(child_start + tree.size());
}
}
if (current_func_info.argc.argc > replacement_func_info.argc.argc)
{
auto end_index = children_data[(current_func_info.argc.argc - replacement_func_info.argc.argc) - 1].end;
auto start_index = children_data.begin()->start;
c.delete_subtree(tree_t::subtree_point_t(start_index), end_index);
}
else if (current_func_info.argc.argc == replacement_func_info.argc.argc)
{
// exactly enough args
// return types should have been replaced if needed. this part should do nothing?
}
else
{
// not enough args
size_t start_index = c_node + 1;
// size_t total_bytes_after = c.total_value_bytes(start_index);
// TODO: transactions?
// auto move = c.temporary_move(total_bytes_after);
for (ptrdiff_t i = static_cast<ptrdiff_t>(replacement_func_info.argc.argc) - 1;
i >= current_func_info.argc.argc; i--)
{
auto& tree = tree_t::get_thread_local(program);
config.generator.get().generate(tree,
{
program, replacement_func_info.argument_types[i].id, config.replacement_min_depth,
config.replacement_max_depth
});
start_index = c.insert_subtree(tree_t::subtree_point_t(static_cast<ptrdiff_t>(start_index)), tree);
}
}
// now finally update the type.
c.modify_operator(c_node, random_replacement, replacement_func_info.return_type);
}
#if BLT_DEBUG_LEVEL >= 2
if (!c.check(detail::debug::context_ptr))
{
std::cout << "Parent: " << std::endl;
c_copy.print(std::cout, false, false);
std::cout << "Child Values:" << std::endl;
c.print(std::cout, false, false);
std::cout << std::endl;
print_mutate_stats();
BLT_ABORT("Adjust Tree Check Failed.");
}
#endif
#if defined(BLT_TRACK_ALLOCATIONS) || BLT_DEBUG_LEVEL >= 2
++mutate_adjust_counter;
#endif
}
break;
case mutation_operator::SUB_FUNC:
{
auto& current_func_info = program.get_operator_info(c.get_operator(c_node).id());
// need to:
// mutate the current function.
// current function is moved to one of the arguments.
// other arguments are generated.
// get a replacement which returns the same type.
auto& non_terminals = program.get_type_non_terminals(current_func_info.return_type.id);
if (non_terminals.empty())
continue;
operator_id random_replacement = program.get_random().select(non_terminals);
size_t arg_position = 0;
do
{
auto& replacement_func_info = program.get_operator_info(random_replacement);
for (const auto& [index, v] : enumerate(replacement_func_info.argument_types))
{
if (v.id == current_func_info.return_type.id)
{
arg_position = index;
goto exit;
}
}
random_replacement = program.get_random().select(program.get_type_non_terminals(current_func_info.return_type.id));
}
while (true);
exit:
auto& replacement_func_info = program.get_operator_info(random_replacement);
auto new_argc = replacement_func_info.argc.argc;
// replacement function should be valid. let's make a copy of us.
auto current_end = c.find_endpoint(static_cast<ptrdiff_t>(c_node));
// size_t for_bytes = c.total_value_bytes(c_node, current_end);
// size_t after_bytes = c.total_value_bytes(current_end);
auto size = current_end - c_node;
// auto combined_ptr = get_thread_pointer_for_size<struct SUB_FUNC_FOR>(for_bytes + after_bytes);
// vals.copy_to(combined_ptr, for_bytes + after_bytes);
// vals.pop_bytes(static_cast<ptrdiff_t>(for_bytes + after_bytes));
size_t start_index = c_node;
for (ptrdiff_t i = new_argc - 1; i > static_cast<ptrdiff_t>(arg_position); i--)
{
auto& tree = tree_t::get_thread_local(program);
config.generator.get().generate(tree,
{
program, replacement_func_info.argument_types[i].id, config.replacement_min_depth,
config.replacement_max_depth
});
start_index = c.insert_subtree(tree_t::subtree_point_t(static_cast<ptrdiff_t>(start_index)), tree);
}
start_index += size;
// vals.copy_from(combined_ptr, for_bytes);
for (blt::ptrdiff_t i = static_cast<blt::ptrdiff_t>(arg_position) - 1; i >= 0; i--)
{
auto& tree = tree_t::get_thread_local(program);
config.generator.get().generate(tree,
{
program, replacement_func_info.argument_types[i].id, config.replacement_min_depth,
config.replacement_max_depth
});
start_index = c.insert_subtree(tree_t::subtree_point_t(static_cast<ptrdiff_t>(start_index)), tree);
}
// vals.copy_from(combined_ptr + for_bytes, after_bytes);
c.insert_operator(c_node, {
program.get_typesystem().get_type(replacement_func_info.return_type).size(),
random_replacement,
program.is_operator_ephemeral(random_replacement),
program.get_operator_flags(random_replacement)
});
#if BLT_DEBUG_LEVEL >= 2
if (!c.check(detail::debug::context_ptr))
{
std::cout << "Parent: " << std::endl;
p.print(std::cout, false, false);
std::cout << "Child:" << std::endl;
c.print(std::cout, false, false);
std::cout << std::endl;
print_mutate_stats();
BLT_ABORT("SUB_FUNC Tree Check Failed.");
}
#endif
#if defined(BLT_TRACK_ALLOCATIONS) || BLT_DEBUG_LEVEL >= 2
++mutate_sub_func_counter;
#endif
}
break;
case mutation_operator::JUMP_FUNC:
{
auto& info = program.get_operator_info(c.get_operator(c_node).id());
size_t argument_index = -1ul;
for (const auto& [index, v] : enumerate(info.argument_types))
{
if (v.id == info.return_type.id)
{
argument_index = index;
break;
}
}
if (argument_index == -1ul)
continue;
thread_local tracked_vector<tree_t::child_t> child_data;
child_data.clear();
c.find_child_extends(child_data, c_node, info.argument_types.size());
auto child_index = child_data.size() - 1 - argument_index;
const auto child = child_data[child_index];
thread_local tree_t child_tree{program};
c.copy_subtree(tree_t::subtree_point_t(child.start), child.end, child_tree);
c.delete_subtree(tree_t::subtree_point_t(static_cast<ptrdiff_t>(c_node)));
c.insert_subtree(tree_t::subtree_point_t(static_cast<ptrdiff_t>(c_node)), child_tree);
child_tree.clear(program);
// auto for_bytes = c.total_value_bytes(child.start, child.end);
// auto after_bytes = c.total_value_bytes(child_data.back().end);
// auto storage_ptr = get_thread_pointer_for_size<struct jump_func>(for_bytes + after_bytes);
// vals.copy_to(storage_ptr + for_bytes, after_bytes);
// vals.pop_bytes(static_cast<blt::ptrdiff_t>(after_bytes));
//
// for (auto i = static_cast<blt::ptrdiff_t>(child_data.size() - 1); i > static_cast<blt::ptrdiff_t>(child_index); i--)
// {
// auto& cc = child_data[i];
// auto bytes = c.total_value_bytes(cc.start, cc.end);
// vals.pop_bytes(static_cast<blt::ptrdiff_t>(bytes));
// ops.erase(ops.begin() + cc.start, ops.begin() + cc.end);
// }
// vals.copy_to(storage_ptr, for_bytes);
// vals.pop_bytes(static_cast<blt::ptrdiff_t>(for_bytes));
// for (auto i = static_cast<blt::ptrdiff_t>(child_index - 1); i >= 0; i--)
// {
// auto& cc = child_data[i];
// auto bytes = c.total_value_bytes(cc.start, cc.end);
// vals.pop_bytes(static_cast<blt::ptrdiff_t>(bytes));
// ops.erase(ops.begin() + cc.start, ops.begin() + cc.end);
// }
// ops.erase(ops.begin() + static_cast<blt::ptrdiff_t>(c_node));
// vals.copy_from(storage_ptr, for_bytes + after_bytes);
#if BLT_DEBUG_LEVEL >= 2
if (!c.check(detail::debug::context_ptr))
{
std::cout << "Parent: " << std::endl;
p.print(std::cout, false, false, false, static_cast<ptrdiff_t>(c_node));
std::cout << "Child Values:" << std::endl;
c.print(std::cout, false, false);
std::cout << std::endl;
BLT_ERROR("Failed at mutation index %lu/%lu", c_node, c.size());
print_mutate_stats();
BLT_ABORT("JUMP_FUNC Tree Check Failed.");
}
#endif
#if defined(BLT_TRACK_ALLOCATIONS) || BLT_DEBUG_LEVEL >= 2
++mutate_jump_counter;
#endif
}
break;
case mutation_operator::COPY:
{
auto& info = program.get_operator_info(c.get_operator(c_node).id());
if (c.get_operator(c_node).is_value())
continue;
thread_local tracked_vector<size_t> potential_indexes;
potential_indexes.clear();
const auto from_index = program.get_random().get_u64(0, info.argument_types.size());
for (const auto [index, type] : enumerate(info.argument_types))
{
if (index == from_index)
continue;
if (info.argument_types[from_index] == type)
potential_indexes.push_back(index);
}
if (potential_indexes.empty())
continue;
const auto to_index = program.get_random().select(potential_indexes);
thread_local tracked_vector<tree_t::child_t> child_data;
child_data.clear();
c.find_child_extends(child_data, c_node, info.argument_types.size());
const auto child_from_index = child_data.size() - 1 - from_index;
const auto child_to_index = child_data.size() - 1 - to_index;
const auto& [from_start, from_end] = child_data[child_from_index];
const auto& [to_start, to_end] = child_data[child_to_index];
thread_local tree_t copy_tree{program};
c.copy_subtree(tree_t::subtree_point_t{from_start}, from_end, copy_tree);
c.replace_subtree(tree_t::subtree_point_t{to_start}, to_end, copy_tree);
copy_tree.clear(program);
#if BLT_DEBUG_LEVEL >= 2
if (!c.check(detail::debug::context_ptr))
{
std::cout << "Parent: " << std::endl;
p.print(std::cout, false, false);
std::cout << "Child Values:" << std::endl;
c.print(std::cout, false, false);
std::cout << std::endl;
print_mutate_stats();
BLT_ABORT("COPY Tree Check Failed.");
}
#endif
#if defined(BLT_TRACK_ALLOCATIONS) || BLT_DEBUG_LEVEL >= 2
++mutate_copy_counter;
#endif
// size_t from_bytes = c.total_value_bytes(from_child.start, from_child.end);
// size_t after_from_bytes = c.total_value_bytes(from_child.end);
// size_t to_bytes = c.total_value_bytes(to_child.start, to_child.end);
// size_t after_to_bytes = c.total_value_bytes(to_child.end);
//
// auto after_bytes = std::max(after_from_bytes, after_to_bytes);
//
// auto from_ptr = get_thread_pointer_for_size<struct copy>(from_bytes);
// auto after_ptr = get_thread_pointer_for_size<struct copy_after>(after_bytes);
//
// vals.copy_to(after_ptr, after_from_bytes);
// vals.pop_bytes(static_cast<blt::ptrdiff_t>(after_from_bytes));
// vals.copy_to(from_ptr, from_bytes);
// vals.copy_from(after_ptr, after_from_bytes);
//
// vals.copy_to(after_ptr, after_to_bytes);
// vals.pop_bytes(static_cast<blt::ptrdiff_t>(after_to_bytes + to_bytes));
//
// vals.copy_from(from_ptr, from_bytes);
// vals.copy_from(after_ptr, after_to_bytes);
//
// static thread_local tracked_vector<op_container_t> op_copy;
// op_copy.clear();
// op_copy.insert(op_copy.begin(), ops.begin() + from_child.start, ops.begin() + from_child.end);
//
// ops.erase(ops.begin() + to_child.start, ops.begin() + to_child.end);
// ops.insert(ops.begin() + to_child.start, op_copy.begin(), op_copy.end());
}
break;
case mutation_operator::END:
default:
#if BLT_DEBUG_LEVEL > 1
BLT_ABORT("You shouldn't be able to get here!");
#else
BLT_UNREACHABLE;
#endif
}
}
#if BLT_DEBUG_LEVEL >= 2
if (!c.check(detail::debug::context_ptr))
{
std::cout << "Parent: " << std::endl;
p.print(std::cout, false, false);
std::cout << "Child Values:" << std::endl;
c.print(std::cout, false, false);
std::cout << std::endl;
BLT_ABORT("Advanced Mutation Tree Check Failed.");
}
#endif
return true;
}
}