commit 3ac5de603ae52cd8a7d91972a1cc19053e74ed26 Author: Brett Date: Tue Dec 10 14:13:55 2024 -0500 hello diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0ad02d1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +cmake-build*/ +build/ +out/ +./cmake-build*/ +./build/ +./out/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..70a5151 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "lib/blt-with-graphics"] + path = lib/blt-with-graphics + url = https://git.tpgc.me/tri11paragon/BLT-With-Graphics-Template diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/COSC-4P80-Final-Project.iml b/.idea/COSC-4P80-Final-Project.iml new file mode 100644 index 0000000..f08604b --- /dev/null +++ b/.idea/COSC-4P80-Final-Project.iml @@ -0,0 +1,2 @@ + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..0b76fe5 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..37eb8ca --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..9521f67 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,29 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..7587985 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,44 @@ +cmake_minimum_required(VERSION 3.29) +project(COSC-4P80-Final-Project VERSION 0.0.2) + +option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) +option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) +option(ENABLE_TSAN "Enable the thread data race sanitizer" OFF) + +set(CMAKE_CXX_STANDARD 17) + +add_subdirectory(lib/blt-with-graphics) + +add_compile_options("-fopenmp") +add_link_options("-fopenmp") + +fetchcontent_declare(dlib + URL http://dlib.net/files/dlib-19.24.tar.bz2 + URL_HASH MD5=8a98957a73eebd3cd7431c2bac79665f + FIND_PACKAGE_ARGS) +fetchcontent_makeavailable(dlib) + +include_directories(include/) +file(GLOB_RECURSE PROJECT_BUILD_FILES "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp") + +add_executable(COSC-4P80-Final-Project ${PROJECT_BUILD_FILES}) + +target_compile_options(COSC-4P80-Final-Project PRIVATE -Wall -Wextra -Wpedantic -Wno-comment) +target_link_options(COSC-4P80-Final-Project PRIVATE -Wall -Wextra -Wpedantic -Wno-comment) + +target_link_libraries(COSC-4P80-Final-Project PRIVATE BLT_WITH_GRAPHICS dlib) + +if (${ENABLE_ADDRSAN} MATCHES ON) + target_compile_options(COSC-4P80-Final-Project PRIVATE -fsanitize=address) + target_link_options(COSC-4P80-Final-Project PRIVATE -fsanitize=address) +endif () + +if (${ENABLE_UBSAN} MATCHES ON) + target_compile_options(COSC-4P80-Final-Project PRIVATE -fsanitize=undefined) + target_link_options(COSC-4P80-Final-Project PRIVATE -fsanitize=undefined) +endif () + +if (${ENABLE_TSAN} MATCHES ON) + target_compile_options(COSC-4P80-Final-Project PRIVATE -fsanitize=thread) + target_link_options(COSC-4P80-Final-Project PRIVATE -fsanitize=thread) +endif () diff --git a/commit.py b/commit.py new file mode 100755 index 0000000..440e164 --- /dev/null +++ b/commit.py @@ -0,0 +1,285 @@ +#!python3 + +import subprocess +import argparse +import sys +import os +import itertools +import requests +import json +from pathlib import Path + +#--------------------------------------- +# CONFIG +#--------------------------------------- + +VERSION_BEGIN_STR = " VERSION " +VERSION_END_STR = ")" + +#--------------------------------------- +# DO NOT TOUCH +#--------------------------------------- + +USER_HOME = Path.home() +ENVIRONMENT_DATA_LOCATION = USER_HOME / ".brett_scripts.env" + +if sys.platform.startswith("win"): + CONFIG_FILE_DIRECTORY = Path(os.getenv('APPDATA') + "\BLT") + CONFIG_FILE_LOCATION = Path(CONFIG_FILE_DIRECTORY + "\commit_config.env") +else: + XDG_CONFIG_HOME = os.environ.get('XDG_CONFIG_HOME') + if XDG_CONFIG_HOME is None: + XDG_CONFIG_HOME = USER_HOME / ".config" + else: + XDG_CONFIG_HOME = Path(XDG_CONFIG_HOME) + + if len(str(XDG_CONFIG_HOME)) == 0: + XDG_CONFIG_HOME = USER_HOME + CONFIG_FILE_DIRECTORY = XDG_CONFIG_HOME / "blt" + CONFIG_FILE_LOCATION = CONFIG_FILE_DIRECTORY / "commit_config.env" + +class Config: + def __init__(self): + # Inline with semantic versioning it doesn't make sense to branch / release on minor + self.branch_on_major = True + self.branch_on_minor = False + self.release_on_major = True + self.release_on_minor = False + self.main_branch = "main" + self.patch_limit = -1 + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + def fromJSON(file): + with open(file, "r") as f: + j = json.load(f) + obj = Config() + [setattr(obj, key, val) for key, val in j.items() if hasattr(obj, key)] + return obj + + def from_file(file): + values = {} + if (not os.path.exists(file)): + return Config() + + with open(file, "r") as f: + j = json.load(f) + obj = Config() + [setattr(obj, key, val) for key, val in j.items() if hasattr(obj, key)] + return obj + + def save_to_file(self, file): + dir_index = str(file).rfind("/") + dir = str(file)[:dir_index] + if not os.path.exists(dir): + print(f"Creating config directory {dir}") + os.makedirs(dir) + with open(file, "w") as f: + json.dump(self, f, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class EnvData: + def __init__(self, github_username = '', github_token = ''): + self.github_token = github_token + self.github_username = github_username + + def get_env_from_file(file): + f = open(file, "rt") + values = {} + for line in f: + if line.startswith("export"): + content = line.split("=") + for idx, c in enumerate(content): + content[idx] = c.replace("export", "").strip() + values[content[0]] = content[1].replace("\"", "").replace("'", "") + try: + github_token = values["github_token"] + except Exception: + print("Failed to parse github token!") + try: + github_username = values["github_username"] + except: + print("Failed to parse github username! Assuming you are me!") + github_username = "Tri11Paragon" + return EnvData(github_username=github_username, github_token=github_token) + +def open_process(command, print_out = True): + process = subprocess.Popen(command, stderr=subprocess.PIPE, stdout=subprocess.PIPE) + stdout, stderr = process.communicate() + exit_code = process.wait() + str_out = stdout.decode('utf8') + str_err = stderr.decode('utf8') + if print_out and len(str_out) > 0: + print(str_out, end='') + if print_out and len(str_err) > 0: + print(str_err, end='') + #print(stdout, stderr, exit_code) + return (stdout, stderr, exit_code) + +def load_cmake(): + cmake_file = open("CMakeLists.txt", 'r') + cmake_text = cmake_file.read() + cmake_file.close() + return cmake_text + +def write_cmake(cmake_text): + cmake_file = open("CMakeLists.txt", 'w') + cmake_file.write(cmake_text) + cmake_file.close() + +def get_version(cmake_text): + begin = cmake_text.find(VERSION_BEGIN_STR) + len(VERSION_BEGIN_STR) + end = cmake_text.find(VERSION_END_STR, begin) + return (cmake_text[begin:end], begin, end) + +def split_version(cmake_text): + version, begin, end = get_version(cmake_text) + version_parts = version.split('.') + return (version_parts, begin, end) + +def recombine(cmake_text, version_parts, begin, end): + constructed_version = version_parts[0] + '.' + version_parts[1] + '.' + version_parts[2] + constructed_text_begin = cmake_text[0:begin] + constructed_text_end = cmake_text[end::] + return constructed_text_begin + constructed_version + constructed_text_end + +def inc_major(cmake_text): + version_parts, begin, end = split_version(cmake_text) + version_parts[0] = str(int(version_parts[0]) + 1) + version_parts[1] = '0' + version_parts[2] = '0' + return recombine(cmake_text, version_parts, begin, end) + +def inc_minor(cmake_text): + version_parts, begin, end = split_version(cmake_text) + version_parts[1] = str(int(version_parts[1]) + 1) + version_parts[2] = '0' + return recombine(cmake_text, version_parts, begin, end) + +def inc_patch(config: Config, cmake_text): + version_parts, begin, end = split_version(cmake_text) + if config.patch_limit > 0 and int(version_parts[2]) + 1 >= config.patch_limit: + return inc_minor(cmake_text) + version_parts[2] = str(int(version_parts[2]) + 1) + return recombine(cmake_text, version_parts, begin, end) + +def make_branch(config: Config, name): + print(f"Making new branch {name}") + subprocess.call(["git", "checkout", "-b", name]) + subprocess.call(["git", "merge", config.main_branch]) + subprocess.call(["git", "checkout", config.main_branch]) + +def make_release(env: EnvData, name): + print(f"Making new release {name}") + repos_v = open_process(["git", "remote", "-v"])[0].splitlines() + urls = [] + for line in repos_v: + origin = ''.join(itertools.takewhile(str.isalpha, line.decode('utf8'))) + urls.append("https://api.github.com/repos/" + open_process(["git", "remote", "get-url", origin], False)[0].decode('utf8').replace("\n", "").replace("https://github.com/", "") + "/releases") + urls = set(urls) + data = { + 'tag_name': name, + 'name': name, + 'body': "Automated Release '" + name + "'", + 'draft': False, + 'prerelease': False + } + headers = { + 'Authorization': f'Bearer {env.github_token}', + 'Accept': 'application/vnd.github+json', + 'X-GitHub-Api-Version': '2022-11-28' + } + for url in urls: + response = requests.post(url, headers=headers, data=json.dumps(data)) + if response.status_code == 201: + print('Release created successfully!') + release_data = response.json() + print(f"Release URL: {release_data['html_url']}") + else: + print(f"Failed to create release: {response.status_code}") + print(response.json()) + + +def main(): + parser = argparse.ArgumentParser( + prog="Commit Helper", + description="Help you make pretty commits :3") + + parser.add_argument("action", nargs='?', default=None) + parser.add_argument("-p", "--patch", action='store_true', default=False, required=False) + parser.add_argument("-m", "--minor", action='store_true', default=False, required=False) + parser.add_argument("-M", "--major", action='store_true', default=False, required=False) + parser.add_argument('-e', "--env", help="environment file", required=False, default=None) + parser.add_argument('-c', "--config", help="config file", required=False, default=None) + parser.add_argument("--create_default_config", action="store_true", default=False, required=False) + + args = parser.parse_args() + + if args.env is not None: + env = EnvData.get_env_from_file(args.e) + else: + env = EnvData.get_env_from_file(ENVIRONMENT_DATA_LOCATION) + + if args.config is not None: + config = Config.from_file(args.config) + else: + config = Config.from_file(CONFIG_FILE_LOCATION) + + if args.create_default_config: + config.save_to_file(args.config if args.config is not None else CONFIG_FILE_LOCATION) + + cmake_text = load_cmake() + cmake_version = get_version(cmake_text)[0] + print(f"Current Version: {cmake_version}") + + if not (args.patch or args.minor or args.major): + try: + if args.action is not None: + type = args.action + else: + type = input("What kind of commit is this ((M)ajor, (m)inor, (p)atch)? ") + + if type.startswith('M'): + args.major = True + elif type.startswith('m'): + args.minor = True + elif type.startswith('p') or type.startswith('P') or len(type) == 0: + args.patch = True + except KeyboardInterrupt: + print("\nCancelling!") + return + + if args.major: + print("Selected major") + write_cmake(inc_major(cmake_text)) + elif args.minor: + print("Selected minor") + write_cmake(inc_minor(cmake_text)) + elif args.patch: + print("Selected patch") + write_cmake(inc_patch(config, cmake_text)) + + subprocess.call(["git", "add", "*"]) + subprocess.call(["git", "commit"]) + + cmake_text = load_cmake() + version_parts = split_version(cmake_text)[0] + if args.major: + if config.branch_on_major: + make_branch(config, "v" + str(version_parts[0])) + if args.minor: + if config.branch_on_minor: + make_branch(config, "v" + str(version_parts[0]) + "." + str(version_parts[1])) + + subprocess.call(["sh", "-c", "git remote | xargs -L1 git push --all"]) + + if args.major: + if config.release_on_major: + make_release(env, "v" + str(version_parts[0])) + if args.minor: + if config.release_on_minor: + make_release(env, "v" + str(version_parts[0]) + "." + str(version_parts[1])) + +if __name__ == "__main__": + main() diff --git a/default.nix b/default.nix new file mode 100644 index 0000000..083d0ca --- /dev/null +++ b/default.nix @@ -0,0 +1,56 @@ +{ pkgs ? (import { + config.allowUnfree = true; + config.segger-jlink.acceptLicense = true; +}), customPkgs ? (import /home/brett/my-nixpkgs { + config.allowUnfree = true; + config.segger-jlink.acceptLicense = true; +}), ... }: +pkgs.mkShell +{ + buildInputs = with pkgs; [ + cmake + gcc + clang + emscripten + ninja + customPkgs.jetbrains.clion + #clion = import ~/my-nixpkgs/pkgs/applications/editors/jetbrains {}; + renderdoc + valgrind + ]; + propagatedBuildInputs = with pkgs; [ + xorg.libX11 + xorg.libX11.dev + xorg.libXcursor + xorg.libXcursor.dev + xorg.libXext + xorg.libXext.dev + xorg.libXinerama + xorg.libXinerama.dev + xorg.libXrandr + xorg.libXrandr.dev + xorg.libXrender + xorg.libXrender.dev + xorg.libxcb + xorg.libxcb.dev + xorg.libXi + xorg.libXi.dev + harfbuzz + harfbuzz.dev + zlib + zlib.dev + bzip2 + bzip2.dev + pngpp + brotli + brotli.dev + pulseaudio.dev + git + libGL + libGL.dev + glfw + openblas + openblas.dev + ]; + LD_LIBRARY_PATH="/run/opengl-driver/lib:/run/opengl-driver-32/lib"; +} diff --git a/include/MNIST.h b/include/MNIST.h new file mode 100644 index 0000000..e788f7c --- /dev/null +++ b/include/MNIST.h @@ -0,0 +1,27 @@ +#pragma once +/* + * Copyright (C) 2024 Brett Terpstra + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#ifndef MNIST_H +#define MNIST_H + +namespace fp +{ + void run_mnist(); +} + +#endif //MNIST_H diff --git a/lib/blt-with-graphics b/lib/blt-with-graphics new file mode 160000 index 0000000..29286e6 --- /dev/null +++ b/lib/blt-with-graphics @@ -0,0 +1 @@ +Subproject commit 29286e66daa724ef08692d9a65e9c88e5467d9b2 diff --git a/problems/mnist/mnist-dataset.zip b/problems/mnist/mnist-dataset.zip new file mode 100644 index 0000000..188e947 Binary files /dev/null and b/problems/mnist/mnist-dataset.zip differ diff --git a/problems/mnist/t10k-images-idx3-ubyte/t10k-images-idx3-ubyte b/problems/mnist/t10k-images-idx3-ubyte/t10k-images-idx3-ubyte new file mode 100644 index 0000000..1170b2c Binary files /dev/null and b/problems/mnist/t10k-images-idx3-ubyte/t10k-images-idx3-ubyte differ diff --git a/problems/mnist/t10k-images.idx3-ubyte b/problems/mnist/t10k-images.idx3-ubyte new file mode 100644 index 0000000..1170b2c Binary files /dev/null and b/problems/mnist/t10k-images.idx3-ubyte differ diff --git a/problems/mnist/t10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte b/problems/mnist/t10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte new file mode 100644 index 0000000..d1c3a97 Binary files /dev/null and b/problems/mnist/t10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte differ diff --git a/problems/mnist/t10k-labels.idx1-ubyte b/problems/mnist/t10k-labels.idx1-ubyte new file mode 100644 index 0000000..d1c3a97 Binary files /dev/null and b/problems/mnist/t10k-labels.idx1-ubyte differ diff --git a/problems/mnist/train-images-idx3-ubyte/train-images-idx3-ubyte b/problems/mnist/train-images-idx3-ubyte/train-images-idx3-ubyte new file mode 100644 index 0000000..bbce276 Binary files /dev/null and b/problems/mnist/train-images-idx3-ubyte/train-images-idx3-ubyte differ diff --git a/problems/mnist/train-images.idx3-ubyte b/problems/mnist/train-images.idx3-ubyte new file mode 100644 index 0000000..bbce276 Binary files /dev/null and b/problems/mnist/train-images.idx3-ubyte differ diff --git a/problems/mnist/train-labels-idx1-ubyte/train-labels-idx1-ubyte b/problems/mnist/train-labels-idx1-ubyte/train-labels-idx1-ubyte new file mode 100644 index 0000000..d6b4c5d Binary files /dev/null and b/problems/mnist/train-labels-idx1-ubyte/train-labels-idx1-ubyte differ diff --git a/problems/mnist/train-labels.idx1-ubyte b/problems/mnist/train-labels.idx1-ubyte new file mode 100644 index 0000000..d6b4c5d Binary files /dev/null and b/problems/mnist/train-labels.idx1-ubyte differ diff --git a/src/MNIST.cpp b/src/MNIST.cpp new file mode 100644 index 0000000..482e418 --- /dev/null +++ b/src/MNIST.cpp @@ -0,0 +1,285 @@ +/* + * + * Copyright (C) 2024 Brett Terpstra + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fp +{ + class idx_file_t + { + template + using mk_v = std::vector; + using vec_t = std::variant, mk_v, mk_v, mk_v, mk_v, mk_v>; + + public: + explicit idx_file_t(const std::string& path) + { + std::ifstream file{path, std::ios::in | std::ios::binary}; + + using char_type = std::ifstream::char_type; + char_type magic_arr[4]; + file.read(magic_arr, 4); + BLT_ASSERT(magic_arr[0] == 0 && magic_arr[1] == 0); + + blt::u8 dims = magic_arr[3]; + blt::size_t total_size = 1; + + for (blt::i32 i = 0; i < dims; i++) + { + char_type dim_arr[4]; + file.read(dim_arr, 4); + blt::u32 dim; + blt::mem::fromBytes(dim_arr, dim); + dimensions.push_back(dim); + total_size *= dim; + } + + switch (magic_arr[2]) + { + // unsigned char + case 0x08: + data = mk_v{}; + read_data(file, total_size); + break; + // signed char + case 0x09: + data = mk_v{}; + read_data(file, total_size); + break; + // short + case 0x0B: + data = mk_v{}; + read_data(file, total_size); + reverse_data(); + break; + // int + case 0x0C: + data = mk_v{}; + read_data(file, total_size); + reverse_data(); + break; + // float + case 0x0D: + data = mk_v{}; + read_data(file, total_size); + reverse_data(); + break; + // double + case 0x0E: + data = mk_v{}; + read_data(file, total_size); + reverse_data(); + break; + default: + BLT_ERROR("Unspported idx file type!"); + } + if (file.eof()) + { + BLT_ERROR("EOF reached. It's unlikely your file was read correctly!"); + } + } + + template + [[nodiscard]] const std::vector& get_data_as() const + { + return std::get>(data); + } + + template + std::vector> get_as_spans() const + { + std::vector> spans; + + blt::size_t total_size = data_size(1); + + for (blt::size_t i = 0; i < dimensions[0]; i++) + { + auto& array = std::get>(data); + spans.push_back({&array[i * total_size], total_size}); + } + + return spans; + } + + [[nodiscard]] const std::vector& get_dimensions() const + { + return dimensions; + } + + [[nodiscard]] blt::size_t data_size(const blt::size_t starting_dimension = 0) const + { + blt::size_t total_size = 1; + for (const auto d : blt::iterate(dimensions).skip(starting_dimension)) + total_size *= d; + return total_size; + } + + private: + template + void read_data(std::ifstream& file, blt::size_t total_size) + { + auto& array = std::get>(data); + array.resize(total_size); + file.read(reinterpret_cast(array.data()), static_cast(total_size) * sizeof(T)); + } + + template + void reverse_data() + { + auto& array = std::get>(data); + for (auto& v : array) + blt::mem::reverse(v); + } + + std::vector dimensions; + vec_t data; + }; + + class image_t + { + public: + static constexpr blt::u32 target_size = 10; + + image_t(const idx_file_t& image_data, const idx_file_t& label_data): samples(image_data.get_dimensions()[0]), + input_size(image_data.data_size(1)) + { + BLT_ASSERT_MSG(samples == label_data.get_dimensions()[0], + ("Mismatch in data sample sizes! " + std::to_string(samples) + " vs " + std::to_string(label_data.get_dimensions()[0])). + c_str()); + auto& image_array = image_data.get_data_as(); + auto& label_array = label_data.get_data_as(); + + for (const auto label : label_array) + image_labels.push_back(label); + + const auto row_length = image_data.get_dimensions()[2]; + const auto number_of_rows = image_data.get_dimensions()[1]; + + for (blt::u32 i = 0; i < samples; i++) + { + dlib::matrix mat(number_of_rows, row_length); + for (blt::u32 y = 0; y < number_of_rows; y++) + { + for (blt::u32 x = 0; x < row_length; x++) + { + mat(x, y) = image_array[i * input_size + y * row_length + x]; + } + } + data.push_back(mat); + } + } + + [[nodiscard]] const std::vector>& get_image_data() const + { + return data; + } + + [[nodiscard]] const std::vector& get_image_labels() const + { + return image_labels; + } + + private: + blt::u32 samples; + blt::u32 input_size; + std::vector> data; + std::vector image_labels; + }; + + void run_mnist() + { + using namespace dlib; + + idx_file_t test_images{"../problems/mnist/t10k-images.idx3-ubyte"}; + idx_file_t test_labels{"../problems/mnist/t10k-labels.idx1-ubyte"}; + idx_file_t train_images{"../problems/mnist/train-images.idx3-ubyte"}; + idx_file_t train_labels{"../problems/mnist/train-labels.idx1-ubyte"}; + + auto train_samples = train_images.get_dimensions()[0]; + auto test_samples = test_images.get_dimensions()[0]; + + auto columns = train_images.get_dimensions()[1]; + auto rows = train_images.get_dimensions()[2]; + + auto input_size = static_cast(train_images.data_size(1)); + + image_t train_image{train_images, train_labels}; + image_t test_image{test_images, test_labels}; + + using net_type = loss_multiclass_log< + fc<10, + relu>>>>>>>>>>>>>; + + net_type network{}; + + dnn_trainer trainer(network); + trainer.set_learning_rate(0.01); + trainer.set_min_learning_rate(0.00001); + trainer.set_mini_batch_size(128); + trainer.be_verbose(); + + trainer.set_synchronization_file("mnist_sync", std::chrono::seconds(20)); + + trainer.train(train_image.get_image_data(), train_image.get_image_labels()); + trainer.train_one_step(train_image.get_image_data(), train_image.get_image_labels()); + + network.clean(); + serialize("mnist_network.dat") << network; + + std::vector predicted_labels = network(train_image.get_image_data()); + int num_right = 0; + int num_wrong = 0; + // And then let's see if it classified them correctly. + for (size_t i = 0; i < train_image.get_image_data().size(); ++i) + { + if (predicted_labels[i] == train_image.get_image_labels()[i]) + ++num_right; + else + ++num_wrong; + + } + std::cout << "training num_right: " << num_right << std::endl; + std::cout << "training num_wrong: " << num_wrong << std::endl; + std::cout << "training accuracy: " << num_right / static_cast(num_right + num_wrong) << std::endl; + + predicted_labels = network(test_image.get_image_data()); + num_right = 0; + num_wrong = 0; + for (size_t i = 0; i < test_image.get_image_data().size(); ++i) + { + if (predicted_labels[i] == test_image.get_image_labels()[i]) + ++num_right; + else + ++num_wrong; + + } + std::cout << "testing num_right: " << num_right << std::endl; + std::cout << "testing num_wrong: " << num_wrong << std::endl; + std::cout << "testing accuracy: " << num_right / static_cast(num_right + num_wrong) << std::endl; + } +} diff --git a/src/main.cpp b/src/main.cpp new file mode 100644 index 0000000..18fbfbe --- /dev/null +++ b/src/main.cpp @@ -0,0 +1,8 @@ +#include + +#include + +int main() +{ + fp::run_mnist(); +}