GP_Image_Test/src/gp.cpp

302 lines
10 KiB
C++

/*
* <Short Description>
* Copyright (C) 2024 Brett Terpstra
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
#include <gp.h>
#include <blt/std/allocator.h>
#include <blt/std/logging.h>
#include <functions.h>
#include <random>
#include <queue>
#include <stack>
blt::bump_allocator<node> node_allocator(32000);
node* createNode(function_t type)
{
auto* n = node_allocator.allocate(1);
node_allocator.construct(n, type);
return n;
}
void destroyNode(node* n)
{
if (n == nullptr)
return;
node_allocator.destroy(n);
node_allocator.deallocate(n, 1);
}
void node::grow(bool use_terminals)
{
auto min_children = function_arg_min_map[to_underlying(type)];
auto max_children = function_arg_max_map[to_underlying(type)];
static thread_local std::random_device dev;
static thread_local std::mt19937_64 engine{dev()};
std::uniform_int_distribution<blt::i32> dist(min_children, max_children);
static std::uniform_int_distribution<int> choice(0, 1);
argc = dist(engine);
if (argc == 0)
return;
const auto& allowed_args_args = function_arg_allowed_map[to_underlying(type)];
for (size_t i = 0; i < argc; i++)
{
// 50/50 chance to either use a terminal or use from the function list.
populate_node(i, engine, allowed_args_args[i], choice(engine) || use_terminals);
}
}
void node::full(bool use_terminals)
{
static thread_local std::random_device dev;
static thread_local std::mt19937_64 engine{dev()};
argc = function_arg_max_map[to_underlying(type)];
if (argc == 0)
return;
const auto& allowed_args_args = function_arg_allowed_map[to_underlying(type)];
for (size_t i = 0; i < argc; i++)
populate_node(i, engine, allowed_args_args[i], use_terminals);
}
void node::populate_node(size_t i, std::mt19937_64& engine, const allowed_funcs<function_t>& allowed_args, bool use_terminal)
{
if (use_terminal)
{
auto terminals = intersection(allowed_args, FUNC_ALLOW_TERMINALS_SET);
if (terminals.empty())
{
terminals = allowed_args;
}
std::uniform_int_distribution<blt::size_t> select(0, terminals.size() - 1);
sub_nodes[i] = createNode(terminals[select(engine)]);
} else
{
auto non_terminals = intersection_comp(allowed_args, FUNC_ALLOW_TERMINALS_SET);
if (non_terminals.empty())
{
//BLT_WARN("Empty non-terminals set for node '%s'! Filling from terminals!", function_name_map[to_underlying(type)].c_str());
std::uniform_int_distribution<blt::size_t> select(0, allowed_args.size() - 1);
sub_nodes[i] = createNode(allowed_args[select(engine)]);
} else
{
std::uniform_int_distribution<blt::size_t> select(0, non_terminals.size() - 1);
sub_nodes[i] = createNode(non_terminals[select(engine)]);
}
}
}
std::string node::get_type_name()
{
if (argc > 0)
return function_name_map[to_underlying(type)];
else
{
if (type == function_t::SCALAR)
{
evaluate();
return std::to_string(img->get().x());
} else if (type == function_t::COLOR)
{
evaluate();
return '{' + std::to_string(img->get().x()) + ", " + std::to_string(img->get().y()) + ", " + std::to_string(img->get().z()) + "}";
} else
return function_name_map[to_underlying(type)];
}
}
void node::print_tree()
{
struct stack_info
{
node* n;
blt::i64 layer = 0;
};
std::stack<std::pair<blt::i64, node*>> nodes;
std::vector<stack_info> stack_nodes;
nodes.emplace(0, this);
while (!nodes.empty())
{
auto top = nodes.top();
stack_nodes.push_back({top.second, top.first});
nodes.pop();
for (size_t i = 0; i < top.second->argc; i++)
nodes.emplace(top.first + 1, top.second->sub_nodes[i]);
}
for (auto e : blt::enumerate(stack_nodes))
{
auto i = e.first;
auto& v = e.second;
std::cout << std::endl;
for (blt::i64 j = 0; j < v.layer; j++)
std::cout << "| ";
if (v.n->argc != 0)
{
std::cout << "(";
}
std::cout << v.n->get_type_name();
if (i + 1 < stack_nodes.size())
{
auto& n = stack_nodes[i + 1];
if (n.layer < v.layer)
{
//std::cout << " ";
// if (v.layer - n.layer > 2)
// {
for (auto j = v.layer - 1; j >= n.layer; j--)
{
std::cout << std::endl;
for (blt::i64 k = 0; k < j; k++)
{
std::cout << "| ";
}
std::cout << ")";
}
// }
} else
std::cout << " ";
} else
{
for (auto j = v.layer - 1; j >= 0; j--)
{
std::cout << std::endl;
for (blt::i64 k = 0; k < j; k++)
std::cout << "| ";
std::cout << ")";
}
}
}
std::cout << std::endl;
}
node* node::construct_random_tree(blt::size_t max_depth)
{
static std::random_device dev;
static std::mt19937_64 engine{dev()};
std::uniform_int_distribution<int> choice(0, 1);
static auto NON_TERMINALS = intersection_comp(FUNC_ALLOW_ANY, FUNC_ALLOW_TERMINALS_SET);
std::uniform_int_distribution<int> select(0, static_cast<int>(NON_TERMINALS.size()) - 1);
node* n = createNode(NON_TERMINALS[select(engine)]);
std::queue<std::pair<node*, size_t>> grow_queue;
size_t current_depth = 0;
grow_queue.emplace(n, current_depth);
while (!grow_queue.empty())
{
auto front = grow_queue.front();
if (front.second != current_depth)
current_depth++;
if (choice(engine) && current_depth >= MIN_DEPTH)
front.first->grow(front.second >= max_depth);
else
front.first->full(front.second >= max_depth);
for (size_t i = 0; i < front.first->argc; i++)
grow_queue.emplace(front.first->sub_nodes[i], current_depth + 1);
grow_queue.pop();
}
return n;
}
void node::evaluate()
{
img = image{};
std::array<image*, MAX_ARGS> sub_node_images{nullptr};
for (size_t i = 0; i < argc; i++)
{
BLT_ASSERT_MSG(sub_nodes[i] != nullptr && sub_nodes[i]->img.has_value() && "Node must have evaluated children!",
("Failed at arg: " + std::to_string(i) + " with type: " + function_name_map[to_underlying(type)] + " child type: " +
function_name_map[to_underlying(sub_nodes[i]->type)]).c_str());
sub_node_images[i] = &sub_nodes[i]->img.value();
}
#define FUNC_DEFINE(NAME, MIN_ARGS, MAX_ARGS, FUNC, ...) case function_t::NAME: { \
if (function_t::NAME == function_t::IF) { \
/*std::cout << "__:" << function_name_map[to_underlying(this->sub_nodes[0]->type)] << std::endl;*/ \
}\
if (FUNC_ALLOW_TERMINALS_SET.contains(function_t::NAME)){ \
FUNC(img.value(), 0, 0, argc, const_cast<const image**>(sub_node_images.data()), data); \
} else { \
for (blt::i32 y = 0; y < height; y++) { \
for (blt::i32 x = 0; x < width; x++) { \
FUNC(img.value(), static_cast<float>(x), static_cast<float>(y), argc, const_cast<const image**>(sub_node_images.data()), data); \
} \
} \
} \
} \
break;
switch (type)
{
FUNC_FUNCTIONS
default:
BLT_WARN("How did we get here?");
break;
}
#undef FUNC_DEFINE
reset_children();
}
node::node(const node& copy)
{
argc = copy.argc;
type = copy.type;
static_assert(sizeof(data_t) == sizeof(float) * 3 && "Uhhh something is wrong here!");
std::memcpy(data.data(), copy.data.data(), sizeof(data_t));
for (blt::size_t i = 0; i < argc; i++)
sub_nodes[i] = copy.sub_nodes[i]->clone();
}
node* node::clone()
{
auto np = node_allocator.allocate(1);
// tee hee
::new(np) node(*this);
return np;
}
void node::evaluate_tree()
{
std::stack<node*> nodes;
std::stack<node*> node_stack;
nodes.push(this);
while (!nodes.empty())
{
auto* top = nodes.top();
node_stack.push(top);
nodes.pop();
for (size_t i = 0; i < top->argc; i++)
nodes.push(top->sub_nodes[i]);
}
while (!node_stack.empty())
{
node_stack.top()->evaluate();
node_stack.pop();
}
}