main
Brett 2024-01-30 12:21:02 -05:00
parent ce53aace34
commit 68dad0f963
4 changed files with 193 additions and 27 deletions

View File

@ -23,5 +23,10 @@
inline constexpr blt::i32 MAX_DEPTH = 12;
inline constexpr blt::i32 width = 256, height = 256;
inline constexpr blt::i32 POPULATION_SIZE = 20;
inline constexpr blt::i32 GEN_COUNT = 20;
inline constexpr blt::i32 TOURNAMENT_SIZE = 3;
inline constexpr float CROSSOVER_RATE = 0.9;
inline constexpr float MUTATION_RATE = 0.9;
#endif //GP_IMAGE_TEST_CONFIG_H

View File

@ -65,7 +65,7 @@ struct node
}
}
static node* construct_random_tree();
static node* construct_random_tree(blt::size_t max_depth = MAX_DEPTH);
void evaluate();

View File

@ -127,7 +127,7 @@ void node::print_tree()
std::cout << ") ";
}
node* node::construct_random_tree()
node* node::construct_random_tree(blt::size_t max_depth)
{
static std::random_device dev;
static std::mt19937_64 engine{dev()};
@ -149,9 +149,9 @@ node* node::construct_random_tree()
current_depth++;
if (choice(engine))
front.first->grow(front.second >= MAX_DEPTH);
front.first->grow(front.second >= max_depth);
else
front.first->full(front.second >= MAX_DEPTH);
front.first->full(front.second >= max_depth);
for (size_t i = 0; i < front.first->argc; i++)
grow_queue.emplace(front.first->sub_nodes[i], current_depth + 1);

View File

@ -23,9 +23,11 @@ blt::gfx::batch_renderer_2d renderer_2d(resources);
class tree
{
public:
struct crossover_result_t
struct search_result_t
{
std::unique_ptr<tree> c1, c2;
node* child;
node* parent;
blt::size_t index;
};
private:
std::unique_ptr<node, node_deleter> root = nullptr;
@ -64,20 +66,29 @@ class tree
}
}
node* select_random_child()
search_result_t select_random_child()
{
static std::random_device dev;
static std::mt19937_64 engine{dev()};
auto d = depth();
std::uniform_int_distribution depth_dist(0, static_cast<int>(d));
std::uniform_int_distribution select(0, 3);
node* current = root.get();
node* parent = nullptr;
blt::size_t index = 0;
while (true)
{
std::uniform_int_distribution<size_t> children(0, current->argc - 1);
if (select(engine) == 0 || current->argc == 0)
break;
index = children(engine);
auto* next = current->sub_nodes[index];
if (next == nullptr)
break;
parent = current;
current = next;
}
return {current, parent, index};
}
public:
@ -86,13 +97,13 @@ class tree
return std::make_unique<tree>(tree{node::construct_random_tree()});
}
blt::size_t depth()
static blt::size_t depth(node* root_node)
{
// depth -> node
std::stack<std::pair<blt::size_t, node*>> stack;
blt::size_t max_depth = 0;
stack.emplace(0, root.get());
stack.emplace(0, root_node);
while (!stack.empty())
{
auto top = stack.top();
@ -112,9 +123,38 @@ class tree
return max_depth;
}
static crossover_result_t crossover(tree* p1, tree* p2)
static bool crossover(tree* p1, tree* p2)
{
return {};
auto n1 = p1->select_random_child();
auto n2 = p2->select_random_child();
const auto& p1_allowed = function_arg_allowed_set_map[to_underlying(n1.parent->type)][n1.index];
const auto& p2_allowed = function_arg_allowed_set_map[to_underlying(n2.parent->type)][n2.index];
if (!p1_allowed.contains(n2.child->type))
return false;
if (!p2_allowed.contains(n1.child->type))
return false;
n1.parent->sub_nodes[n1.index] = n2.child;
n2.parent->sub_nodes[n2.index] = n1.child;
return true;
}
static void mutate(tree* p)
{
static std::random_device dev;
static std::mt19937_64 engine{dev()};
std::uniform_int_distribution choice(0, 1);
auto n = p->select_random_child();
auto d = depth(n.child);
node* new_subtree = node::construct_random_tree(d + 1);
n.parent->sub_nodes[n.index] = new_subtree;
destroyNode(n.child);
}
void evaluate()
@ -149,7 +189,120 @@ class tree
}
};
std::unique_ptr<tree> test_tree;
class gp_population
{
public:
struct gp_i
{
std::unique_ptr<tree> t = nullptr;
float fitness = 0;
};
private:
std::array<gp_i, POPULATION_SIZE> pop;
public:
gp_population() = default;
static blt::size_t choice()
{
static std::random_device dev;
static std::mt19937_64 engine{dev()};
static std::uniform_int_distribution c(0, POPULATION_SIZE - 1);
return c(engine);
}
void init()
{
for (auto& v : pop)
v = {tree::construct_random_tree()};
evaluate_population();
}
tree* select()
{
tree* n = nullptr;
float fitness = 0;
for (int i = 0; i < TOURNAMENT_SIZE; i++)
{
auto& v = pop[choice()];
if (n != v.t.get() && v.fitness > fitness)
{
n = v.t.get();
fitness = v.fitness;
}
}
return n;
}
void run_step()
{
static std::random_device dev;
static std::mt19937_64 engine{dev()};
static std::uniform_real_distribution rand(0.0, 1.0);
for (blt::size_t i = 0; i < POPULATION_SIZE; i++)
{
if (rand(engine) < CROSSOVER_RATE)
{
auto* p1 = select();
auto* p2 = select();
tree::crossover(p1, p2);
}
}
for (blt::size_t i = 0; i < POPULATION_SIZE; i++)
{
if (rand(engine) < MUTATION_RATE)
{
auto* p1 = select();
tree::mutate(p1);
}
}
}
void evaluate_population()
{
for (auto& v : pop)
{
v.t->evaluate();
v.fitness = v.t->fitness();
}
}
image& display(blt::size_t i)
{
return pop[i].t->getImage();
}
std::pair<blt::size_t, float> get_best()
{
blt::size_t i = 0;
float fitness = 0;
for (blt::size_t j = 0; j < POPULATION_SIZE; j++)
{
if (pop[j].fitness > fitness)
{
i = j;
fitness = pop[j].fitness;
}
}
return {i, fitness};
}
image& display_best()
{
return display(get_best().first);
}
float best_fitness()
{
return get_best().second;
}
};
gp_population pop;
blt::gfx::texture_gl2D* texture;
void print_bits(float f)
@ -178,6 +331,8 @@ void init()
resources.load_resources();
}
float best = 0;
void update(std::int32_t w, std::int32_t h)
{
global_matrices.update_perspectives(w, h, 90, 0.1, 2000);
@ -186,23 +341,29 @@ void update(std::int32_t w, std::int32_t h)
if (ImGui::Button("Regenerate"))
{
BLT_INFO("Regen tree");
test_tree = tree::construct_random_tree();
test_tree->evaluate();
BLT_INFO("Uploading");
texture->upload((void*) test_tree->getImage().getData().data(), GL_RGB, 0, 0, 0, -1, -1, GL_FLOAT);
pop.init();
best = pop.best_fitness();
}
if (ImGui::Button("Run Step"))
{
BLT_INFO("Running Step");
pop.run_step();
BLT_INFO("Evaluating Population");
pop.evaluate_population();
best = pop.best_fitness();
}
if (ImGui::Button("Display"))
{
if (test_tree->hasTree())
test_tree->printTree();
}
if (ImGui::Button("Eval"))
{
if (test_tree->hasImage())
BLT_DEBUG(test_tree->fitness());
BLT_INFO("Uploading");
texture->upload((void*) pop.display_best().getData().data(), GL_RGB, 0, 0, 0, -1, -1, GL_FLOAT);
;
// if (test_tree->hasTree())
// test_tree->printTree();
}
ImGui::Text("Best Fitness: %f", best);
auto lw = 512.0f;
auto lh = 512.0f;
@ -216,7 +377,7 @@ void update(std::int32_t w, std::int32_t h)
int main()
{
shapiro_test_run();
//shapiro_test_run();
blt::gfx::init(blt::gfx::window_data{"Window of GP test", init, update}.setSyncInterval(1));
global_matrices.cleanup();
resources.cleanup();