/* * * Copyright (C) 2024 Brett Terpstra * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ #include #include #include #include #include #include #include blt::bump_allocator node_allocator(32000); node* createNode(function_t type) { auto* n = node_allocator.allocate(1); node_allocator.construct(n, type); return n; } void destroyNode(node* n) { if (n == nullptr) return; node_allocator.destroy(n); node_allocator.deallocate(n, 1); } void node::grow(bool use_terminals) { auto min_children = function_arg_min_map[to_underlying(type)]; auto max_children = function_arg_max_map[to_underlying(type)]; static thread_local std::random_device dev; static thread_local std::mt19937_64 engine{dev()}; std::uniform_int_distribution dist(min_children, max_children); static std::uniform_int_distribution choice(0, 1); argc = dist(engine); if (argc == 0) return; const auto& allowed_args_args = function_arg_allowed_map[to_underlying(type)]; for (size_t i = 0; i < argc; i++) { // 50/50 chance to either use a terminal or use from the function list. populate_node(i, engine, allowed_args_args[i], choice(engine) || use_terminals); } } void node::full(bool use_terminals) { static thread_local std::random_device dev; static thread_local std::mt19937_64 engine{dev()}; argc = function_arg_max_map[to_underlying(type)]; if (argc == 0) return; const auto& allowed_args_args = function_arg_allowed_map[to_underlying(type)]; for (size_t i = 0; i < argc; i++) populate_node(i, engine, allowed_args_args[i], use_terminals); } void node::populate_node(size_t i, std::mt19937_64& engine, const allowed_funcs& allowed_args, bool use_terminal) { if (use_terminal) { auto terminals = intersection(allowed_args, FUNC_ALLOW_TERMINALS_SET); if (terminals.empty()) { terminals = allowed_args; } std::uniform_int_distribution select(0, terminals.size() - 1); sub_nodes[i] = createNode(terminals[select(engine)]); } else { auto non_terminals = intersection_comp(allowed_args, FUNC_ALLOW_TERMINALS_SET); if (non_terminals.empty()) { //BLT_WARN("Empty non-terminals set for node '%s'! Filling from terminals!", function_name_map[to_underlying(type)].c_str()); std::uniform_int_distribution select(0, allowed_args.size() - 1); sub_nodes[i] = createNode(allowed_args[select(engine)]); } else { std::uniform_int_distribution select(0, non_terminals.size() - 1); sub_nodes[i] = createNode(non_terminals[select(engine)]); } } } std::string node::get_type_name() { if (argc > 0) return function_name_map[to_underlying(type)]; else { if (type == function_t::SCALAR) { evaluate(); return std::to_string(img->get().x()); } else if (type == function_t::COLOR) { evaluate(); return '{' + std::to_string(img->get().x()) + ", " + std::to_string(img->get().y()) + ", " + std::to_string(img->get().z()) + "}"; } else return function_name_map[to_underlying(type)]; } } void node::print_tree() { struct stack_info { node* n; blt::i64 layer = 0; }; std::stack> nodes; std::vector stack_nodes; nodes.emplace(0, this); while (!nodes.empty()) { auto top = nodes.top(); stack_nodes.push_back({top.second, top.first}); nodes.pop(); for (size_t i = 0; i < top.second->argc; i++) nodes.emplace(top.first + 1, top.second->sub_nodes[i]); } for (auto e : blt::enumerate(stack_nodes)) { auto i = e.first; auto& v = e.second; std::cout << std::endl; for (blt::i64 j = 0; j < v.layer; j++) std::cout << "| "; if (v.n->argc != 0) { std::cout << "("; } std::cout << v.n->get_type_name(); if (i + 1 < stack_nodes.size()) { auto& n = stack_nodes[i + 1]; if (n.layer < v.layer) { //std::cout << " "; // if (v.layer - n.layer > 2) // { for (auto j = v.layer - 1; j >= n.layer; j--) { std::cout << std::endl; for (blt::i64 k = 0; k < j; k++) { std::cout << "| "; } std::cout << ")"; } // } } else std::cout << " "; } else { for (auto j = v.layer - 1; j >= 0; j--) { std::cout << std::endl; for (blt::i64 k = 0; k < j; k++) std::cout << "| "; std::cout << ")"; } } } std::cout << std::endl; } node* node::construct_random_tree(blt::size_t max_depth) { static std::random_device dev; static std::mt19937_64 engine{dev()}; std::uniform_int_distribution choice(0, 1); static auto NON_TERMINALS = intersection_comp(FUNC_ALLOW_ANY, FUNC_ALLOW_TERMINALS_SET); std::uniform_int_distribution select(0, static_cast(NON_TERMINALS.size()) - 1); node* n = createNode(NON_TERMINALS[select(engine)]); std::queue> grow_queue; size_t current_depth = 0; grow_queue.emplace(n, current_depth); while (!grow_queue.empty()) { auto front = grow_queue.front(); if (front.second != current_depth) current_depth++; if (choice(engine) && current_depth >= MIN_DEPTH) front.first->grow(front.second >= max_depth); else front.first->full(front.second >= max_depth); for (size_t i = 0; i < front.first->argc; i++) grow_queue.emplace(front.first->sub_nodes[i], current_depth + 1); grow_queue.pop(); } return n; } void node::evaluate() { img = image{}; std::array sub_node_images{nullptr}; for (size_t i = 0; i < argc; i++) { BLT_ASSERT_MSG(sub_nodes[i] != nullptr && sub_nodes[i]->img.has_value() && "Node must have evaluated children!", ("Failed at arg: " + std::to_string(i) + " with type: " + function_name_map[to_underlying(type)] + " child type: " + function_name_map[to_underlying(sub_nodes[i]->type)]).c_str()); sub_node_images[i] = &sub_nodes[i]->img.value(); } #define FUNC_DEFINE(NAME, MIN_ARGS, MAX_ARGS, FUNC, ...) case function_t::NAME: { \ if (function_t::NAME == function_t::IF) { \ /*std::cout << "__:" << function_name_map[to_underlying(this->sub_nodes[0]->type)] << std::endl;*/ \ }\ if (FUNC_ALLOW_TERMINALS_SET.contains(function_t::NAME)){ \ FUNC(img.value(), 0, 0, argc, const_cast(sub_node_images.data()), data); \ } else { \ for (blt::i32 y = 0; y < height; y++) { \ for (blt::i32 x = 0; x < width; x++) { \ FUNC(img.value(), static_cast(x), static_cast(y), argc, const_cast(sub_node_images.data()), data); \ } \ } \ } \ } \ break; switch (type) { FUNC_FUNCTIONS default: BLT_WARN("How did we get here?"); break; } #undef FUNC_DEFINE reset_children(); } node::node(const node& copy) { argc = copy.argc; type = copy.type; static_assert(sizeof(data_t) == sizeof(float) * 3 && "Uhhh something is wrong here!"); std::memcpy(data.data(), copy.data.data(), sizeof(data_t)); for (blt::size_t i = 0; i < argc; i++) sub_nodes[i] = copy.sub_nodes[i]->clone(); } node* node::clone() { auto np = node_allocator.allocate(1); // tee hee ::new(np) node(*this); return np; } void node::evaluate_tree() { std::stack nodes; std::stack node_stack; nodes.push(this); while (!nodes.empty()) { auto* top = nodes.top(); node_stack.push(top); nodes.pop(); for (size_t i = 0; i < top->argc; i++) nodes.push(top->sub_nodes[i]); } while (!node_stack.empty()) { node_stack.top()->evaluate(); node_stack.pop(); } }