#pragma once /* * Copyright (C) 2024 Brett Terpstra * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ #ifndef COSC_4P80_ASSIGNMENT_2_COMMON_H #define COSC_4P80_ASSIGNMENT_2_COMMON_H #include #include namespace assign2 { using Scalar = float; const inline Scalar learn_rate = 0.1; template decltype(std::cout)& print_vec(const std::vector& vec) { for (auto [i, v] : blt::enumerate(vec)) { std::cout << v; if (i != vec.size() - 1) std::cout << ", "; } return std::cout; } struct data_t { bool is_bad = false; std::vector bins; }; struct data_file_t { std::vector data_points; }; class layer_t; class network_t; struct function_t { [[nodiscard]] virtual Scalar call(Scalar) const = 0; [[nodiscard]] virtual Scalar derivative(Scalar) const = 0; }; struct weight_view { public: weight_view(Scalar* data, blt::size_t size): m_data(data), m_size(size) {} inline Scalar& operator[](blt::size_t index) const { #if BLT_DEBUG_LEVEL > 0 if (index >= size) throw std::runtime_error("Index is out of bounds!"); #endif return m_data[index]; } [[nodiscard]] inline blt::size_t size() const { return m_size; } [[nodiscard]] auto begin() const { return m_data; } [[nodiscard]] auto end() const { return m_data + m_size; } private: Scalar* m_data; blt::size_t m_size; }; /** * this class exists purely as an optimization */ class weight_t { public: void preallocate(blt::size_t amount) { data.resize(amount); } weight_view allocate_view(blt::size_t count) { auto size = place; place += count; return {&data[size], count}; } void debug() const { std::cout << "Weights: "; print_vec(data) << std::endl; } private: blt::size_t place = 0; std::vector data; }; } #endif //COSC_4P80_ASSIGNMENT_2_COMMON_H