diff --git a/.idea/editor.xml b/.idea/editor.xml
index 6df7d16..5603a5a 100644
--- a/.idea/editor.xml
+++ b/.idea/editor.xml
@@ -5,6 +5,8 @@
+
+
diff --git a/CMakeLists.txt b/CMakeLists.txt
index a1cf2a1..eff7083 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -27,7 +27,7 @@ macro(compile_options target_name)
sanitizers(${target_name})
endmacro()
-project(blt-gp VERSION 0.5.10)
+project(blt-gp VERSION 0.5.11)
include(CTest)
diff --git a/include/blt/gp/program.h b/include/blt/gp/program.h
index 25132ca..bc352b5 100644
--- a/include/blt/gp/program.h
+++ b/include/blt/gp/program.h
@@ -760,6 +760,11 @@ namespace blt::gp
return current_stats;
}
+ [[nodiscard]] const tracked_vector& get_stats_histories() const
+ {
+ return statistic_history;
+ }
+
[[nodiscard]] bool is_operator_ephemeral(const operator_id id) const
{
return storage.operator_flags.find(static_cast(id))->second.is_ephemeral();
diff --git a/include/blt/gp/transformers.h b/include/blt/gp/transformers.h
index 2da1aed..e7970f4 100644
--- a/include/blt/gp/transformers.h
+++ b/include/blt/gp/transformers.h
@@ -55,6 +55,9 @@ namespace blt::gp
}
}
+ /**
+ * Base class for crossover which performs basic subtree crossover on two random nodes in the parent tree
+ */
class crossover_t
{
public:
@@ -123,6 +126,11 @@ namespace blt::gp
config_t config;
};
+ class advanced_crossover_t : public crossover_t
+ {
+ bool apply(gp_program& program, const tree_t& p1, const tree_t& p2, tree_t& c1, tree_t& c2) override;
+ };
+
class mutation_t
{
public:
diff --git a/include/blt/gp/util/statistics.h b/include/blt/gp/util/statistics.h
index 3fd868d..1670ede 100644
--- a/include/blt/gp/util/statistics.h
+++ b/include/blt/gp/util/statistics.h
@@ -196,6 +196,20 @@ namespace blt::gp
worst_fitness = 0;
normalized_fitness.clear();
}
+
+ friend bool operator==(const population_stats& a, const population_stats& b)
+ {
+ return a.overall_fitness.load(std::memory_order_relaxed) == b.overall_fitness.load(std::memory_order_relaxed) &&
+ a.average_fitness.load(std::memory_order_relaxed) == b.average_fitness.load(std::memory_order_relaxed) &&
+ a.best_fitness.load(std::memory_order_relaxed) == b.best_fitness.load(std::memory_order_relaxed) &&
+ a.worst_fitness.load(std::memory_order_relaxed) == b.worst_fitness.load(std::memory_order_relaxed) &&
+ a.normalized_fitness == b.normalized_fitness;
+ }
+
+ friend bool operator!=(const population_stats& a, const population_stats& b)
+ {
+ return !(a == b);
+ }
};
}
diff --git a/src/program.cpp b/src/program.cpp
index 40ad5b5..556ed01 100644
--- a/src/program.cpp
+++ b/src/program.cpp
@@ -168,7 +168,7 @@ namespace blt::gp
BLT_ASSERT(reader.read(&size, sizeof(size)) == sizeof(size));
std::string name;
name.resize(size);
- BLT_ASSERT(reader.read(name.data(), size) == size);
+ BLT_ASSERT(reader.read(name.data(), size) == static_cast(size));
if (!storage.names[i].has_value())
throw std::runtime_error("Expected operator ID " + std::to_string(i) + " to have name " + name);
if (name != *storage.names[i])
@@ -210,11 +210,11 @@ namespace blt::gp
"Operator ID " + std::to_string(i) + " expected return type " + std::to_string(op.return_type) + " but got " + std::to_string(
return_type));
size_t arg_type_count;
+ BLT_ASSERT(reader.read(&arg_type_count, sizeof(arg_type_count)) == sizeof(return_type));
if (arg_type_count != op.argument_types.size())
throw std::runtime_error(
"Operator ID " + std::to_string(i) + " expected " + std::to_string(op.argument_types.size()) + " arguments but got " +
std::to_string(arg_type_count));
- BLT_ASSERT(reader.read(&arg_type_count, sizeof(arg_type_count)) == sizeof(return_type));
for (size_t j = 0; j < arg_type_count; j++)
{
type_id type;
diff --git a/src/transformers.cpp b/src/transformers.cpp
index b7bf704..36da17a 100644
--- a/src/transformers.cpp
+++ b/src/transformers.cpp
@@ -80,20 +80,7 @@ namespace blt::gp
if (!point)
return false;
- // TODO: more crossover!
- switch (program.get_random().get_u32(0, 2))
- {
- case 0:
- case 1:
- c1.swap_subtrees(point->p1_crossover_point, c2, point->p2_crossover_point);
- break;
- default:
-#if BLT_DEBUG_LEVEL > 0
- BLT_ABORT("This place should be unreachable!");
-#else
- BLT_UNREACHABLE;
-#endif
- }
+ c1.swap_subtrees(point->p1_crossover_point, c2, point->p2_crossover_point);
#if BLT_DEBUG_LEVEL >= 2
if (!c1.check(detail::debug::context_ptr) || !c2.check(detail::debug::context_ptr))
@@ -134,6 +121,57 @@ namespace blt::gp
return t.select_subtree_traverse(config.terminal_chance, config.depth_multiplier);
}
+ bool advanced_crossover_t::apply(gp_program& program, const tree_t& p1, const tree_t& p2, tree_t& c1, tree_t& c2)
+ {
+ if (p1.size() < config.min_tree_size || p2.size() < config.min_tree_size)
+ return false;
+
+ // TODO: more crossover!
+ switch (program.get_random().get_u32(0, 2))
+ {
+ // single point crossover (only if operators at this point are "compatible")
+ case 0:
+ case0:
+ {
+ std::optional point;
+
+ if (config.traverse)
+ point = get_crossover_point_traverse(p1, p2);
+ else
+ point = get_crossover_point(p1, p2);
+
+ if (!point)
+ return false;
+
+ // check if can work
+ // otherwise goto case2
+ }
+ // Mating crossover analogs to same species breeding. Only works if tree is mostly similar
+ case 1:
+ case1:
+ {
+ // if fails got to case0
+ if (false)
+ goto case0;
+ }
+ // Subtree crossover, select random points inside trees and swap their subtrees
+ case 2:
+ case2:
+ return crossover_t::apply(program, p1, p2, c1, c2);
+ default:
+#if BLT_DEBUG_LEVEL > 0
+ BLT_ABORT("This place should be unreachable!");
+#else
+ BLT_UNREACHABLE;
+#endif
+ }
+
+#if BLT_DEBUG_LEVEL >= 2
+ if (!c1.check(detail::debug::context_ptr) || !c2.check(detail::debug::context_ptr))
+ throw std::runtime_error("Tree check failed");
+#endif
+ }
+
bool mutation_t::apply(gp_program& program, const tree_t&, tree_t& c)
{
// TODO: options for this?
diff --git a/tests/serialization_test.cpp b/tests/serialization_test.cpp
index c101fa8..9bf5327 100644
--- a/tests/serialization_test.cpp
+++ b/tests/serialization_test.cpp
@@ -100,6 +100,12 @@ int main()
test_program.set_operations(operators);
test_program.setup_generational_evaluation(fitness_function, sel, sel, sel, false);
+ // simulate a program which is similar but incompatible with the other programs.
+ operator_builder builder2{};
+ gp_program bad_program{691};
+ bad_program.set_operations(builder2.build(addf, subf, mulf, op_sinf, op_cosf, litf, op_xf));
+ bad_program.setup_generational_evaluation(fitness_function, sel, sel, sel, false);
+
program.generate_initial_population(program.get_typesystem().get_type().id());
program.setup_generational_evaluation(fitness_function, sel, sel, sel);
while (!program.should_terminate())
@@ -112,8 +118,7 @@ int main()
BLT_TRACE("Evaluate Fitness");
program.evaluate_fitness();
{
- std::filesystem::remove("serialization_test.data");
- std::ofstream stream{"serialization_test.data", std::ios::binary};
+ std::ofstream stream{"serialization_test.data", std::ios::binary | std::ios::trunc};
blt::fs::fstream_writer_t writer{stream};
program.save_generation(writer);
}
@@ -132,4 +137,34 @@ int main()
}
}
}
+ {
+ std::ofstream stream{"serialization_test2.data", std::ios::binary | std::ios::trunc};
+ blt::fs::fstream_writer_t writer{stream};
+ program.save_state(writer);
+ }
+ {
+ std::ifstream stream{"serialization_test2.data", std::ios::binary};
+ blt::fs::fstream_reader_t reader{stream};
+ test_program.load_state(reader);
+
+ for (const auto [saved, loaded] : blt::zip(program.get_stats_histories(), test_program.get_stats_histories()))
+ {
+ if (saved != loaded)
+ {
+ BLT_ERROR("Serializer Failed to correctly serialize histories to disk, histories are not equal!");
+ std::exit(1);
+ }
+ }
+ }
+ try {
+ std::ifstream stream{"serialization_test2.data", std::ios::binary};
+ blt::fs::fstream_reader_t reader{stream};
+ bad_program.load_state(reader);
+ } catch (const std::runtime_error&)
+ {
+ // TODO: use blt::expected so this isn't required + better design.
+ goto exit;
+ }
+ BLT_ASSERT(false && "Expected program to throw an exception when parsing state data into an incompatible program!");
+ exit:
}