working on tree generation

main
Brett 2024-03-19 12:13:33 -04:00
parent 319d385cd9
commit ef8b03ae66
5 changed files with 120 additions and 93 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25)
project(lilfbtf5 VERSION 0.1.30)
project(lilfbtf5 VERSION 0.1.31)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)

View File

@ -34,6 +34,7 @@ namespace fb
void reset();
bool choice();
bool choice(double d);
bool chance(double chance = 0.5);
float random_float(float min = 0, float max = 1);

View File

@ -122,18 +122,24 @@ namespace fb
}
};
struct node_construction_info_t
{
blt::bump_allocator<blt::BLT_2MB_SIZE, false>& alloc;
random& engine;
type_engine_t& types;
};
struct tree_construction_info_t
{
tree_init_t tree_type;
random& engine;
type_engine_t& types;
double terminal_chance = 0.5;
};
struct node_construction_info_t
{
tree_t& tree;
random& engine;
type_engine_t& types;
double terminal_chance;
node_construction_info_t(tree_t& tree, const tree_construction_info_t& info):
tree(tree), engine(info.engine), types(info.types), terminal_chance(info.terminal_chance)
{}
};
}
@ -153,7 +159,9 @@ namespace fb
static detail::node_t* allocate_terminal(detail::node_construction_info_t info, type_id type);
static void grow(detail::node_construction_info_t info, blt::size_t min_depth, blt::size_t max_depth);
static void full(detail::node_construction_info_t info, blt::size_t depth);
public:
explicit tree_t(type_engine_t& types);

View File

@ -63,4 +63,10 @@ namespace fb
{
return random_double() <= chance;
}
bool random::choice(double d)
{
std::uniform_real_distribution<double> dist(0, 1);\
return dist(engine) < d;
}
}

View File

@ -27,99 +27,38 @@ namespace fb
tree_t::tree_t(type_engine_t& types): alloc(), types(types)
{}
tree_t tree_t::make_tree(type_engine_t& types, random& engine,
tree_t tree_t::make_tree(detail::tree_construction_info_t tree_info,
blt::size_t min_depth, blt::size_t max_depth, std::optional<type_id> starting_type)
{
using detail::node_t;
tree_t tree(types);
tree_t tree(tree_info.types);
{
if (starting_type)
tree.root = allocate_non_terminal({tree, tree_info}, starting_type.value());
else
{
auto& non_terminals = types.get_non_terminals(starting_type.value());
auto selection = non_terminals[engine.random_long(0, non_terminals.size() - 1)];
func_t func(types.get_function_argc(selection), types.get_function(selection), starting_type.value(), selection);
if (const auto& func_init = types.get_function_initializer(selection))
func_init.value()(func);
tree.root = tree.alloc.template emplace<node_t>(func, tree.alloc);
} else
{
auto& non_terminals = types.get_all_non_terminals();
auto selection = non_terminals[engine.random_long(0, non_terminals.size() - 1)];
func_t func(types.get_function_argc(selection.second), types.get_function(selection.second), selection.first, selection.second);
if (const auto& func_init = types.get_function_initializer(selection.second))
auto& non_terminals = tree_info.types.get_all_non_terminals();
auto selection = non_terminals[tree_info.engine.random_long(0, non_terminals.size() - 1)];
func_t func(tree_info.types.get_function_argc(selection.second), tree_info.types.get_function(selection.second), selection.first,
selection.second);
if (const auto& func_init = tree_info.types.get_function_initializer(selection.second))
func_init.value()(func);
tree.root = tree.alloc.template emplace<node_t>(func, tree.alloc);
}
}
std::stack<std::pair<node_t*, blt::size_t>> stack;
stack.emplace(tree.root, 0);
while (!stack.empty())
{
auto top = stack.top();
auto* node = top.first;
auto depth = top.second;
stack.pop();
const auto& allowed_types = types.get_function_allowed_arguments(node->type.getFunction());
// we need to make sure there is at least one non-terminal generation, until we hit the min height
bool has_one_non_terminal = false;
for (blt::size_t i = 0; i < node->type.argc(); i++)
switch (tree_info.tree_type)
{
type_id type_category = allowed_types[i];
const auto& terminals = types.get_terminals(type_category);
const auto& non_terminals = types.get_non_terminals(type_category);
if (depth < min_depth && !has_one_non_terminal)
{
// make sure we have at least min height possible by using at least one non terminal
function_id selection = non_terminals[engine.random_long(0, non_terminals.size() - 1)];
func_t func(types.get_function_argc(selection), types.get_function(selection), type_category, selection);
if (const auto& func_init = types.get_function_initializer(selection))
func_init.value()(func);
node->children[i] = tree.alloc.template emplace<node_t>(func, tree.alloc);
has_one_non_terminal = true;
} else if (depth >= max_depth)
{
// if we are above the max_height select only terminals
function_id selection = terminals[engine.random_long(0, terminals.size() - 1)];
func_t func(types.get_function_argc(selection), types.get_function(selection), type_category, selection);
if (const auto& func_init = types.get_function_initializer(selection))
func_init.value()(func);
node->children[i] = tree.alloc.template emplace<node_t>(func, tree.alloc);
} else if (engine.choice())
{
// otherwise select between use full() method
function_id selection = non_terminals[engine.random_long(0, non_terminals.size() - 1)];
func_t func(types.get_function_argc(selection), types.get_function(selection), type_category, selection);
if (const auto& func_init = types.get_function_initializer(selection))
func_init.value()(func);
node->children[i] = tree.alloc.template emplace<node_t>(func, tree.alloc);
} else
{
// and use grow() method, meaning select choice again
if (engine.choice())
{
// to use non-terminals
function_id selection = non_terminals[engine.random_long(0, non_terminals.size() - 1)];
func_t func(types.get_function_argc(selection), types.get_function(selection), type_category, selection);
if (const auto& func_init = types.get_function_initializer(selection))
func_init.value()(func);
node->children[i] = tree.alloc.template emplace<node_t>(func, tree.alloc);
} else
{
// or use terminals
function_id selection = terminals[engine.random_long(0, terminals.size() - 1)];
func_t func(types.get_function_argc(selection), types.get_function(selection), type_category, selection);
if (const auto& func_init = types.get_function_initializer(selection))
func_init.value()(func);
node->children[i] = tree.alloc.template emplace<node_t>(func, tree.alloc);
}
}
// node has children that need populated
if (node->children[i]->type.argc() != 0)
stack.emplace(node->children[i], depth + 1);
}
case tree_init_t::GROW:
grow({tree, tree_info}, min_depth, max_depth);
break;
case tree_init_t::FULL:
full({tree, tree_info}, tree_info.engine.random_long(min_depth, max_depth));
break;
case tree_init_t::RAMPED_HALF_HALF:
break;
case tree_init_t::BRETT_HALF_HALF:
break;
}
return tree;
@ -159,7 +98,7 @@ namespace fb
func_t func(info.types.get_function_argc(selection), info.types.get_function(selection), type, selection);
if (const auto& func_init = info.types.get_function_initializer(selection))
func_init.value()(func);
return info.alloc.template emplace<detail::node_t>(func, info.alloc);
return info.tree.alloc.template emplace<detail::node_t>(func, info.tree.alloc);
}
detail::node_t* tree_t::allocate_terminal(detail::node_construction_info_t info, type_id type)
@ -175,7 +114,7 @@ namespace fb
func_t func(info.types.get_function_argc(selection), info.types.get_function(selection), type, selection);
if (const auto& func_init = info.types.get_function_initializer(selection))
func_init.value()(func);
return info.alloc.template emplace<detail::node_t>(func, info.alloc);
return info.tree.alloc.template emplace<detail::node_t>(func, info.tree.alloc);
}
detail::node_t* tree_t::allocate_non_terminal_restricted(detail::node_construction_info_t info, type_id type)
@ -194,7 +133,80 @@ namespace fb
func_t func(info.types.get_function_argc(selection), info.types.get_function(selection), type, selection);
if (const auto& func_init = info.types.get_function_initializer(selection))
(*func_init)(func);
return info.alloc.template emplace<detail::node_t>(func, info.alloc);
return info.tree.alloc.template emplace<detail::node_t>(func, info.tree.alloc);
}
void tree_t::grow(detail::node_construction_info_t info, blt::size_t min_depth, blt::size_t max_depth)
{
using namespace detail;
std::stack<std::pair<node_t*, blt::size_t>> stack;
stack.emplace(info.tree.root, 0);
while (!stack.empty())
{
auto top = stack.top();
auto* node = top.first;
auto depth = top.second;
stack.pop();
const auto& allowed_types = info.types.get_function_allowed_arguments(node->type.getFunction());
// we need to make sure there is at least one non-terminal generation, until we hit the min height
bool has_one_non_terminal = false;
for (blt::size_t i = 0; i < node->type.argc(); i++)
{
type_id type_category = allowed_types[i];
if (depth < min_depth && !has_one_non_terminal)
{
// make sure we have at least min height possible by using at least one non terminal
node->children[i] = allocate_non_terminal(info, type_category);
has_one_non_terminal = true;
} else if (depth >= max_depth || info.engine.choice(info.terminal_chance))
{
// if we are above the max_height select only terminals or otherwise select between use of terminals
node->children[i] = allocate_terminal(info, type_category);
} else
{
// and use of non-terminals method
node->children[i] = allocate_non_terminal(info, type_category);
}
// node has children that need populated
if (node->children[i]->type.argc() != 0)
stack.emplace(node->children[i], depth + 1);
}
}
}
void tree_t::full(detail::node_construction_info_t info, blt::size_t select_depth)
{
using namespace detail;
std::stack<std::pair<node_t*, blt::size_t>> stack;
stack.emplace(info.tree.root, 0);
while (!stack.empty())
{
auto top = stack.top();
auto* node = top.first;
auto depth = top.second;
stack.pop();
const auto& allowed_types = info.types.get_function_allowed_arguments(node->type.getFunction());
for (blt::size_t i = 0; i < node->type.argc(); i++)
{
type_id type_category = allowed_types[i];
if (depth >= select_depth)
{
// if we are above the max_height select only terminals
node->children[i] = allocate_terminal(info, type_category);
} else
{
// otherwise only non-terminals can be used
node->children[i] = allocate_non_terminal(info, type_category);
}
// node has children that need populated
if (node->children[i]->type.argc() != 0)
stack.emplace(node->children[i], depth + 1);
}
}
}