working on transformers

main
Brett 2025-04-18 13:34:32 -04:00
parent b9c535f6c9
commit 487f771377
8 changed files with 121 additions and 19 deletions

View File

@ -5,6 +5,8 @@
<option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=CppDeprecatedOverridenMethod/@EntryIndexedValue" value="WARNING" type="string" /> <option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=CppDeprecatedOverridenMethod/@EntryIndexedValue" value="WARNING" type="string" />
<option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=CppRedundantFwdClassOrEnumSpecifier/@EntryIndexedValue" value="SUGGESTION" type="string" /> <option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=CppRedundantFwdClassOrEnumSpecifier/@EntryIndexedValue" value="SUGGESTION" type="string" />
<option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=CppRedundantQualifierADL/@EntryIndexedValue" value="DO_NOT_SHOW" type="string" /> <option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=CppRedundantQualifierADL/@EntryIndexedValue" value="DO_NOT_SHOW" type="string" />
<option name="/Default/CodeStyle/Generate/=CppDefinitions/@KeyIndexDefined" value="true" type="bool" />
<option name="/Default/CodeStyle/Generate/=CppDefinitions/Options/=GenerateInlineDefinitions/@EntryIndexedValue" value="False" type="string" />
<option name="/Default/CodeStyle/Naming/CppNaming/Rules/=Class_0020and_0020struct_0020fields/@EntryIndexedValue" value="&lt;NamingElement Priority=&quot;11&quot;&gt;&lt;Descriptor Static=&quot;Indeterminate&quot; Constexpr=&quot;Indeterminate&quot; Const=&quot;Indeterminate&quot; Volatile=&quot;Indeterminate&quot; Accessibility=&quot;NOT_APPLICABLE&quot;&gt;&lt;type Name=&quot;class field&quot; /&gt;&lt;type Name=&quot;struct field&quot; /&gt;&lt;/Descriptor&gt;&lt;Policy Inspect=&quot;True&quot; WarnAboutPrefixesAndSuffixes=&quot;False&quot; Prefix=&quot;m_&quot; Suffix=&quot;&quot; Style=&quot;aa_bb&quot; /&gt;&lt;/NamingElement&gt;" type="string" /> <option name="/Default/CodeStyle/Naming/CppNaming/Rules/=Class_0020and_0020struct_0020fields/@EntryIndexedValue" value="&lt;NamingElement Priority=&quot;11&quot;&gt;&lt;Descriptor Static=&quot;Indeterminate&quot; Constexpr=&quot;Indeterminate&quot; Const=&quot;Indeterminate&quot; Volatile=&quot;Indeterminate&quot; Accessibility=&quot;NOT_APPLICABLE&quot;&gt;&lt;type Name=&quot;class field&quot; /&gt;&lt;type Name=&quot;struct field&quot; /&gt;&lt;/Descriptor&gt;&lt;Policy Inspect=&quot;True&quot; WarnAboutPrefixesAndSuffixes=&quot;False&quot; Prefix=&quot;m_&quot; Suffix=&quot;&quot; Style=&quot;aa_bb&quot; /&gt;&lt;/NamingElement&gt;" type="string" />
<option name="/Default/CodeStyle/Naming/CppNaming/Rules/=Enums/@EntryIndexedValue" value="&lt;NamingElement Priority=&quot;3&quot;&gt;&lt;Descriptor Static=&quot;Indeterminate&quot; Constexpr=&quot;Indeterminate&quot; Const=&quot;Indeterminate&quot; Volatile=&quot;Indeterminate&quot; Accessibility=&quot;NOT_APPLICABLE&quot;&gt;&lt;type Name=&quot;enum&quot; /&gt;&lt;/Descriptor&gt;&lt;Policy Inspect=&quot;True&quot; WarnAboutPrefixesAndSuffixes=&quot;False&quot; Prefix=&quot;&quot; Suffix=&quot;&quot; Style=&quot;AA_BB&quot; /&gt;&lt;/NamingElement&gt;" type="string" /> <option name="/Default/CodeStyle/Naming/CppNaming/Rules/=Enums/@EntryIndexedValue" value="&lt;NamingElement Priority=&quot;3&quot;&gt;&lt;Descriptor Static=&quot;Indeterminate&quot; Constexpr=&quot;Indeterminate&quot; Const=&quot;Indeterminate&quot; Volatile=&quot;Indeterminate&quot; Accessibility=&quot;NOT_APPLICABLE&quot;&gt;&lt;type Name=&quot;enum&quot; /&gt;&lt;/Descriptor&gt;&lt;Policy Inspect=&quot;True&quot; WarnAboutPrefixesAndSuffixes=&quot;False&quot; Prefix=&quot;&quot; Suffix=&quot;&quot; Style=&quot;AA_BB&quot; /&gt;&lt;/NamingElement&gt;" type="string" />
</component> </component>

View File

@ -27,7 +27,7 @@ macro(compile_options target_name)
sanitizers(${target_name}) sanitizers(${target_name})
endmacro() endmacro()
project(blt-gp VERSION 0.5.10) project(blt-gp VERSION 0.5.11)
include(CTest) include(CTest)

View File

@ -760,6 +760,11 @@ namespace blt::gp
return current_stats; return current_stats;
} }
[[nodiscard]] const tracked_vector<population_stats>& get_stats_histories() const
{
return statistic_history;
}
[[nodiscard]] bool is_operator_ephemeral(const operator_id id) const [[nodiscard]] bool is_operator_ephemeral(const operator_id id) const
{ {
return storage.operator_flags.find(static_cast<size_t>(id))->second.is_ephemeral(); return storage.operator_flags.find(static_cast<size_t>(id))->second.is_ephemeral();

View File

@ -55,6 +55,9 @@ namespace blt::gp
} }
} }
/**
* Base class for crossover which performs basic subtree crossover on two random nodes in the parent tree
*/
class crossover_t class crossover_t
{ {
public: public:
@ -123,6 +126,11 @@ namespace blt::gp
config_t config; config_t config;
}; };
class advanced_crossover_t : public crossover_t
{
bool apply(gp_program& program, const tree_t& p1, const tree_t& p2, tree_t& c1, tree_t& c2) override;
};
class mutation_t class mutation_t
{ {
public: public:

View File

@ -196,6 +196,20 @@ namespace blt::gp
worst_fitness = 0; worst_fitness = 0;
normalized_fitness.clear(); normalized_fitness.clear();
} }
friend bool operator==(const population_stats& a, const population_stats& b)
{
return a.overall_fitness.load(std::memory_order_relaxed) == b.overall_fitness.load(std::memory_order_relaxed) &&
a.average_fitness.load(std::memory_order_relaxed) == b.average_fitness.load(std::memory_order_relaxed) &&
a.best_fitness.load(std::memory_order_relaxed) == b.best_fitness.load(std::memory_order_relaxed) &&
a.worst_fitness.load(std::memory_order_relaxed) == b.worst_fitness.load(std::memory_order_relaxed) &&
a.normalized_fitness == b.normalized_fitness;
}
friend bool operator!=(const population_stats& a, const population_stats& b)
{
return !(a == b);
}
}; };
} }

View File

@ -168,7 +168,7 @@ namespace blt::gp
BLT_ASSERT(reader.read(&size, sizeof(size)) == sizeof(size)); BLT_ASSERT(reader.read(&size, sizeof(size)) == sizeof(size));
std::string name; std::string name;
name.resize(size); name.resize(size);
BLT_ASSERT(reader.read(name.data(), size) == size); BLT_ASSERT(reader.read(name.data(), size) == static_cast<i64>(size));
if (!storage.names[i].has_value()) if (!storage.names[i].has_value())
throw std::runtime_error("Expected operator ID " + std::to_string(i) + " to have name " + name); throw std::runtime_error("Expected operator ID " + std::to_string(i) + " to have name " + name);
if (name != *storage.names[i]) if (name != *storage.names[i])
@ -210,11 +210,11 @@ namespace blt::gp
"Operator ID " + std::to_string(i) + " expected return type " + std::to_string(op.return_type) + " but got " + std::to_string( "Operator ID " + std::to_string(i) + " expected return type " + std::to_string(op.return_type) + " but got " + std::to_string(
return_type)); return_type));
size_t arg_type_count; size_t arg_type_count;
BLT_ASSERT(reader.read(&arg_type_count, sizeof(arg_type_count)) == sizeof(return_type));
if (arg_type_count != op.argument_types.size()) if (arg_type_count != op.argument_types.size())
throw std::runtime_error( throw std::runtime_error(
"Operator ID " + std::to_string(i) + " expected " + std::to_string(op.argument_types.size()) + " arguments but got " + "Operator ID " + std::to_string(i) + " expected " + std::to_string(op.argument_types.size()) + " arguments but got " +
std::to_string(arg_type_count)); std::to_string(arg_type_count));
BLT_ASSERT(reader.read(&arg_type_count, sizeof(arg_type_count)) == sizeof(return_type));
for (size_t j = 0; j < arg_type_count; j++) for (size_t j = 0; j < arg_type_count; j++)
{ {
type_id type; type_id type;

View File

@ -80,20 +80,7 @@ namespace blt::gp
if (!point) if (!point)
return false; return false;
// TODO: more crossover! c1.swap_subtrees(point->p1_crossover_point, c2, point->p2_crossover_point);
switch (program.get_random().get_u32(0, 2))
{
case 0:
case 1:
c1.swap_subtrees(point->p1_crossover_point, c2, point->p2_crossover_point);
break;
default:
#if BLT_DEBUG_LEVEL > 0
BLT_ABORT("This place should be unreachable!");
#else
BLT_UNREACHABLE;
#endif
}
#if BLT_DEBUG_LEVEL >= 2 #if BLT_DEBUG_LEVEL >= 2
if (!c1.check(detail::debug::context_ptr) || !c2.check(detail::debug::context_ptr)) if (!c1.check(detail::debug::context_ptr) || !c2.check(detail::debug::context_ptr))
@ -134,6 +121,57 @@ namespace blt::gp
return t.select_subtree_traverse(config.terminal_chance, config.depth_multiplier); return t.select_subtree_traverse(config.terminal_chance, config.depth_multiplier);
} }
bool advanced_crossover_t::apply(gp_program& program, const tree_t& p1, const tree_t& p2, tree_t& c1, tree_t& c2)
{
if (p1.size() < config.min_tree_size || p2.size() < config.min_tree_size)
return false;
// TODO: more crossover!
switch (program.get_random().get_u32(0, 2))
{
// single point crossover (only if operators at this point are "compatible")
case 0:
case0:
{
std::optional<crossover_point_t> point;
if (config.traverse)
point = get_crossover_point_traverse(p1, p2);
else
point = get_crossover_point(p1, p2);
if (!point)
return false;
// check if can work
// otherwise goto case2
}
// Mating crossover analogs to same species breeding. Only works if tree is mostly similar
case 1:
case1:
{
// if fails got to case0
if (false)
goto case0;
}
// Subtree crossover, select random points inside trees and swap their subtrees
case 2:
case2:
return crossover_t::apply(program, p1, p2, c1, c2);
default:
#if BLT_DEBUG_LEVEL > 0
BLT_ABORT("This place should be unreachable!");
#else
BLT_UNREACHABLE;
#endif
}
#if BLT_DEBUG_LEVEL >= 2
if (!c1.check(detail::debug::context_ptr) || !c2.check(detail::debug::context_ptr))
throw std::runtime_error("Tree check failed");
#endif
}
bool mutation_t::apply(gp_program& program, const tree_t&, tree_t& c) bool mutation_t::apply(gp_program& program, const tree_t&, tree_t& c)
{ {
// TODO: options for this? // TODO: options for this?

View File

@ -100,6 +100,12 @@ int main()
test_program.set_operations(operators); test_program.set_operations(operators);
test_program.setup_generational_evaluation(fitness_function, sel, sel, sel, false); test_program.setup_generational_evaluation(fitness_function, sel, sel, sel, false);
// simulate a program which is similar but incompatible with the other programs.
operator_builder<context> builder2{};
gp_program bad_program{691};
bad_program.set_operations(builder2.build(addf, subf, mulf, op_sinf, op_cosf, litf, op_xf));
bad_program.setup_generational_evaluation(fitness_function, sel, sel, sel, false);
program.generate_initial_population(program.get_typesystem().get_type<float>().id()); program.generate_initial_population(program.get_typesystem().get_type<float>().id());
program.setup_generational_evaluation(fitness_function, sel, sel, sel); program.setup_generational_evaluation(fitness_function, sel, sel, sel);
while (!program.should_terminate()) while (!program.should_terminate())
@ -112,8 +118,7 @@ int main()
BLT_TRACE("Evaluate Fitness"); BLT_TRACE("Evaluate Fitness");
program.evaluate_fitness(); program.evaluate_fitness();
{ {
std::filesystem::remove("serialization_test.data"); std::ofstream stream{"serialization_test.data", std::ios::binary | std::ios::trunc};
std::ofstream stream{"serialization_test.data", std::ios::binary};
blt::fs::fstream_writer_t writer{stream}; blt::fs::fstream_writer_t writer{stream};
program.save_generation(writer); program.save_generation(writer);
} }
@ -132,4 +137,34 @@ int main()
} }
} }
} }
{
std::ofstream stream{"serialization_test2.data", std::ios::binary | std::ios::trunc};
blt::fs::fstream_writer_t writer{stream};
program.save_state(writer);
}
{
std::ifstream stream{"serialization_test2.data", std::ios::binary};
blt::fs::fstream_reader_t reader{stream};
test_program.load_state(reader);
for (const auto [saved, loaded] : blt::zip(program.get_stats_histories(), test_program.get_stats_histories()))
{
if (saved != loaded)
{
BLT_ERROR("Serializer Failed to correctly serialize histories to disk, histories are not equal!");
std::exit(1);
}
}
}
try {
std::ifstream stream{"serialization_test2.data", std::ios::binary};
blt::fs::fstream_reader_t reader{stream};
bad_program.load_state(reader);
} catch (const std::runtime_error&)
{
// TODO: use blt::expected so this isn't required + better design.
goto exit;
}
BLT_ASSERT(false && "Expected program to throw an exception when parsing state data into an incompatible program!");
exit:
} }