COSC-4P82-Final-Project/src/runner/main.cpp

485 lines
16 KiB
C++

/*
* Copyright (C) 2024 Brett Terpstra
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
#include "blt/std/assert.h"
#include "blt/std/memory.h"
#include "blt/std/types.h"
#include <aggregation.h>
#include <blt/fs/loader.h>
#include <blt/parse/argparse.h>
#include <blt/profiling/profiler_v2.h>
#include <blt/std/logging.h>
#include <blt/std/utility.h>
#include <chrono>
#include <cstdio>
#include <fcntl.h>
#include <ipc.h>
#include <limits>
#include <poll.h>
#include <random>
#include <string>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/un.h>
#include <sys/wait.h>
#include <thread>
#include <unistd.h>
struct timer_info
{
blt::u64 wall_time;
blt::u64 cpu_time;
blt::u64 cpu_cylces;
};
// run id - > timer info
blt::hashmap_t<blt::i32, timer_info> run_timers;
class child_t
{
private:
blt::i32 run;
int socket = 0;
bool socket_closed = false;
double fitness = 0;
std::vector<packet_t> unprocessed_packets;
public:
explicit child_t(blt::i32 run): run(run)
{};
void open(int sock)
{
socket = sock;
}
void processPackets()
{
for (const auto& v : unprocessed_packets)
{
if (v.id == packet_id::EXEC_TIME)
run_timers[run].wall_time = v.timer;
else if (v.id == packet_id::CPU_TIME)
run_timers[run].cpu_time = v.timer;
else if (v.id == packet_id::CPU_CYCLES)
run_timers[run].cpu_cylces = v.timer;
}
clearPackets(packet_id::EXEC_TIME);
clearPackets(packet_id::CPU_TIME);
clearPackets(packet_id::CPU_CYCLES);
}
ssize_t write(unsigned char* buffer, blt::size_t count)
{
pollfd fds_write{socket, POLLOUT, 0};
if (poll(&fds_write, 1, 1) < 0)
BLT_WARN("Error polling write %d", errno);
if (fds_write.revents & POLLHUP)
socket_closed = true;
if (fds_write.revents & POLLERR || fds_write.revents & POLLHUP || fds_write.revents & POLLNVAL)
return 0;
if (fds_write.revents & POLLOUT)
return ::write(socket, buffer, count);
return 0;
}
ssize_t read(unsigned char* buffer, blt::size_t count)
{
pollfd fds_read{socket, POLLIN, 0};
if (poll(&fds_read, 1, 1) < 0)
BLT_WARN("Error polling read %d", errno);
if (fds_read.revents & POLLHUP)
socket_closed = true;
if (fds_read.revents & POLLERR || fds_read.revents & POLLHUP || fds_read.revents & POLLNVAL)
return 0;
if (fds_read.revents & POLLIN)
return ::read(socket, buffer, count);
return 0;
}
void handlePacket(packet_t packet)
{
unprocessed_packets.push_back(packet);
}
[[nodiscard]] inline bool isSocketClosed() const
{
return socket_closed;
}
[[nodiscard]] inline const std::vector<packet_t>& pendingPackets() const
{
return unprocessed_packets;
}
inline void clearPackets(packet_id id)
{
auto it = unprocessed_packets.begin();
do
{
it = std::find_if(unprocessed_packets.begin(), unprocessed_packets.end(), [id](const auto& v) { return v.id == id; });
if (it == unprocessed_packets.end())
break;
std::iter_swap(it, unprocessed_packets.end() - 1);
unprocessed_packets.pop_back();
} while (true);
}
void setFitness(double f)
{
fitness = f;
}
[[nodiscard]] inline double getFitness() const
{
return fitness;
}
~child_t()
{
close(socket);
}
};
// children PIDs
blt::hashmap_t<std::int32_t, std::unique_ptr<child_t>> children;
sockaddr_un name{};
int host_socket = 0;
std::string SOCKET_LOCATION;
state_t current_state = state_t::RUN_GENERATIONS;
double fitness_cutoff = 0;
std::vector<double> fitness_storage;
int child_fp(blt::arg_parse::arg_results& args, int run_id, const std::string& socket_location)
{
auto program = "../" + args.get<std::string>("program");
auto file = "../" + args.get<std::string>("file");
auto rice_file = "../" + args.get<std::string>("rice");
auto dir = "./run_" + std::to_string(run_id);
BLT_DEBUG("Running GP program '%s' on run %d", program.c_str(), run_id);
mkdir(dir.c_str(), S_IREAD | S_IWRITE | S_IEXEC | S_IRGRP | S_IWGRP | S_IXGRP | S_IROTH | S_IXOTH);
if (chdir(dir.c_str()))
{
BLT_ERROR(errno);
return 1;
}
auto command = program + " -f " + file + " -p rice_file='" + rice_file + "' -p socket_location='" + socket_location + "' -p process_id=" +
std::to_string(getpid());
BLT_TRACE("Running command %s", command.c_str());
FILE* process = popen(command.c_str(), "r");
char buffer[4096];
while (fgets(buffer, 4096, process) != nullptr)
{
BLT_TRACE_STREAM << buffer;
}
pclose(process);
return 0;
}
void create_child_sockets()
{
blt::i64 ret;
unsigned char data[sizeof(pid_t)];
for (blt::u64 i = 0; i < children.size(); i++)
{
int socket_fd = accept(host_socket, nullptr, nullptr);
BLT_ASSERT(socket_fd != -1 && "Failed to create data socket!");
// wait until client sends pid data.
do
{
std::memset(data, 0, sizeof(data));
ret = read(socket_fd, data, sizeof(data));
} while (ret != sizeof(data));
pid_t pid;
blt::mem::fromBytes(data, pid);
//pid -= 1;
if (!children.contains(pid))
BLT_WARN("This PID '%d' does not exist as a child!", pid);
else
BLT_INFO("Established connection to child %d", pid);
children[pid]->open(socket_fd);
}
}
void remove_pending_finished_child_process()
{
int status;
for (const auto& pair : children)
{
auto pid = waitpid(pair.first, &status, WNOHANG);
if (pid != 0 && (WIFEXITED(status) || WIFSIGNALED(status)))
{
BLT_TRACE("Process %d exited? %b signaled? %b", pid, WIFEXITED(status), WIFSIGNALED(status));
auto child = std::find_if(children.begin(), children.end(), [pid](const auto& item) {
return item.first == pid;
});
if (child == children.end())
{
BLT_WARN("Unable to find child process %d!", pid);
return;
}
children.erase(child);
BLT_TRACE("Closing process %d finished!", pid);
}
}
}
void create_parent_socket()
{
std::memset(&name, 0, sizeof(name));
name.sun_family = AF_UNIX;
std::strncpy(name.sun_path, SOCKET_LOCATION.c_str(), sizeof(name.sun_path) - 1);
int ret;
BLT_INFO("Creating socket for %s", SOCKET_LOCATION.c_str());
host_socket = socket(AF_UNIX, SOCK_SEQPACKET, 0);
BLT_ASSERT(host_socket != -1 && "Failed to create socket!");
ret = bind(host_socket, (const struct sockaddr*) &name, sizeof(name));
BLT_ASSERT(ret == 0 && "Failed to bind socket");
ret = listen(host_socket, 20);
BLT_ASSERT(ret == 0 && "Failed to listen socket");
}
void send_execution_command(blt::i32 numGens)
{
packet_t packet{};
unsigned char buffer[sizeof(packet_t)];
auto it = children.begin();
while (it != children.end())
{
packet.id = packet_id::EXECUTE_RUN;
packet.numOfGens = numGens;
std::memcpy(buffer, &packet, sizeof(buffer));
if (it->second->write(buffer, sizeof(buffer)) <= 0)
{
if (it->second->isSocketClosed())
{
it = children.erase(it);
continue;
}
BLT_WARN("Failed to write to child error %d", errno);
}
++it;
}
}
void tick_state(blt::arg_parse::arg_results& args)
{
packet_t packet{};
unsigned char buffer[sizeof(packet_t)];
packet.state = current_state;
auto it = children.begin();
outer_while:
while (it != children.end())
{
auto& child = *it;
ssize_t ret;
// read all packets
do
{
if (ret = child.second->read(buffer, sizeof(buffer)), ret <= 0)
{
if (child.second->isSocketClosed())
{
it = children.erase(it);
// YUCKY
goto outer_while;
}
if (errno != 0)
BLT_WARN("Failed to read to child error %d", errno);
} else
{
std::memcpy(&packet, buffer, sizeof(buffer));
child.second->handlePacket(packet);
}
} while (ret > 0);
child.second->processPackets();
++it;
}
switch (current_state)
{
case state_t::RUN_GENERATIONS:
{
send_execution_command(args.get<blt::i32>("--num_gen"));
current_state = state_t::CHILD_EVALUATION;
break;
}
case state_t::CHILD_EVALUATION:
{
for (auto& child : children)
{
for (const auto& p : child.second->pendingPackets())
{
if (p.id == packet_id::CHILD_FIT)
{
child.second->setFitness(p.fitness);
fitness_storage.push_back(p.fitness);
break;
}
}
child.second->clearPackets(packet_id::CHILD_FIT);
}
if (fitness_storage.size() < children.size())
break;
std::sort(fitness_storage.begin(), fitness_storage.end());
auto ratio = args.get<double>("--prune_ratio");
auto cutoff = static_cast<long>(static_cast<double>(fitness_storage.size()) * ratio);
if (!fitness_storage.empty())
fitness_cutoff = fitness_storage[cutoff];
else
BLT_WARN("Running with no active populations?");
BLT_INFO("Cutoff value %d, current size %d, fitness: %f", cutoff, fitness_storage.size(), fitness_cutoff);
current_state = state_t::PRUNE;
fitness_storage.clear();
break;
}
case state_t::PRUNE:
{
if (children.size() == 1)
{
// run to completion, we no longer need to sync with the server.
send_execution_command(std::numeric_limits<blt::i32>::max());
// keep the server in idle state, this way we can still handle incoming packets
// since we will need to get information about pop stats
current_state = state_t::IDLE;
break;
}
BLT_DEBUG("Pruning with fitness %f", fitness_cutoff);
auto it = children.begin();
while (it != children.end())
{
auto& child = *it;
if (child.second->getFitness() <= fitness_cutoff)
{
packet.id = packet_id::PRUNE;
packet.fitness = fitness_cutoff;
std::memcpy(buffer, &packet, sizeof(buffer));
if (child.second->write(buffer, sizeof(buffer)) <= 0)
{
if (child.second->isSocketClosed())
{
it = children.erase(it);
continue;
}
BLT_WARN("Failed to write to child error %d", errno);
}
}
++it;
}
current_state = state_t::RUN_GENERATIONS;
break;
}
case state_t::IDLE:
std::this_thread::sleep_for(std::chrono::milliseconds(1));
break;
}
}
void init_sockets(blt::arg_parse::arg_results& args)
{
create_child_sockets();
while (!children.empty())
{
remove_pending_finished_child_process();
tick_state(args);
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
unlink(name.sun_path);
close(host_socket);
}
int main(int argc, const char** argv)
{
blt::arg_parse parser;
parser.addArgument(blt::arg_builder("-n", "--num_pops").setDefault("120").setHelp("Number of populations to start").build());
parser.addArgument(blt::arg_builder("-g", "--num_gen").setDefault("5").setHelp("Number of generations between pruning").build());
parser.addArgument(blt::arg_builder("-p", "--prune_ratio").setDefault("0.2").setHelp("Number of generations to run before pruning").build());
parser.addArgument(blt::arg_builder("--program").setDefault("./FinalProject").setHelp("GP Program to execute per run").build());
parser.addArgument(blt::arg_builder("--out_file").setDefault("regress")
.setHelp("Name of the stats file (without extension) to use in building the final data")
.build());
parser.addArgument(
blt::arg_builder("--write_file").setDefault("aggregated").setHelp("Name of the file to write the aggregated data to (without extension)")
.build());
parser.addArgument(blt::arg_builder("--file").setDefault("../input.file").setHelp("File to run the GP on").build());
parser.addArgument(blt::arg_builder("--rice").setDefault("../Rice_Cammeo_Osmancik.arff").setHelp("Rice file to run the GP on").build());
auto args = parser.parse_args(argc, argv);
BLT_INFO("%b", args.contains("--num_pops"));
BLT_INFO(args.get<std::string>("--write_file"));
BLT_INFO("Parsing user arguments:");
for (auto& v : args.data)
BLT_INFO("\t%s = %s", v.first.c_str(), blt::to_string(v.second).c_str());
std::string random_id;
std::random_device dev;
std::mt19937_64 engine(dev());
std::uniform_int_distribution charGenLower('a', 'z');
std::uniform_int_distribution charGenUpper('A', 'Z');
std::uniform_int_distribution choice(0, 1);
for (int i = 0; i < 5; i++)
{
if (choice(engine))
random_id += static_cast<char>(charGenLower(engine));
else
random_id += static_cast<char>(charGenUpper(engine));
}
auto runs = args.get<std::int32_t>("num_pops");
BLT_DEBUG("Running with %d runs", runs);
SOCKET_LOCATION = "/tmp/gp_program_" + random_id + ".socket";
create_parent_socket();
for (auto i = 0; i < runs; i++)
{
auto pid = fork();
if (pid == 0)
return child_fp(args, i, SOCKET_LOCATION);
else if (pid > 0)
{
// parent
children.insert({pid, std::make_unique<child_t>(i)});
BLT_TRACE("Forked child to %d", pid);
} else
{
// failure
BLT_ERROR("Failed to fork process! Error: %d", errno);
return 1;
}
}
init_sockets(args);
process_files(args.get<std::string>("--out_file"), args.get<std::string>("--write_file"), args.get<int>("--num_pops"));
}