working on transformers
parent
b9c535f6c9
commit
487f771377
|
@ -5,6 +5,8 @@
|
||||||
<option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=CppDeprecatedOverridenMethod/@EntryIndexedValue" value="WARNING" type="string" />
|
<option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=CppDeprecatedOverridenMethod/@EntryIndexedValue" value="WARNING" type="string" />
|
||||||
<option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=CppRedundantFwdClassOrEnumSpecifier/@EntryIndexedValue" value="SUGGESTION" type="string" />
|
<option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=CppRedundantFwdClassOrEnumSpecifier/@EntryIndexedValue" value="SUGGESTION" type="string" />
|
||||||
<option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=CppRedundantQualifierADL/@EntryIndexedValue" value="DO_NOT_SHOW" type="string" />
|
<option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=CppRedundantQualifierADL/@EntryIndexedValue" value="DO_NOT_SHOW" type="string" />
|
||||||
|
<option name="/Default/CodeStyle/Generate/=CppDefinitions/@KeyIndexDefined" value="true" type="bool" />
|
||||||
|
<option name="/Default/CodeStyle/Generate/=CppDefinitions/Options/=GenerateInlineDefinitions/@EntryIndexedValue" value="False" type="string" />
|
||||||
<option name="/Default/CodeStyle/Naming/CppNaming/Rules/=Class_0020and_0020struct_0020fields/@EntryIndexedValue" value="<NamingElement Priority="11"><Descriptor Static="Indeterminate" Constexpr="Indeterminate" Const="Indeterminate" Volatile="Indeterminate" Accessibility="NOT_APPLICABLE"><type Name="class field" /><type Name="struct field" /></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="m_" Suffix="" Style="aa_bb" /></NamingElement>" type="string" />
|
<option name="/Default/CodeStyle/Naming/CppNaming/Rules/=Class_0020and_0020struct_0020fields/@EntryIndexedValue" value="<NamingElement Priority="11"><Descriptor Static="Indeterminate" Constexpr="Indeterminate" Const="Indeterminate" Volatile="Indeterminate" Accessibility="NOT_APPLICABLE"><type Name="class field" /><type Name="struct field" /></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="m_" Suffix="" Style="aa_bb" /></NamingElement>" type="string" />
|
||||||
<option name="/Default/CodeStyle/Naming/CppNaming/Rules/=Enums/@EntryIndexedValue" value="<NamingElement Priority="3"><Descriptor Static="Indeterminate" Constexpr="Indeterminate" Const="Indeterminate" Volatile="Indeterminate" Accessibility="NOT_APPLICABLE"><type Name="enum" /></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="" Suffix="" Style="AA_BB" /></NamingElement>" type="string" />
|
<option name="/Default/CodeStyle/Naming/CppNaming/Rules/=Enums/@EntryIndexedValue" value="<NamingElement Priority="3"><Descriptor Static="Indeterminate" Constexpr="Indeterminate" Const="Indeterminate" Volatile="Indeterminate" Accessibility="NOT_APPLICABLE"><type Name="enum" /></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="" Suffix="" Style="AA_BB" /></NamingElement>" type="string" />
|
||||||
</component>
|
</component>
|
||||||
|
|
|
@ -27,7 +27,7 @@ macro(compile_options target_name)
|
||||||
sanitizers(${target_name})
|
sanitizers(${target_name})
|
||||||
endmacro()
|
endmacro()
|
||||||
|
|
||||||
project(blt-gp VERSION 0.5.10)
|
project(blt-gp VERSION 0.5.11)
|
||||||
|
|
||||||
include(CTest)
|
include(CTest)
|
||||||
|
|
||||||
|
|
|
@ -760,6 +760,11 @@ namespace blt::gp
|
||||||
return current_stats;
|
return current_stats;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] const tracked_vector<population_stats>& get_stats_histories() const
|
||||||
|
{
|
||||||
|
return statistic_history;
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]] bool is_operator_ephemeral(const operator_id id) const
|
[[nodiscard]] bool is_operator_ephemeral(const operator_id id) const
|
||||||
{
|
{
|
||||||
return storage.operator_flags.find(static_cast<size_t>(id))->second.is_ephemeral();
|
return storage.operator_flags.find(static_cast<size_t>(id))->second.is_ephemeral();
|
||||||
|
|
|
@ -55,6 +55,9 @@ namespace blt::gp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Base class for crossover which performs basic subtree crossover on two random nodes in the parent tree
|
||||||
|
*/
|
||||||
class crossover_t
|
class crossover_t
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
@ -123,6 +126,11 @@ namespace blt::gp
|
||||||
config_t config;
|
config_t config;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class advanced_crossover_t : public crossover_t
|
||||||
|
{
|
||||||
|
bool apply(gp_program& program, const tree_t& p1, const tree_t& p2, tree_t& c1, tree_t& c2) override;
|
||||||
|
};
|
||||||
|
|
||||||
class mutation_t
|
class mutation_t
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -196,6 +196,20 @@ namespace blt::gp
|
||||||
worst_fitness = 0;
|
worst_fitness = 0;
|
||||||
normalized_fitness.clear();
|
normalized_fitness.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
friend bool operator==(const population_stats& a, const population_stats& b)
|
||||||
|
{
|
||||||
|
return a.overall_fitness.load(std::memory_order_relaxed) == b.overall_fitness.load(std::memory_order_relaxed) &&
|
||||||
|
a.average_fitness.load(std::memory_order_relaxed) == b.average_fitness.load(std::memory_order_relaxed) &&
|
||||||
|
a.best_fitness.load(std::memory_order_relaxed) == b.best_fitness.load(std::memory_order_relaxed) &&
|
||||||
|
a.worst_fitness.load(std::memory_order_relaxed) == b.worst_fitness.load(std::memory_order_relaxed) &&
|
||||||
|
a.normalized_fitness == b.normalized_fitness;
|
||||||
|
}
|
||||||
|
|
||||||
|
friend bool operator!=(const population_stats& a, const population_stats& b)
|
||||||
|
{
|
||||||
|
return !(a == b);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -168,7 +168,7 @@ namespace blt::gp
|
||||||
BLT_ASSERT(reader.read(&size, sizeof(size)) == sizeof(size));
|
BLT_ASSERT(reader.read(&size, sizeof(size)) == sizeof(size));
|
||||||
std::string name;
|
std::string name;
|
||||||
name.resize(size);
|
name.resize(size);
|
||||||
BLT_ASSERT(reader.read(name.data(), size) == size);
|
BLT_ASSERT(reader.read(name.data(), size) == static_cast<i64>(size));
|
||||||
if (!storage.names[i].has_value())
|
if (!storage.names[i].has_value())
|
||||||
throw std::runtime_error("Expected operator ID " + std::to_string(i) + " to have name " + name);
|
throw std::runtime_error("Expected operator ID " + std::to_string(i) + " to have name " + name);
|
||||||
if (name != *storage.names[i])
|
if (name != *storage.names[i])
|
||||||
|
@ -210,11 +210,11 @@ namespace blt::gp
|
||||||
"Operator ID " + std::to_string(i) + " expected return type " + std::to_string(op.return_type) + " but got " + std::to_string(
|
"Operator ID " + std::to_string(i) + " expected return type " + std::to_string(op.return_type) + " but got " + std::to_string(
|
||||||
return_type));
|
return_type));
|
||||||
size_t arg_type_count;
|
size_t arg_type_count;
|
||||||
|
BLT_ASSERT(reader.read(&arg_type_count, sizeof(arg_type_count)) == sizeof(return_type));
|
||||||
if (arg_type_count != op.argument_types.size())
|
if (arg_type_count != op.argument_types.size())
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Operator ID " + std::to_string(i) + " expected " + std::to_string(op.argument_types.size()) + " arguments but got " +
|
"Operator ID " + std::to_string(i) + " expected " + std::to_string(op.argument_types.size()) + " arguments but got " +
|
||||||
std::to_string(arg_type_count));
|
std::to_string(arg_type_count));
|
||||||
BLT_ASSERT(reader.read(&arg_type_count, sizeof(arg_type_count)) == sizeof(return_type));
|
|
||||||
for (size_t j = 0; j < arg_type_count; j++)
|
for (size_t j = 0; j < arg_type_count; j++)
|
||||||
{
|
{
|
||||||
type_id type;
|
type_id type;
|
||||||
|
|
|
@ -80,20 +80,7 @@ namespace blt::gp
|
||||||
if (!point)
|
if (!point)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// TODO: more crossover!
|
c1.swap_subtrees(point->p1_crossover_point, c2, point->p2_crossover_point);
|
||||||
switch (program.get_random().get_u32(0, 2))
|
|
||||||
{
|
|
||||||
case 0:
|
|
||||||
case 1:
|
|
||||||
c1.swap_subtrees(point->p1_crossover_point, c2, point->p2_crossover_point);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
#if BLT_DEBUG_LEVEL > 0
|
|
||||||
BLT_ABORT("This place should be unreachable!");
|
|
||||||
#else
|
|
||||||
BLT_UNREACHABLE;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
#if BLT_DEBUG_LEVEL >= 2
|
#if BLT_DEBUG_LEVEL >= 2
|
||||||
if (!c1.check(detail::debug::context_ptr) || !c2.check(detail::debug::context_ptr))
|
if (!c1.check(detail::debug::context_ptr) || !c2.check(detail::debug::context_ptr))
|
||||||
|
@ -134,6 +121,57 @@ namespace blt::gp
|
||||||
return t.select_subtree_traverse(config.terminal_chance, config.depth_multiplier);
|
return t.select_subtree_traverse(config.terminal_chance, config.depth_multiplier);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool advanced_crossover_t::apply(gp_program& program, const tree_t& p1, const tree_t& p2, tree_t& c1, tree_t& c2)
|
||||||
|
{
|
||||||
|
if (p1.size() < config.min_tree_size || p2.size() < config.min_tree_size)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// TODO: more crossover!
|
||||||
|
switch (program.get_random().get_u32(0, 2))
|
||||||
|
{
|
||||||
|
// single point crossover (only if operators at this point are "compatible")
|
||||||
|
case 0:
|
||||||
|
case0:
|
||||||
|
{
|
||||||
|
std::optional<crossover_point_t> point;
|
||||||
|
|
||||||
|
if (config.traverse)
|
||||||
|
point = get_crossover_point_traverse(p1, p2);
|
||||||
|
else
|
||||||
|
point = get_crossover_point(p1, p2);
|
||||||
|
|
||||||
|
if (!point)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// check if can work
|
||||||
|
// otherwise goto case2
|
||||||
|
}
|
||||||
|
// Mating crossover analogs to same species breeding. Only works if tree is mostly similar
|
||||||
|
case 1:
|
||||||
|
case1:
|
||||||
|
{
|
||||||
|
// if fails got to case0
|
||||||
|
if (false)
|
||||||
|
goto case0;
|
||||||
|
}
|
||||||
|
// Subtree crossover, select random points inside trees and swap their subtrees
|
||||||
|
case 2:
|
||||||
|
case2:
|
||||||
|
return crossover_t::apply(program, p1, p2, c1, c2);
|
||||||
|
default:
|
||||||
|
#if BLT_DEBUG_LEVEL > 0
|
||||||
|
BLT_ABORT("This place should be unreachable!");
|
||||||
|
#else
|
||||||
|
BLT_UNREACHABLE;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
#if BLT_DEBUG_LEVEL >= 2
|
||||||
|
if (!c1.check(detail::debug::context_ptr) || !c2.check(detail::debug::context_ptr))
|
||||||
|
throw std::runtime_error("Tree check failed");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
bool mutation_t::apply(gp_program& program, const tree_t&, tree_t& c)
|
bool mutation_t::apply(gp_program& program, const tree_t&, tree_t& c)
|
||||||
{
|
{
|
||||||
// TODO: options for this?
|
// TODO: options for this?
|
||||||
|
|
|
@ -100,6 +100,12 @@ int main()
|
||||||
test_program.set_operations(operators);
|
test_program.set_operations(operators);
|
||||||
test_program.setup_generational_evaluation(fitness_function, sel, sel, sel, false);
|
test_program.setup_generational_evaluation(fitness_function, sel, sel, sel, false);
|
||||||
|
|
||||||
|
// simulate a program which is similar but incompatible with the other programs.
|
||||||
|
operator_builder<context> builder2{};
|
||||||
|
gp_program bad_program{691};
|
||||||
|
bad_program.set_operations(builder2.build(addf, subf, mulf, op_sinf, op_cosf, litf, op_xf));
|
||||||
|
bad_program.setup_generational_evaluation(fitness_function, sel, sel, sel, false);
|
||||||
|
|
||||||
program.generate_initial_population(program.get_typesystem().get_type<float>().id());
|
program.generate_initial_population(program.get_typesystem().get_type<float>().id());
|
||||||
program.setup_generational_evaluation(fitness_function, sel, sel, sel);
|
program.setup_generational_evaluation(fitness_function, sel, sel, sel);
|
||||||
while (!program.should_terminate())
|
while (!program.should_terminate())
|
||||||
|
@ -112,8 +118,7 @@ int main()
|
||||||
BLT_TRACE("Evaluate Fitness");
|
BLT_TRACE("Evaluate Fitness");
|
||||||
program.evaluate_fitness();
|
program.evaluate_fitness();
|
||||||
{
|
{
|
||||||
std::filesystem::remove("serialization_test.data");
|
std::ofstream stream{"serialization_test.data", std::ios::binary | std::ios::trunc};
|
||||||
std::ofstream stream{"serialization_test.data", std::ios::binary};
|
|
||||||
blt::fs::fstream_writer_t writer{stream};
|
blt::fs::fstream_writer_t writer{stream};
|
||||||
program.save_generation(writer);
|
program.save_generation(writer);
|
||||||
}
|
}
|
||||||
|
@ -132,4 +137,34 @@ int main()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
std::ofstream stream{"serialization_test2.data", std::ios::binary | std::ios::trunc};
|
||||||
|
blt::fs::fstream_writer_t writer{stream};
|
||||||
|
program.save_state(writer);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
std::ifstream stream{"serialization_test2.data", std::ios::binary};
|
||||||
|
blt::fs::fstream_reader_t reader{stream};
|
||||||
|
test_program.load_state(reader);
|
||||||
|
|
||||||
|
for (const auto [saved, loaded] : blt::zip(program.get_stats_histories(), test_program.get_stats_histories()))
|
||||||
|
{
|
||||||
|
if (saved != loaded)
|
||||||
|
{
|
||||||
|
BLT_ERROR("Serializer Failed to correctly serialize histories to disk, histories are not equal!");
|
||||||
|
std::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
std::ifstream stream{"serialization_test2.data", std::ios::binary};
|
||||||
|
blt::fs::fstream_reader_t reader{stream};
|
||||||
|
bad_program.load_state(reader);
|
||||||
|
} catch (const std::runtime_error&)
|
||||||
|
{
|
||||||
|
// TODO: use blt::expected so this isn't required + better design.
|
||||||
|
goto exit;
|
||||||
|
}
|
||||||
|
BLT_ASSERT(false && "Expected program to throw an exception when parsing state data into an incompatible program!");
|
||||||
|
exit:
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue