working on transformers
parent
b9c535f6c9
commit
487f771377
|
@ -5,6 +5,8 @@
|
|||
<option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=CppDeprecatedOverridenMethod/@EntryIndexedValue" value="WARNING" type="string" />
|
||||
<option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=CppRedundantFwdClassOrEnumSpecifier/@EntryIndexedValue" value="SUGGESTION" type="string" />
|
||||
<option name="/Default/CodeInspection/Highlighting/InspectionSeverities/=CppRedundantQualifierADL/@EntryIndexedValue" value="DO_NOT_SHOW" type="string" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppDefinitions/@KeyIndexDefined" value="true" type="bool" />
|
||||
<option name="/Default/CodeStyle/Generate/=CppDefinitions/Options/=GenerateInlineDefinitions/@EntryIndexedValue" value="False" type="string" />
|
||||
<option name="/Default/CodeStyle/Naming/CppNaming/Rules/=Class_0020and_0020struct_0020fields/@EntryIndexedValue" value="<NamingElement Priority="11"><Descriptor Static="Indeterminate" Constexpr="Indeterminate" Const="Indeterminate" Volatile="Indeterminate" Accessibility="NOT_APPLICABLE"><type Name="class field" /><type Name="struct field" /></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="m_" Suffix="" Style="aa_bb" /></NamingElement>" type="string" />
|
||||
<option name="/Default/CodeStyle/Naming/CppNaming/Rules/=Enums/@EntryIndexedValue" value="<NamingElement Priority="3"><Descriptor Static="Indeterminate" Constexpr="Indeterminate" Const="Indeterminate" Volatile="Indeterminate" Accessibility="NOT_APPLICABLE"><type Name="enum" /></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="" Suffix="" Style="AA_BB" /></NamingElement>" type="string" />
|
||||
</component>
|
||||
|
|
|
@ -27,7 +27,7 @@ macro(compile_options target_name)
|
|||
sanitizers(${target_name})
|
||||
endmacro()
|
||||
|
||||
project(blt-gp VERSION 0.5.10)
|
||||
project(blt-gp VERSION 0.5.11)
|
||||
|
||||
include(CTest)
|
||||
|
||||
|
|
|
@ -760,6 +760,11 @@ namespace blt::gp
|
|||
return current_stats;
|
||||
}
|
||||
|
||||
[[nodiscard]] const tracked_vector<population_stats>& get_stats_histories() const
|
||||
{
|
||||
return statistic_history;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool is_operator_ephemeral(const operator_id id) const
|
||||
{
|
||||
return storage.operator_flags.find(static_cast<size_t>(id))->second.is_ephemeral();
|
||||
|
|
|
@ -55,6 +55,9 @@ namespace blt::gp
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Base class for crossover which performs basic subtree crossover on two random nodes in the parent tree
|
||||
*/
|
||||
class crossover_t
|
||||
{
|
||||
public:
|
||||
|
@ -123,6 +126,11 @@ namespace blt::gp
|
|||
config_t config;
|
||||
};
|
||||
|
||||
class advanced_crossover_t : public crossover_t
|
||||
{
|
||||
bool apply(gp_program& program, const tree_t& p1, const tree_t& p2, tree_t& c1, tree_t& c2) override;
|
||||
};
|
||||
|
||||
class mutation_t
|
||||
{
|
||||
public:
|
||||
|
|
|
@ -196,6 +196,20 @@ namespace blt::gp
|
|||
worst_fitness = 0;
|
||||
normalized_fitness.clear();
|
||||
}
|
||||
|
||||
friend bool operator==(const population_stats& a, const population_stats& b)
|
||||
{
|
||||
return a.overall_fitness.load(std::memory_order_relaxed) == b.overall_fitness.load(std::memory_order_relaxed) &&
|
||||
a.average_fitness.load(std::memory_order_relaxed) == b.average_fitness.load(std::memory_order_relaxed) &&
|
||||
a.best_fitness.load(std::memory_order_relaxed) == b.best_fitness.load(std::memory_order_relaxed) &&
|
||||
a.worst_fitness.load(std::memory_order_relaxed) == b.worst_fitness.load(std::memory_order_relaxed) &&
|
||||
a.normalized_fitness == b.normalized_fitness;
|
||||
}
|
||||
|
||||
friend bool operator!=(const population_stats& a, const population_stats& b)
|
||||
{
|
||||
return !(a == b);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -168,7 +168,7 @@ namespace blt::gp
|
|||
BLT_ASSERT(reader.read(&size, sizeof(size)) == sizeof(size));
|
||||
std::string name;
|
||||
name.resize(size);
|
||||
BLT_ASSERT(reader.read(name.data(), size) == size);
|
||||
BLT_ASSERT(reader.read(name.data(), size) == static_cast<i64>(size));
|
||||
if (!storage.names[i].has_value())
|
||||
throw std::runtime_error("Expected operator ID " + std::to_string(i) + " to have name " + name);
|
||||
if (name != *storage.names[i])
|
||||
|
@ -210,11 +210,11 @@ namespace blt::gp
|
|||
"Operator ID " + std::to_string(i) + " expected return type " + std::to_string(op.return_type) + " but got " + std::to_string(
|
||||
return_type));
|
||||
size_t arg_type_count;
|
||||
BLT_ASSERT(reader.read(&arg_type_count, sizeof(arg_type_count)) == sizeof(return_type));
|
||||
if (arg_type_count != op.argument_types.size())
|
||||
throw std::runtime_error(
|
||||
"Operator ID " + std::to_string(i) + " expected " + std::to_string(op.argument_types.size()) + " arguments but got " +
|
||||
std::to_string(arg_type_count));
|
||||
BLT_ASSERT(reader.read(&arg_type_count, sizeof(arg_type_count)) == sizeof(return_type));
|
||||
for (size_t j = 0; j < arg_type_count; j++)
|
||||
{
|
||||
type_id type;
|
||||
|
|
|
@ -80,20 +80,7 @@ namespace blt::gp
|
|||
if (!point)
|
||||
return false;
|
||||
|
||||
// TODO: more crossover!
|
||||
switch (program.get_random().get_u32(0, 2))
|
||||
{
|
||||
case 0:
|
||||
case 1:
|
||||
c1.swap_subtrees(point->p1_crossover_point, c2, point->p2_crossover_point);
|
||||
break;
|
||||
default:
|
||||
#if BLT_DEBUG_LEVEL > 0
|
||||
BLT_ABORT("This place should be unreachable!");
|
||||
#else
|
||||
BLT_UNREACHABLE;
|
||||
#endif
|
||||
}
|
||||
|
||||
#if BLT_DEBUG_LEVEL >= 2
|
||||
if (!c1.check(detail::debug::context_ptr) || !c2.check(detail::debug::context_ptr))
|
||||
|
@ -134,6 +121,57 @@ namespace blt::gp
|
|||
return t.select_subtree_traverse(config.terminal_chance, config.depth_multiplier);
|
||||
}
|
||||
|
||||
bool advanced_crossover_t::apply(gp_program& program, const tree_t& p1, const tree_t& p2, tree_t& c1, tree_t& c2)
|
||||
{
|
||||
if (p1.size() < config.min_tree_size || p2.size() < config.min_tree_size)
|
||||
return false;
|
||||
|
||||
// TODO: more crossover!
|
||||
switch (program.get_random().get_u32(0, 2))
|
||||
{
|
||||
// single point crossover (only if operators at this point are "compatible")
|
||||
case 0:
|
||||
case0:
|
||||
{
|
||||
std::optional<crossover_point_t> point;
|
||||
|
||||
if (config.traverse)
|
||||
point = get_crossover_point_traverse(p1, p2);
|
||||
else
|
||||
point = get_crossover_point(p1, p2);
|
||||
|
||||
if (!point)
|
||||
return false;
|
||||
|
||||
// check if can work
|
||||
// otherwise goto case2
|
||||
}
|
||||
// Mating crossover analogs to same species breeding. Only works if tree is mostly similar
|
||||
case 1:
|
||||
case1:
|
||||
{
|
||||
// if fails got to case0
|
||||
if (false)
|
||||
goto case0;
|
||||
}
|
||||
// Subtree crossover, select random points inside trees and swap their subtrees
|
||||
case 2:
|
||||
case2:
|
||||
return crossover_t::apply(program, p1, p2, c1, c2);
|
||||
default:
|
||||
#if BLT_DEBUG_LEVEL > 0
|
||||
BLT_ABORT("This place should be unreachable!");
|
||||
#else
|
||||
BLT_UNREACHABLE;
|
||||
#endif
|
||||
}
|
||||
|
||||
#if BLT_DEBUG_LEVEL >= 2
|
||||
if (!c1.check(detail::debug::context_ptr) || !c2.check(detail::debug::context_ptr))
|
||||
throw std::runtime_error("Tree check failed");
|
||||
#endif
|
||||
}
|
||||
|
||||
bool mutation_t::apply(gp_program& program, const tree_t&, tree_t& c)
|
||||
{
|
||||
// TODO: options for this?
|
||||
|
|
|
@ -100,6 +100,12 @@ int main()
|
|||
test_program.set_operations(operators);
|
||||
test_program.setup_generational_evaluation(fitness_function, sel, sel, sel, false);
|
||||
|
||||
// simulate a program which is similar but incompatible with the other programs.
|
||||
operator_builder<context> builder2{};
|
||||
gp_program bad_program{691};
|
||||
bad_program.set_operations(builder2.build(addf, subf, mulf, op_sinf, op_cosf, litf, op_xf));
|
||||
bad_program.setup_generational_evaluation(fitness_function, sel, sel, sel, false);
|
||||
|
||||
program.generate_initial_population(program.get_typesystem().get_type<float>().id());
|
||||
program.setup_generational_evaluation(fitness_function, sel, sel, sel);
|
||||
while (!program.should_terminate())
|
||||
|
@ -112,8 +118,7 @@ int main()
|
|||
BLT_TRACE("Evaluate Fitness");
|
||||
program.evaluate_fitness();
|
||||
{
|
||||
std::filesystem::remove("serialization_test.data");
|
||||
std::ofstream stream{"serialization_test.data", std::ios::binary};
|
||||
std::ofstream stream{"serialization_test.data", std::ios::binary | std::ios::trunc};
|
||||
blt::fs::fstream_writer_t writer{stream};
|
||||
program.save_generation(writer);
|
||||
}
|
||||
|
@ -132,4 +137,34 @@ int main()
|
|||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
std::ofstream stream{"serialization_test2.data", std::ios::binary | std::ios::trunc};
|
||||
blt::fs::fstream_writer_t writer{stream};
|
||||
program.save_state(writer);
|
||||
}
|
||||
{
|
||||
std::ifstream stream{"serialization_test2.data", std::ios::binary};
|
||||
blt::fs::fstream_reader_t reader{stream};
|
||||
test_program.load_state(reader);
|
||||
|
||||
for (const auto [saved, loaded] : blt::zip(program.get_stats_histories(), test_program.get_stats_histories()))
|
||||
{
|
||||
if (saved != loaded)
|
||||
{
|
||||
BLT_ERROR("Serializer Failed to correctly serialize histories to disk, histories are not equal!");
|
||||
std::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
try {
|
||||
std::ifstream stream{"serialization_test2.data", std::ios::binary};
|
||||
blt::fs::fstream_reader_t reader{stream};
|
||||
bad_program.load_state(reader);
|
||||
} catch (const std::runtime_error&)
|
||||
{
|
||||
// TODO: use blt::expected so this isn't required + better design.
|
||||
goto exit;
|
||||
}
|
||||
BLT_ASSERT(false && "Expected program to throw an exception when parsing state data into an incompatible program!");
|
||||
exit:
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue