working on tree generation

main
Brett 2024-03-19 12:13:33 -04:00
parent 319d385cd9
commit ef8b03ae66
5 changed files with 120 additions and 93 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(lilfbtf5 VERSION 0.1.30) project(lilfbtf5 VERSION 0.1.31)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)

View File

@ -34,6 +34,7 @@ namespace fb
void reset(); void reset();
bool choice(); bool choice();
bool choice(double d);
bool chance(double chance = 0.5); bool chance(double chance = 0.5);
float random_float(float min = 0, float max = 1); float random_float(float min = 0, float max = 1);

View File

@ -122,18 +122,24 @@ namespace fb
} }
}; };
struct node_construction_info_t
{
blt::bump_allocator<blt::BLT_2MB_SIZE, false>& alloc;
random& engine;
type_engine_t& types;
};
struct tree_construction_info_t struct tree_construction_info_t
{ {
tree_init_t tree_type; tree_init_t tree_type;
random& engine; random& engine;
type_engine_t& types; type_engine_t& types;
double terminal_chance = 0.5;
};
struct node_construction_info_t
{
tree_t& tree;
random& engine;
type_engine_t& types;
double terminal_chance;
node_construction_info_t(tree_t& tree, const tree_construction_info_t& info):
tree(tree), engine(info.engine), types(info.types), terminal_chance(info.terminal_chance)
{}
}; };
} }
@ -153,7 +159,9 @@ namespace fb
static detail::node_t* allocate_terminal(detail::node_construction_info_t info, type_id type); static detail::node_t* allocate_terminal(detail::node_construction_info_t info, type_id type);
static void grow(detail::node_construction_info_t info, blt::size_t min_depth, blt::size_t max_depth);
static void full(detail::node_construction_info_t info, blt::size_t depth);
public: public:
explicit tree_t(type_engine_t& types); explicit tree_t(type_engine_t& types);

View File

@ -63,4 +63,10 @@ namespace fb
{ {
return random_double() <= chance; return random_double() <= chance;
} }
bool random::choice(double d)
{
std::uniform_real_distribution<double> dist(0, 1);\
return dist(engine) < d;
}
} }

View File

@ -27,99 +27,38 @@ namespace fb
tree_t::tree_t(type_engine_t& types): alloc(), types(types) tree_t::tree_t(type_engine_t& types): alloc(), types(types)
{} {}
tree_t tree_t::make_tree(type_engine_t& types, random& engine, tree_t tree_t::make_tree(detail::tree_construction_info_t tree_info,
blt::size_t min_depth, blt::size_t max_depth, std::optional<type_id> starting_type) blt::size_t min_depth, blt::size_t max_depth, std::optional<type_id> starting_type)
{ {
using detail::node_t; using detail::node_t;
tree_t tree(types); tree_t tree(tree_info.types);
{ {
if (starting_type) if (starting_type)
tree.root = allocate_non_terminal({tree, tree_info}, starting_type.value());
else
{ {
auto& non_terminals = types.get_non_terminals(starting_type.value()); auto& non_terminals = tree_info.types.get_all_non_terminals();
auto selection = non_terminals[engine.random_long(0, non_terminals.size() - 1)]; auto selection = non_terminals[tree_info.engine.random_long(0, non_terminals.size() - 1)];
func_t func(types.get_function_argc(selection), types.get_function(selection), starting_type.value(), selection); func_t func(tree_info.types.get_function_argc(selection.second), tree_info.types.get_function(selection.second), selection.first,
if (const auto& func_init = types.get_function_initializer(selection)) selection.second);
func_init.value()(func); if (const auto& func_init = tree_info.types.get_function_initializer(selection.second))
tree.root = tree.alloc.template emplace<node_t>(func, tree.alloc);
} else
{
auto& non_terminals = types.get_all_non_terminals();
auto selection = non_terminals[engine.random_long(0, non_terminals.size() - 1)];
func_t func(types.get_function_argc(selection.second), types.get_function(selection.second), selection.first, selection.second);
if (const auto& func_init = types.get_function_initializer(selection.second))
func_init.value()(func); func_init.value()(func);
tree.root = tree.alloc.template emplace<node_t>(func, tree.alloc); tree.root = tree.alloc.template emplace<node_t>(func, tree.alloc);
} }
} }
std::stack<std::pair<node_t*, blt::size_t>> stack;
stack.emplace(tree.root, 0); switch (tree_info.tree_type)
while (!stack.empty())
{ {
auto top = stack.top(); case tree_init_t::GROW:
auto* node = top.first; grow({tree, tree_info}, min_depth, max_depth);
auto depth = top.second; break;
stack.pop(); case tree_init_t::FULL:
full({tree, tree_info}, tree_info.engine.random_long(min_depth, max_depth));
const auto& allowed_types = types.get_function_allowed_arguments(node->type.getFunction()); break;
// we need to make sure there is at least one non-terminal generation, until we hit the min height case tree_init_t::RAMPED_HALF_HALF:
bool has_one_non_terminal = false; break;
for (blt::size_t i = 0; i < node->type.argc(); i++) case tree_init_t::BRETT_HALF_HALF:
{ break;
type_id type_category = allowed_types[i];
const auto& terminals = types.get_terminals(type_category);
const auto& non_terminals = types.get_non_terminals(type_category);
if (depth < min_depth && !has_one_non_terminal)
{
// make sure we have at least min height possible by using at least one non terminal
function_id selection = non_terminals[engine.random_long(0, non_terminals.size() - 1)];
func_t func(types.get_function_argc(selection), types.get_function(selection), type_category, selection);
if (const auto& func_init = types.get_function_initializer(selection))
func_init.value()(func);
node->children[i] = tree.alloc.template emplace<node_t>(func, tree.alloc);
has_one_non_terminal = true;
} else if (depth >= max_depth)
{
// if we are above the max_height select only terminals
function_id selection = terminals[engine.random_long(0, terminals.size() - 1)];
func_t func(types.get_function_argc(selection), types.get_function(selection), type_category, selection);
if (const auto& func_init = types.get_function_initializer(selection))
func_init.value()(func);
node->children[i] = tree.alloc.template emplace<node_t>(func, tree.alloc);
} else if (engine.choice())
{
// otherwise select between use full() method
function_id selection = non_terminals[engine.random_long(0, non_terminals.size() - 1)];
func_t func(types.get_function_argc(selection), types.get_function(selection), type_category, selection);
if (const auto& func_init = types.get_function_initializer(selection))
func_init.value()(func);
node->children[i] = tree.alloc.template emplace<node_t>(func, tree.alloc);
} else
{
// and use grow() method, meaning select choice again
if (engine.choice())
{
// to use non-terminals
function_id selection = non_terminals[engine.random_long(0, non_terminals.size() - 1)];
func_t func(types.get_function_argc(selection), types.get_function(selection), type_category, selection);
if (const auto& func_init = types.get_function_initializer(selection))
func_init.value()(func);
node->children[i] = tree.alloc.template emplace<node_t>(func, tree.alloc);
} else
{
// or use terminals
function_id selection = terminals[engine.random_long(0, terminals.size() - 1)];
func_t func(types.get_function_argc(selection), types.get_function(selection), type_category, selection);
if (const auto& func_init = types.get_function_initializer(selection))
func_init.value()(func);
node->children[i] = tree.alloc.template emplace<node_t>(func, tree.alloc);
}
}
// node has children that need populated
if (node->children[i]->type.argc() != 0)
stack.emplace(node->children[i], depth + 1);
}
} }
return tree; return tree;
@ -159,7 +98,7 @@ namespace fb
func_t func(info.types.get_function_argc(selection), info.types.get_function(selection), type, selection); func_t func(info.types.get_function_argc(selection), info.types.get_function(selection), type, selection);
if (const auto& func_init = info.types.get_function_initializer(selection)) if (const auto& func_init = info.types.get_function_initializer(selection))
func_init.value()(func); func_init.value()(func);
return info.alloc.template emplace<detail::node_t>(func, info.alloc); return info.tree.alloc.template emplace<detail::node_t>(func, info.tree.alloc);
} }
detail::node_t* tree_t::allocate_terminal(detail::node_construction_info_t info, type_id type) detail::node_t* tree_t::allocate_terminal(detail::node_construction_info_t info, type_id type)
@ -175,7 +114,7 @@ namespace fb
func_t func(info.types.get_function_argc(selection), info.types.get_function(selection), type, selection); func_t func(info.types.get_function_argc(selection), info.types.get_function(selection), type, selection);
if (const auto& func_init = info.types.get_function_initializer(selection)) if (const auto& func_init = info.types.get_function_initializer(selection))
func_init.value()(func); func_init.value()(func);
return info.alloc.template emplace<detail::node_t>(func, info.alloc); return info.tree.alloc.template emplace<detail::node_t>(func, info.tree.alloc);
} }
detail::node_t* tree_t::allocate_non_terminal_restricted(detail::node_construction_info_t info, type_id type) detail::node_t* tree_t::allocate_non_terminal_restricted(detail::node_construction_info_t info, type_id type)
@ -194,7 +133,80 @@ namespace fb
func_t func(info.types.get_function_argc(selection), info.types.get_function(selection), type, selection); func_t func(info.types.get_function_argc(selection), info.types.get_function(selection), type, selection);
if (const auto& func_init = info.types.get_function_initializer(selection)) if (const auto& func_init = info.types.get_function_initializer(selection))
(*func_init)(func); (*func_init)(func);
return info.alloc.template emplace<detail::node_t>(func, info.alloc); return info.tree.alloc.template emplace<detail::node_t>(func, info.tree.alloc);
}
void tree_t::grow(detail::node_construction_info_t info, blt::size_t min_depth, blt::size_t max_depth)
{
using namespace detail;
std::stack<std::pair<node_t*, blt::size_t>> stack;
stack.emplace(info.tree.root, 0);
while (!stack.empty())
{
auto top = stack.top();
auto* node = top.first;
auto depth = top.second;
stack.pop();
const auto& allowed_types = info.types.get_function_allowed_arguments(node->type.getFunction());
// we need to make sure there is at least one non-terminal generation, until we hit the min height
bool has_one_non_terminal = false;
for (blt::size_t i = 0; i < node->type.argc(); i++)
{
type_id type_category = allowed_types[i];
if (depth < min_depth && !has_one_non_terminal)
{
// make sure we have at least min height possible by using at least one non terminal
node->children[i] = allocate_non_terminal(info, type_category);
has_one_non_terminal = true;
} else if (depth >= max_depth || info.engine.choice(info.terminal_chance))
{
// if we are above the max_height select only terminals or otherwise select between use of terminals
node->children[i] = allocate_terminal(info, type_category);
} else
{
// and use of non-terminals method
node->children[i] = allocate_non_terminal(info, type_category);
}
// node has children that need populated
if (node->children[i]->type.argc() != 0)
stack.emplace(node->children[i], depth + 1);
}
}
}
void tree_t::full(detail::node_construction_info_t info, blt::size_t select_depth)
{
using namespace detail;
std::stack<std::pair<node_t*, blt::size_t>> stack;
stack.emplace(info.tree.root, 0);
while (!stack.empty())
{
auto top = stack.top();
auto* node = top.first;
auto depth = top.second;
stack.pop();
const auto& allowed_types = info.types.get_function_allowed_arguments(node->type.getFunction());
for (blt::size_t i = 0; i < node->type.argc(); i++)
{
type_id type_category = allowed_types[i];
if (depth >= select_depth)
{
// if we are above the max_height select only terminals
node->children[i] = allocate_terminal(info, type_category);
} else
{
// otherwise only non-terminals can be used
node->children[i] = allocate_non_terminal(info, type_category);
}
// node has children that need populated
if (node->children[i]->type.argc() != 0)
stack.emplace(node->children[i], depth + 1);
}
}
} }