302 lines
10 KiB
C++
302 lines
10 KiB
C++
/*
|
|
* <Short Description>
|
|
* Copyright (C) 2024 Brett Terpstra
|
|
*
|
|
* This program is free software: you can redistribute it and/or modify
|
|
* it under the terms of the GNU General Public License as published by
|
|
* the Free Software Foundation, either version 3 of the License, or
|
|
* (at your option) any later version.
|
|
*
|
|
* This program is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
* GNU General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU General Public License
|
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
*/
|
|
#include <gp.h>
|
|
#include <blt/std/allocator.h>
|
|
#include <blt/std/logging.h>
|
|
#include <functions.h>
|
|
#include <random>
|
|
#include <queue>
|
|
#include <stack>
|
|
|
|
blt::bump_allocator<node> node_allocator(32000);
|
|
|
|
node* createNode(function_t type)
|
|
{
|
|
auto* n = node_allocator.allocate(1);
|
|
node_allocator.construct(n, type);
|
|
return n;
|
|
}
|
|
|
|
void destroyNode(node* n)
|
|
{
|
|
if (n == nullptr)
|
|
return;
|
|
node_allocator.destroy(n);
|
|
node_allocator.deallocate(n, 1);
|
|
}
|
|
|
|
void node::grow(bool use_terminals)
|
|
{
|
|
auto min_children = function_arg_min_map[to_underlying(type)];
|
|
auto max_children = function_arg_max_map[to_underlying(type)];
|
|
|
|
static thread_local std::random_device dev;
|
|
static thread_local std::mt19937_64 engine{dev()};
|
|
std::uniform_int_distribution<blt::i32> dist(min_children, max_children);
|
|
static std::uniform_int_distribution<int> choice(0, 1);
|
|
|
|
argc = dist(engine);
|
|
if (argc == 0)
|
|
return;
|
|
const auto& allowed_args_args = function_arg_allowed_map[to_underlying(type)];
|
|
for (size_t i = 0; i < argc; i++)
|
|
{
|
|
// 50/50 chance to either use a terminal or use from the function list.
|
|
populate_node(i, engine, allowed_args_args[i], choice(engine) || use_terminals);
|
|
}
|
|
}
|
|
|
|
void node::full(bool use_terminals)
|
|
{
|
|
static thread_local std::random_device dev;
|
|
static thread_local std::mt19937_64 engine{dev()};
|
|
|
|
argc = function_arg_max_map[to_underlying(type)];
|
|
if (argc == 0)
|
|
return;
|
|
|
|
const auto& allowed_args_args = function_arg_allowed_map[to_underlying(type)];
|
|
for (size_t i = 0; i < argc; i++)
|
|
populate_node(i, engine, allowed_args_args[i], use_terminals);
|
|
}
|
|
|
|
void node::populate_node(size_t i, std::mt19937_64& engine, const allowed_funcs<function_t>& allowed_args, bool use_terminal)
|
|
{
|
|
if (use_terminal)
|
|
{
|
|
auto terminals = intersection(allowed_args, FUNC_ALLOW_TERMINALS_SET);
|
|
if (terminals.empty())
|
|
{
|
|
terminals = allowed_args;
|
|
}
|
|
std::uniform_int_distribution<blt::size_t> select(0, terminals.size() - 1);
|
|
sub_nodes[i] = createNode(terminals[select(engine)]);
|
|
} else
|
|
{
|
|
auto non_terminals = intersection_comp(allowed_args, FUNC_ALLOW_TERMINALS_SET);
|
|
if (non_terminals.empty())
|
|
{
|
|
//BLT_WARN("Empty non-terminals set for node '%s'! Filling from terminals!", function_name_map[to_underlying(type)].c_str());
|
|
std::uniform_int_distribution<blt::size_t> select(0, allowed_args.size() - 1);
|
|
sub_nodes[i] = createNode(allowed_args[select(engine)]);
|
|
} else
|
|
{
|
|
std::uniform_int_distribution<blt::size_t> select(0, non_terminals.size() - 1);
|
|
sub_nodes[i] = createNode(non_terminals[select(engine)]);
|
|
}
|
|
}
|
|
}
|
|
|
|
std::string node::get_type_name()
|
|
{
|
|
if (argc > 0)
|
|
return function_name_map[to_underlying(type)];
|
|
else
|
|
{
|
|
if (type == function_t::SCALAR)
|
|
{
|
|
evaluate();
|
|
return std::to_string(img->get().x());
|
|
} else if (type == function_t::COLOR)
|
|
{
|
|
evaluate();
|
|
return '{' + std::to_string(img->get().x()) + ", " + std::to_string(img->get().y()) + ", " + std::to_string(img->get().z()) + "}";
|
|
} else
|
|
return function_name_map[to_underlying(type)];
|
|
}
|
|
}
|
|
|
|
void node::print_tree()
|
|
{
|
|
struct stack_info
|
|
{
|
|
node* n;
|
|
blt::i64 layer = 0;
|
|
};
|
|
std::stack<std::pair<blt::i64, node*>> nodes;
|
|
std::vector<stack_info> stack_nodes;
|
|
nodes.emplace(0, this);
|
|
|
|
while (!nodes.empty())
|
|
{
|
|
auto top = nodes.top();
|
|
stack_nodes.push_back({top.second, top.first});
|
|
nodes.pop();
|
|
for (size_t i = 0; i < top.second->argc; i++)
|
|
nodes.emplace(top.first + 1, top.second->sub_nodes[i]);
|
|
}
|
|
|
|
for (auto e : blt::enumerate(stack_nodes))
|
|
{
|
|
auto i = e.first;
|
|
auto& v = e.second;
|
|
std::cout << std::endl;
|
|
for (blt::i64 j = 0; j < v.layer; j++)
|
|
std::cout << "| ";
|
|
if (v.n->argc != 0)
|
|
{
|
|
std::cout << "(";
|
|
}
|
|
std::cout << v.n->get_type_name();
|
|
if (i + 1 < stack_nodes.size())
|
|
{
|
|
auto& n = stack_nodes[i + 1];
|
|
if (n.layer < v.layer)
|
|
{
|
|
//std::cout << " ";
|
|
// if (v.layer - n.layer > 2)
|
|
// {
|
|
for (auto j = v.layer - 1; j >= n.layer; j--)
|
|
{
|
|
std::cout << std::endl;
|
|
for (blt::i64 k = 0; k < j; k++)
|
|
{
|
|
std::cout << "| ";
|
|
}
|
|
std::cout << ")";
|
|
}
|
|
// }
|
|
} else
|
|
std::cout << " ";
|
|
} else
|
|
{
|
|
for (auto j = v.layer - 1; j >= 0; j--)
|
|
{
|
|
std::cout << std::endl;
|
|
for (blt::i64 k = 0; k < j; k++)
|
|
std::cout << "| ";
|
|
std::cout << ")";
|
|
}
|
|
}
|
|
}
|
|
std::cout << std::endl;
|
|
}
|
|
|
|
node* node::construct_random_tree(blt::size_t max_depth)
|
|
{
|
|
static std::random_device dev;
|
|
static std::mt19937_64 engine{dev()};
|
|
std::uniform_int_distribution<int> choice(0, 1);
|
|
|
|
static auto NON_TERMINALS = intersection_comp(FUNC_ALLOW_ANY, FUNC_ALLOW_TERMINALS_SET);
|
|
|
|
std::uniform_int_distribution<int> select(0, static_cast<int>(NON_TERMINALS.size()) - 1);
|
|
|
|
node* n = createNode(NON_TERMINALS[select(engine)]);
|
|
|
|
std::queue<std::pair<node*, size_t>> grow_queue;
|
|
size_t current_depth = 0;
|
|
grow_queue.emplace(n, current_depth);
|
|
while (!grow_queue.empty())
|
|
{
|
|
auto front = grow_queue.front();
|
|
if (front.second != current_depth)
|
|
current_depth++;
|
|
|
|
if (choice(engine) && current_depth >= MIN_DEPTH)
|
|
front.first->grow(front.second >= max_depth);
|
|
else
|
|
front.first->full(front.second >= max_depth);
|
|
|
|
for (size_t i = 0; i < front.first->argc; i++)
|
|
grow_queue.emplace(front.first->sub_nodes[i], current_depth + 1);
|
|
grow_queue.pop();
|
|
}
|
|
|
|
return n;
|
|
}
|
|
|
|
void node::evaluate()
|
|
{
|
|
img = image{};
|
|
std::array<image*, MAX_ARGS> sub_node_images{nullptr};
|
|
for (size_t i = 0; i < argc; i++)
|
|
{
|
|
BLT_ASSERT_MSG(sub_nodes[i] != nullptr && sub_nodes[i]->img.has_value() && "Node must have evaluated children!",
|
|
("Failed at arg: " + std::to_string(i) + " with type: " + function_name_map[to_underlying(type)] + " child type: " +
|
|
function_name_map[to_underlying(sub_nodes[i]->type)]).c_str());
|
|
sub_node_images[i] = &sub_nodes[i]->img.value();
|
|
}
|
|
#define FUNC_DEFINE(NAME, MIN_ARGS, MAX_ARGS, FUNC, ...) case function_t::NAME: { \
|
|
if (function_t::NAME == function_t::IF) { \
|
|
/*std::cout << "__:" << function_name_map[to_underlying(this->sub_nodes[0]->type)] << std::endl;*/ \
|
|
}\
|
|
if (FUNC_ALLOW_TERMINALS_SET.contains(function_t::NAME)){ \
|
|
FUNC(img.value(), 0, 0, argc, const_cast<const image**>(sub_node_images.data()), data); \
|
|
} else { \
|
|
for (blt::i32 y = 0; y < height; y++) { \
|
|
for (blt::i32 x = 0; x < width; x++) { \
|
|
FUNC(img.value(), static_cast<float>(x), static_cast<float>(y), argc, const_cast<const image**>(sub_node_images.data()), data); \
|
|
} \
|
|
} \
|
|
} \
|
|
} \
|
|
break;
|
|
|
|
switch (type)
|
|
{
|
|
FUNC_FUNCTIONS
|
|
default:
|
|
BLT_WARN("How did we get here?");
|
|
break;
|
|
}
|
|
#undef FUNC_DEFINE
|
|
reset_children();
|
|
}
|
|
|
|
node::node(const node& copy)
|
|
{
|
|
argc = copy.argc;
|
|
type = copy.type;
|
|
static_assert(sizeof(data_t) == sizeof(float) * 3 && "Uhhh something is wrong here!");
|
|
std::memcpy(data.data(), copy.data.data(), sizeof(data_t));
|
|
for (blt::size_t i = 0; i < argc; i++)
|
|
sub_nodes[i] = copy.sub_nodes[i]->clone();
|
|
}
|
|
|
|
node* node::clone()
|
|
{
|
|
auto np = node_allocator.allocate(1);
|
|
// tee hee
|
|
::new(np) node(*this);
|
|
return np;
|
|
}
|
|
|
|
void node::evaluate_tree()
|
|
{
|
|
std::stack<node*> nodes;
|
|
std::stack<node*> node_stack;
|
|
|
|
nodes.push(this);
|
|
|
|
while (!nodes.empty())
|
|
{
|
|
auto* top = nodes.top();
|
|
node_stack.push(top);
|
|
nodes.pop();
|
|
for (size_t i = 0; i < top->argc; i++)
|
|
nodes.push(top->sub_nodes[i]);
|
|
}
|
|
|
|
while (!node_stack.empty())
|
|
{
|
|
node_stack.top()->evaluate();
|
|
node_stack.pop();
|
|
}
|
|
}
|