diff --git a/CMakeLists.txt b/CMakeLists.txt index 6ab37f9..fac6e7b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.25) -project(lilfbtf5 VERSION 0.1.30) +project(lilfbtf5 VERSION 0.1.31) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) diff --git a/include/lilfbtf/random.h b/include/lilfbtf/random.h index c61e66f..c7eed34 100644 --- a/include/lilfbtf/random.h +++ b/include/lilfbtf/random.h @@ -34,6 +34,7 @@ namespace fb void reset(); bool choice(); + bool choice(double d); bool chance(double chance = 0.5); float random_float(float min = 0, float max = 1); diff --git a/include/lilfbtf/tree.h b/include/lilfbtf/tree.h index 862df00..68968fa 100644 --- a/include/lilfbtf/tree.h +++ b/include/lilfbtf/tree.h @@ -122,18 +122,24 @@ namespace fb } }; - struct node_construction_info_t - { - blt::bump_allocator& alloc; - random& engine; - type_engine_t& types; - }; - struct tree_construction_info_t { tree_init_t tree_type; random& engine; type_engine_t& types; + double terminal_chance = 0.5; + }; + + struct node_construction_info_t + { + tree_t& tree; + random& engine; + type_engine_t& types; + double terminal_chance; + + node_construction_info_t(tree_t& tree, const tree_construction_info_t& info): + tree(tree), engine(info.engine), types(info.types), terminal_chance(info.terminal_chance) + {} }; } @@ -153,7 +159,9 @@ namespace fb static detail::node_t* allocate_terminal(detail::node_construction_info_t info, type_id type); + static void grow(detail::node_construction_info_t info, blt::size_t min_depth, blt::size_t max_depth); + static void full(detail::node_construction_info_t info, blt::size_t depth); public: explicit tree_t(type_engine_t& types); diff --git a/src/random.cpp b/src/random.cpp index f552547..90589b1 100644 --- a/src/random.cpp +++ b/src/random.cpp @@ -63,4 +63,10 @@ namespace fb { return random_double() <= chance; } + + bool random::choice(double d) + { + std::uniform_real_distribution dist(0, 1);\ + return dist(engine) < d; + } } \ No newline at end of file diff --git a/src/tree.cpp b/src/tree.cpp index b82d667..de7503b 100644 --- a/src/tree.cpp +++ b/src/tree.cpp @@ -27,99 +27,38 @@ namespace fb tree_t::tree_t(type_engine_t& types): alloc(), types(types) {} - tree_t tree_t::make_tree(type_engine_t& types, random& engine, + tree_t tree_t::make_tree(detail::tree_construction_info_t tree_info, blt::size_t min_depth, blt::size_t max_depth, std::optional starting_type) { using detail::node_t; - tree_t tree(types); - + tree_t tree(tree_info.types); { if (starting_type) + tree.root = allocate_non_terminal({tree, tree_info}, starting_type.value()); + else { - auto& non_terminals = types.get_non_terminals(starting_type.value()); - auto selection = non_terminals[engine.random_long(0, non_terminals.size() - 1)]; - func_t func(types.get_function_argc(selection), types.get_function(selection), starting_type.value(), selection); - if (const auto& func_init = types.get_function_initializer(selection)) - func_init.value()(func); - tree.root = tree.alloc.template emplace(func, tree.alloc); - } else - { - auto& non_terminals = types.get_all_non_terminals(); - auto selection = non_terminals[engine.random_long(0, non_terminals.size() - 1)]; - func_t func(types.get_function_argc(selection.second), types.get_function(selection.second), selection.first, selection.second); - if (const auto& func_init = types.get_function_initializer(selection.second)) + auto& non_terminals = tree_info.types.get_all_non_terminals(); + auto selection = non_terminals[tree_info.engine.random_long(0, non_terminals.size() - 1)]; + func_t func(tree_info.types.get_function_argc(selection.second), tree_info.types.get_function(selection.second), selection.first, + selection.second); + if (const auto& func_init = tree_info.types.get_function_initializer(selection.second)) func_init.value()(func); tree.root = tree.alloc.template emplace(func, tree.alloc); } } - std::stack> stack; - stack.emplace(tree.root, 0); - while (!stack.empty()) + + switch (tree_info.tree_type) { - auto top = stack.top(); - auto* node = top.first; - auto depth = top.second; - stack.pop(); - - const auto& allowed_types = types.get_function_allowed_arguments(node->type.getFunction()); - // we need to make sure there is at least one non-terminal generation, until we hit the min height - bool has_one_non_terminal = false; - for (blt::size_t i = 0; i < node->type.argc(); i++) - { - type_id type_category = allowed_types[i]; - const auto& terminals = types.get_terminals(type_category); - const auto& non_terminals = types.get_non_terminals(type_category); - - if (depth < min_depth && !has_one_non_terminal) - { - // make sure we have at least min height possible by using at least one non terminal - function_id selection = non_terminals[engine.random_long(0, non_terminals.size() - 1)]; - func_t func(types.get_function_argc(selection), types.get_function(selection), type_category, selection); - if (const auto& func_init = types.get_function_initializer(selection)) - func_init.value()(func); - node->children[i] = tree.alloc.template emplace(func, tree.alloc); - has_one_non_terminal = true; - } else if (depth >= max_depth) - { - // if we are above the max_height select only terminals - function_id selection = terminals[engine.random_long(0, terminals.size() - 1)]; - func_t func(types.get_function_argc(selection), types.get_function(selection), type_category, selection); - if (const auto& func_init = types.get_function_initializer(selection)) - func_init.value()(func); - node->children[i] = tree.alloc.template emplace(func, tree.alloc); - } else if (engine.choice()) - { - // otherwise select between use full() method - function_id selection = non_terminals[engine.random_long(0, non_terminals.size() - 1)]; - func_t func(types.get_function_argc(selection), types.get_function(selection), type_category, selection); - if (const auto& func_init = types.get_function_initializer(selection)) - func_init.value()(func); - node->children[i] = tree.alloc.template emplace(func, tree.alloc); - } else - { - // and use grow() method, meaning select choice again - if (engine.choice()) - { - // to use non-terminals - function_id selection = non_terminals[engine.random_long(0, non_terminals.size() - 1)]; - func_t func(types.get_function_argc(selection), types.get_function(selection), type_category, selection); - if (const auto& func_init = types.get_function_initializer(selection)) - func_init.value()(func); - node->children[i] = tree.alloc.template emplace(func, tree.alloc); - } else - { - // or use terminals - function_id selection = terminals[engine.random_long(0, terminals.size() - 1)]; - func_t func(types.get_function_argc(selection), types.get_function(selection), type_category, selection); - if (const auto& func_init = types.get_function_initializer(selection)) - func_init.value()(func); - node->children[i] = tree.alloc.template emplace(func, tree.alloc); - } - } - // node has children that need populated - if (node->children[i]->type.argc() != 0) - stack.emplace(node->children[i], depth + 1); - } + case tree_init_t::GROW: + grow({tree, tree_info}, min_depth, max_depth); + break; + case tree_init_t::FULL: + full({tree, tree_info}, tree_info.engine.random_long(min_depth, max_depth)); + break; + case tree_init_t::RAMPED_HALF_HALF: + break; + case tree_init_t::BRETT_HALF_HALF: + break; } return tree; @@ -159,7 +98,7 @@ namespace fb func_t func(info.types.get_function_argc(selection), info.types.get_function(selection), type, selection); if (const auto& func_init = info.types.get_function_initializer(selection)) func_init.value()(func); - return info.alloc.template emplace(func, info.alloc); + return info.tree.alloc.template emplace(func, info.tree.alloc); } detail::node_t* tree_t::allocate_terminal(detail::node_construction_info_t info, type_id type) @@ -175,7 +114,7 @@ namespace fb func_t func(info.types.get_function_argc(selection), info.types.get_function(selection), type, selection); if (const auto& func_init = info.types.get_function_initializer(selection)) func_init.value()(func); - return info.alloc.template emplace(func, info.alloc); + return info.tree.alloc.template emplace(func, info.tree.alloc); } detail::node_t* tree_t::allocate_non_terminal_restricted(detail::node_construction_info_t info, type_id type) @@ -194,7 +133,80 @@ namespace fb func_t func(info.types.get_function_argc(selection), info.types.get_function(selection), type, selection); if (const auto& func_init = info.types.get_function_initializer(selection)) (*func_init)(func); - return info.alloc.template emplace(func, info.alloc); + return info.tree.alloc.template emplace(func, info.tree.alloc); + } + + void tree_t::grow(detail::node_construction_info_t info, blt::size_t min_depth, blt::size_t max_depth) + { + using namespace detail; + std::stack> stack; + stack.emplace(info.tree.root, 0); + while (!stack.empty()) + { + auto top = stack.top(); + auto* node = top.first; + auto depth = top.second; + stack.pop(); + + const auto& allowed_types = info.types.get_function_allowed_arguments(node->type.getFunction()); + // we need to make sure there is at least one non-terminal generation, until we hit the min height + bool has_one_non_terminal = false; + for (blt::size_t i = 0; i < node->type.argc(); i++) + { + type_id type_category = allowed_types[i]; + + if (depth < min_depth && !has_one_non_terminal) + { + // make sure we have at least min height possible by using at least one non terminal + node->children[i] = allocate_non_terminal(info, type_category); + has_one_non_terminal = true; + } else if (depth >= max_depth || info.engine.choice(info.terminal_chance)) + { + // if we are above the max_height select only terminals or otherwise select between use of terminals + node->children[i] = allocate_terminal(info, type_category); + } else + { + // and use of non-terminals method + node->children[i] = allocate_non_terminal(info, type_category); + } + // node has children that need populated + if (node->children[i]->type.argc() != 0) + stack.emplace(node->children[i], depth + 1); + } + } + } + + void tree_t::full(detail::node_construction_info_t info, blt::size_t select_depth) + { + using namespace detail; + std::stack> stack; + stack.emplace(info.tree.root, 0); + while (!stack.empty()) + { + auto top = stack.top(); + auto* node = top.first; + auto depth = top.second; + stack.pop(); + + const auto& allowed_types = info.types.get_function_allowed_arguments(node->type.getFunction()); + for (blt::size_t i = 0; i < node->type.argc(); i++) + { + type_id type_category = allowed_types[i]; + + if (depth >= select_depth) + { + // if we are above the max_height select only terminals + node->children[i] = allocate_terminal(info, type_category); + } else + { + // otherwise only non-terminals can be used + node->children[i] = allocate_non_terminal(info, type_category); + } + // node has children that need populated + if (node->children[i]->type.argc() != 0) + stack.emplace(node->children[i], depth + 1); + } + } }