hello
commit
3ac5de603a
|
@ -0,0 +1,6 @@
|
|||
cmake-build*/
|
||||
build/
|
||||
out/
|
||||
./cmake-build*/
|
||||
./build/
|
||||
./out/
|
|
@ -0,0 +1,3 @@
|
|||
[submodule "lib/blt-with-graphics"]
|
||||
path = lib/blt-with-graphics
|
||||
url = https://git.tpgc.me/tri11paragon/BLT-With-Graphics-Template
|
|
@ -0,0 +1,8 @@
|
|||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
|
@ -0,0 +1,2 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module classpath="CMake" type="CPP_MODULE" version="4" />
|
|
@ -0,0 +1,7 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="CMakePythonSetting">
|
||||
<option name="pythonIntegrationState" value="YES" />
|
||||
</component>
|
||||
<component name="CMakeWorkspace" PROJECT_DIR="$PROJECT_DIR$" />
|
||||
</project>
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/COSC-4P80-Final-Project.iml" filepath="$PROJECT_DIR$/.idea/COSC-4P80-Final-Project.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
|
@ -0,0 +1,29 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-debug/_deps/freetype-src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-debug/_deps/freetype-src/subprojects/dlg" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-debug/_deps/glfw3-src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-debug/_deps/imgui-src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-debug/_deps/opennn-src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-release/_deps/freetype-src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-release/_deps/freetype-src/subprojects/dlg" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-release/_deps/glfw3-src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-release/_deps/imgui-src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-release/_deps/opennn-src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo-addrsan/_deps/freetype-src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo-addrsan/_deps/freetype-src/subprojects/dlg" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo-addrsan/_deps/glfw3-src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo-addrsan/_deps/imgui-src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo-addrsan/_deps/opennn-src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo/_deps/freetype-src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo/_deps/freetype-src/subprojects/dlg" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo/_deps/glfw3-src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo/_deps/imgui-src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo/_deps/opennn-src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/lib/blt-with-graphics" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/lib/blt-with-graphics/libraries/BLT" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/lib/blt-with-graphics/libraries/BLT/libraries/parallel-hashmap" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
|
@ -0,0 +1,44 @@
|
|||
cmake_minimum_required(VERSION 3.29)
|
||||
project(COSC-4P80-Final-Project VERSION 0.0.2)
|
||||
|
||||
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
|
||||
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
|
||||
option(ENABLE_TSAN "Enable the thread data race sanitizer" OFF)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
add_subdirectory(lib/blt-with-graphics)
|
||||
|
||||
add_compile_options("-fopenmp")
|
||||
add_link_options("-fopenmp")
|
||||
|
||||
fetchcontent_declare(dlib
|
||||
URL http://dlib.net/files/dlib-19.24.tar.bz2
|
||||
URL_HASH MD5=8a98957a73eebd3cd7431c2bac79665f
|
||||
FIND_PACKAGE_ARGS)
|
||||
fetchcontent_makeavailable(dlib)
|
||||
|
||||
include_directories(include/)
|
||||
file(GLOB_RECURSE PROJECT_BUILD_FILES "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp")
|
||||
|
||||
add_executable(COSC-4P80-Final-Project ${PROJECT_BUILD_FILES})
|
||||
|
||||
target_compile_options(COSC-4P80-Final-Project PRIVATE -Wall -Wextra -Wpedantic -Wno-comment)
|
||||
target_link_options(COSC-4P80-Final-Project PRIVATE -Wall -Wextra -Wpedantic -Wno-comment)
|
||||
|
||||
target_link_libraries(COSC-4P80-Final-Project PRIVATE BLT_WITH_GRAPHICS dlib)
|
||||
|
||||
if (${ENABLE_ADDRSAN} MATCHES ON)
|
||||
target_compile_options(COSC-4P80-Final-Project PRIVATE -fsanitize=address)
|
||||
target_link_options(COSC-4P80-Final-Project PRIVATE -fsanitize=address)
|
||||
endif ()
|
||||
|
||||
if (${ENABLE_UBSAN} MATCHES ON)
|
||||
target_compile_options(COSC-4P80-Final-Project PRIVATE -fsanitize=undefined)
|
||||
target_link_options(COSC-4P80-Final-Project PRIVATE -fsanitize=undefined)
|
||||
endif ()
|
||||
|
||||
if (${ENABLE_TSAN} MATCHES ON)
|
||||
target_compile_options(COSC-4P80-Final-Project PRIVATE -fsanitize=thread)
|
||||
target_link_options(COSC-4P80-Final-Project PRIVATE -fsanitize=thread)
|
||||
endif ()
|
|
@ -0,0 +1,285 @@
|
|||
#!python3
|
||||
|
||||
import subprocess
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import itertools
|
||||
import requests
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
#---------------------------------------
|
||||
# CONFIG
|
||||
#---------------------------------------
|
||||
|
||||
VERSION_BEGIN_STR = " VERSION "
|
||||
VERSION_END_STR = ")"
|
||||
|
||||
#---------------------------------------
|
||||
# DO NOT TOUCH
|
||||
#---------------------------------------
|
||||
|
||||
USER_HOME = Path.home()
|
||||
ENVIRONMENT_DATA_LOCATION = USER_HOME / ".brett_scripts.env"
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
CONFIG_FILE_DIRECTORY = Path(os.getenv('APPDATA') + "\BLT")
|
||||
CONFIG_FILE_LOCATION = Path(CONFIG_FILE_DIRECTORY + "\commit_config.env")
|
||||
else:
|
||||
XDG_CONFIG_HOME = os.environ.get('XDG_CONFIG_HOME')
|
||||
if XDG_CONFIG_HOME is None:
|
||||
XDG_CONFIG_HOME = USER_HOME / ".config"
|
||||
else:
|
||||
XDG_CONFIG_HOME = Path(XDG_CONFIG_HOME)
|
||||
|
||||
if len(str(XDG_CONFIG_HOME)) == 0:
|
||||
XDG_CONFIG_HOME = USER_HOME
|
||||
CONFIG_FILE_DIRECTORY = XDG_CONFIG_HOME / "blt"
|
||||
CONFIG_FILE_LOCATION = CONFIG_FILE_DIRECTORY / "commit_config.env"
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
# Inline with semantic versioning it doesn't make sense to branch / release on minor
|
||||
self.branch_on_major = True
|
||||
self.branch_on_minor = False
|
||||
self.release_on_major = True
|
||||
self.release_on_minor = False
|
||||
self.main_branch = "main"
|
||||
self.patch_limit = -1
|
||||
|
||||
def toJSON(self):
|
||||
return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4)
|
||||
|
||||
def fromJSON(file):
|
||||
with open(file, "r") as f:
|
||||
j = json.load(f)
|
||||
obj = Config()
|
||||
[setattr(obj, key, val) for key, val in j.items() if hasattr(obj, key)]
|
||||
return obj
|
||||
|
||||
def from_file(file):
|
||||
values = {}
|
||||
if (not os.path.exists(file)):
|
||||
return Config()
|
||||
|
||||
with open(file, "r") as f:
|
||||
j = json.load(f)
|
||||
obj = Config()
|
||||
[setattr(obj, key, val) for key, val in j.items() if hasattr(obj, key)]
|
||||
return obj
|
||||
|
||||
def save_to_file(self, file):
|
||||
dir_index = str(file).rfind("/")
|
||||
dir = str(file)[:dir_index]
|
||||
if not os.path.exists(dir):
|
||||
print(f"Creating config directory {dir}")
|
||||
os.makedirs(dir)
|
||||
with open(file, "w") as f:
|
||||
json.dump(self, f, default=lambda o: o.__dict__, sort_keys=True, indent=4)
|
||||
|
||||
|
||||
class EnvData:
|
||||
def __init__(self, github_username = '', github_token = ''):
|
||||
self.github_token = github_token
|
||||
self.github_username = github_username
|
||||
|
||||
def get_env_from_file(file):
|
||||
f = open(file, "rt")
|
||||
values = {}
|
||||
for line in f:
|
||||
if line.startswith("export"):
|
||||
content = line.split("=")
|
||||
for idx, c in enumerate(content):
|
||||
content[idx] = c.replace("export", "").strip()
|
||||
values[content[0]] = content[1].replace("\"", "").replace("'", "")
|
||||
try:
|
||||
github_token = values["github_token"]
|
||||
except Exception:
|
||||
print("Failed to parse github token!")
|
||||
try:
|
||||
github_username = values["github_username"]
|
||||
except:
|
||||
print("Failed to parse github username! Assuming you are me!")
|
||||
github_username = "Tri11Paragon"
|
||||
return EnvData(github_username=github_username, github_token=github_token)
|
||||
|
||||
def open_process(command, print_out = True):
|
||||
process = subprocess.Popen(command, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
|
||||
stdout, stderr = process.communicate()
|
||||
exit_code = process.wait()
|
||||
str_out = stdout.decode('utf8')
|
||||
str_err = stderr.decode('utf8')
|
||||
if print_out and len(str_out) > 0:
|
||||
print(str_out, end='')
|
||||
if print_out and len(str_err) > 0:
|
||||
print(str_err, end='')
|
||||
#print(stdout, stderr, exit_code)
|
||||
return (stdout, stderr, exit_code)
|
||||
|
||||
def load_cmake():
|
||||
cmake_file = open("CMakeLists.txt", 'r')
|
||||
cmake_text = cmake_file.read()
|
||||
cmake_file.close()
|
||||
return cmake_text
|
||||
|
||||
def write_cmake(cmake_text):
|
||||
cmake_file = open("CMakeLists.txt", 'w')
|
||||
cmake_file.write(cmake_text)
|
||||
cmake_file.close()
|
||||
|
||||
def get_version(cmake_text):
|
||||
begin = cmake_text.find(VERSION_BEGIN_STR) + len(VERSION_BEGIN_STR)
|
||||
end = cmake_text.find(VERSION_END_STR, begin)
|
||||
return (cmake_text[begin:end], begin, end)
|
||||
|
||||
def split_version(cmake_text):
|
||||
version, begin, end = get_version(cmake_text)
|
||||
version_parts = version.split('.')
|
||||
return (version_parts, begin, end)
|
||||
|
||||
def recombine(cmake_text, version_parts, begin, end):
|
||||
constructed_version = version_parts[0] + '.' + version_parts[1] + '.' + version_parts[2]
|
||||
constructed_text_begin = cmake_text[0:begin]
|
||||
constructed_text_end = cmake_text[end::]
|
||||
return constructed_text_begin + constructed_version + constructed_text_end
|
||||
|
||||
def inc_major(cmake_text):
|
||||
version_parts, begin, end = split_version(cmake_text)
|
||||
version_parts[0] = str(int(version_parts[0]) + 1)
|
||||
version_parts[1] = '0'
|
||||
version_parts[2] = '0'
|
||||
return recombine(cmake_text, version_parts, begin, end)
|
||||
|
||||
def inc_minor(cmake_text):
|
||||
version_parts, begin, end = split_version(cmake_text)
|
||||
version_parts[1] = str(int(version_parts[1]) + 1)
|
||||
version_parts[2] = '0'
|
||||
return recombine(cmake_text, version_parts, begin, end)
|
||||
|
||||
def inc_patch(config: Config, cmake_text):
|
||||
version_parts, begin, end = split_version(cmake_text)
|
||||
if config.patch_limit > 0 and int(version_parts[2]) + 1 >= config.patch_limit:
|
||||
return inc_minor(cmake_text)
|
||||
version_parts[2] = str(int(version_parts[2]) + 1)
|
||||
return recombine(cmake_text, version_parts, begin, end)
|
||||
|
||||
def make_branch(config: Config, name):
|
||||
print(f"Making new branch {name}")
|
||||
subprocess.call(["git", "checkout", "-b", name])
|
||||
subprocess.call(["git", "merge", config.main_branch])
|
||||
subprocess.call(["git", "checkout", config.main_branch])
|
||||
|
||||
def make_release(env: EnvData, name):
|
||||
print(f"Making new release {name}")
|
||||
repos_v = open_process(["git", "remote", "-v"])[0].splitlines()
|
||||
urls = []
|
||||
for line in repos_v:
|
||||
origin = ''.join(itertools.takewhile(str.isalpha, line.decode('utf8')))
|
||||
urls.append("https://api.github.com/repos/" + open_process(["git", "remote", "get-url", origin], False)[0].decode('utf8').replace("\n", "").replace("https://github.com/", "") + "/releases")
|
||||
urls = set(urls)
|
||||
data = {
|
||||
'tag_name': name,
|
||||
'name': name,
|
||||
'body': "Automated Release '" + name + "'",
|
||||
'draft': False,
|
||||
'prerelease': False
|
||||
}
|
||||
headers = {
|
||||
'Authorization': f'Bearer {env.github_token}',
|
||||
'Accept': 'application/vnd.github+json',
|
||||
'X-GitHub-Api-Version': '2022-11-28'
|
||||
}
|
||||
for url in urls:
|
||||
response = requests.post(url, headers=headers, data=json.dumps(data))
|
||||
if response.status_code == 201:
|
||||
print('Release created successfully!')
|
||||
release_data = response.json()
|
||||
print(f"Release URL: {release_data['html_url']}")
|
||||
else:
|
||||
print(f"Failed to create release: {response.status_code}")
|
||||
print(response.json())
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="Commit Helper",
|
||||
description="Help you make pretty commits :3")
|
||||
|
||||
parser.add_argument("action", nargs='?', default=None)
|
||||
parser.add_argument("-p", "--patch", action='store_true', default=False, required=False)
|
||||
parser.add_argument("-m", "--minor", action='store_true', default=False, required=False)
|
||||
parser.add_argument("-M", "--major", action='store_true', default=False, required=False)
|
||||
parser.add_argument('-e', "--env", help="environment file", required=False, default=None)
|
||||
parser.add_argument('-c', "--config", help="config file", required=False, default=None)
|
||||
parser.add_argument("--create_default_config", action="store_true", default=False, required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.env is not None:
|
||||
env = EnvData.get_env_from_file(args.e)
|
||||
else:
|
||||
env = EnvData.get_env_from_file(ENVIRONMENT_DATA_LOCATION)
|
||||
|
||||
if args.config is not None:
|
||||
config = Config.from_file(args.config)
|
||||
else:
|
||||
config = Config.from_file(CONFIG_FILE_LOCATION)
|
||||
|
||||
if args.create_default_config:
|
||||
config.save_to_file(args.config if args.config is not None else CONFIG_FILE_LOCATION)
|
||||
|
||||
cmake_text = load_cmake()
|
||||
cmake_version = get_version(cmake_text)[0]
|
||||
print(f"Current Version: {cmake_version}")
|
||||
|
||||
if not (args.patch or args.minor or args.major):
|
||||
try:
|
||||
if args.action is not None:
|
||||
type = args.action
|
||||
else:
|
||||
type = input("What kind of commit is this ((M)ajor, (m)inor, (p)atch)? ")
|
||||
|
||||
if type.startswith('M'):
|
||||
args.major = True
|
||||
elif type.startswith('m'):
|
||||
args.minor = True
|
||||
elif type.startswith('p') or type.startswith('P') or len(type) == 0:
|
||||
args.patch = True
|
||||
except KeyboardInterrupt:
|
||||
print("\nCancelling!")
|
||||
return
|
||||
|
||||
if args.major:
|
||||
print("Selected major")
|
||||
write_cmake(inc_major(cmake_text))
|
||||
elif args.minor:
|
||||
print("Selected minor")
|
||||
write_cmake(inc_minor(cmake_text))
|
||||
elif args.patch:
|
||||
print("Selected patch")
|
||||
write_cmake(inc_patch(config, cmake_text))
|
||||
|
||||
subprocess.call(["git", "add", "*"])
|
||||
subprocess.call(["git", "commit"])
|
||||
|
||||
cmake_text = load_cmake()
|
||||
version_parts = split_version(cmake_text)[0]
|
||||
if args.major:
|
||||
if config.branch_on_major:
|
||||
make_branch(config, "v" + str(version_parts[0]))
|
||||
if args.minor:
|
||||
if config.branch_on_minor:
|
||||
make_branch(config, "v" + str(version_parts[0]) + "." + str(version_parts[1]))
|
||||
|
||||
subprocess.call(["sh", "-c", "git remote | xargs -L1 git push --all"])
|
||||
|
||||
if args.major:
|
||||
if config.release_on_major:
|
||||
make_release(env, "v" + str(version_parts[0]))
|
||||
if args.minor:
|
||||
if config.release_on_minor:
|
||||
make_release(env, "v" + str(version_parts[0]) + "." + str(version_parts[1]))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,56 @@
|
|||
{ pkgs ? (import <nixpkgs> {
|
||||
config.allowUnfree = true;
|
||||
config.segger-jlink.acceptLicense = true;
|
||||
}), customPkgs ? (import /home/brett/my-nixpkgs {
|
||||
config.allowUnfree = true;
|
||||
config.segger-jlink.acceptLicense = true;
|
||||
}), ... }:
|
||||
pkgs.mkShell
|
||||
{
|
||||
buildInputs = with pkgs; [
|
||||
cmake
|
||||
gcc
|
||||
clang
|
||||
emscripten
|
||||
ninja
|
||||
customPkgs.jetbrains.clion
|
||||
#clion = import ~/my-nixpkgs/pkgs/applications/editors/jetbrains {};
|
||||
renderdoc
|
||||
valgrind
|
||||
];
|
||||
propagatedBuildInputs = with pkgs; [
|
||||
xorg.libX11
|
||||
xorg.libX11.dev
|
||||
xorg.libXcursor
|
||||
xorg.libXcursor.dev
|
||||
xorg.libXext
|
||||
xorg.libXext.dev
|
||||
xorg.libXinerama
|
||||
xorg.libXinerama.dev
|
||||
xorg.libXrandr
|
||||
xorg.libXrandr.dev
|
||||
xorg.libXrender
|
||||
xorg.libXrender.dev
|
||||
xorg.libxcb
|
||||
xorg.libxcb.dev
|
||||
xorg.libXi
|
||||
xorg.libXi.dev
|
||||
harfbuzz
|
||||
harfbuzz.dev
|
||||
zlib
|
||||
zlib.dev
|
||||
bzip2
|
||||
bzip2.dev
|
||||
pngpp
|
||||
brotli
|
||||
brotli.dev
|
||||
pulseaudio.dev
|
||||
git
|
||||
libGL
|
||||
libGL.dev
|
||||
glfw
|
||||
openblas
|
||||
openblas.dev
|
||||
];
|
||||
LD_LIBRARY_PATH="/run/opengl-driver/lib:/run/opengl-driver-32/lib";
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
#pragma once
|
||||
/*
|
||||
* Copyright (C) 2024 Brett Terpstra
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef MNIST_H
|
||||
#define MNIST_H
|
||||
|
||||
namespace fp
|
||||
{
|
||||
void run_mnist();
|
||||
}
|
||||
|
||||
#endif //MNIST_H
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 29286e66daa724ef08692d9a65e9c88e5467d9b2
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,285 @@
|
|||
/*
|
||||
* <Short Description>
|
||||
* Copyright (C) 2024 Brett Terpstra
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
#include <MNIST.h>
|
||||
#include <blt/fs/loader.h>
|
||||
#include <blt/std/memory.h>
|
||||
#include <blt/std/memory_util.h>
|
||||
#include <variant>
|
||||
#include <blt/iterator/iterator.h>
|
||||
#include <dlib/dnn.h>
|
||||
#include <dlib/data_io.h>
|
||||
|
||||
namespace fp
|
||||
{
|
||||
class idx_file_t
|
||||
{
|
||||
template <typename T>
|
||||
using mk_v = std::vector<T>;
|
||||
using vec_t = std::variant<mk_v<blt::u8>, mk_v<blt::i8>, mk_v<blt::u16>, mk_v<blt::u32>, mk_v<blt::f32>, mk_v<blt::f64>>;
|
||||
|
||||
public:
|
||||
explicit idx_file_t(const std::string& path)
|
||||
{
|
||||
std::ifstream file{path, std::ios::in | std::ios::binary};
|
||||
|
||||
using char_type = std::ifstream::char_type;
|
||||
char_type magic_arr[4];
|
||||
file.read(magic_arr, 4);
|
||||
BLT_ASSERT(magic_arr[0] == 0 && magic_arr[1] == 0);
|
||||
|
||||
blt::u8 dims = magic_arr[3];
|
||||
blt::size_t total_size = 1;
|
||||
|
||||
for (blt::i32 i = 0; i < dims; i++)
|
||||
{
|
||||
char_type dim_arr[4];
|
||||
file.read(dim_arr, 4);
|
||||
blt::u32 dim;
|
||||
blt::mem::fromBytes(dim_arr, dim);
|
||||
dimensions.push_back(dim);
|
||||
total_size *= dim;
|
||||
}
|
||||
|
||||
switch (magic_arr[2])
|
||||
{
|
||||
// unsigned char
|
||||
case 0x08:
|
||||
data = mk_v<blt::u8>{};
|
||||
read_data<blt::u8>(file, total_size);
|
||||
break;
|
||||
// signed char
|
||||
case 0x09:
|
||||
data = mk_v<blt::i8>{};
|
||||
read_data<blt::i8>(file, total_size);
|
||||
break;
|
||||
// short
|
||||
case 0x0B:
|
||||
data = mk_v<blt::u16>{};
|
||||
read_data<blt::u16>(file, total_size);
|
||||
reverse_data<blt::u16>();
|
||||
break;
|
||||
// int
|
||||
case 0x0C:
|
||||
data = mk_v<blt::u32>{};
|
||||
read_data<blt::u32>(file, total_size);
|
||||
reverse_data<blt::u32>();
|
||||
break;
|
||||
// float
|
||||
case 0x0D:
|
||||
data = mk_v<blt::f32>{};
|
||||
read_data<blt::f32>(file, total_size);
|
||||
reverse_data<blt::f32>();
|
||||
break;
|
||||
// double
|
||||
case 0x0E:
|
||||
data = mk_v<blt::f64>{};
|
||||
read_data<blt::f64>(file, total_size);
|
||||
reverse_data<blt::f64>();
|
||||
break;
|
||||
default:
|
||||
BLT_ERROR("Unspported idx file type!");
|
||||
}
|
||||
if (file.eof())
|
||||
{
|
||||
BLT_ERROR("EOF reached. It's unlikely your file was read correctly!");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] const std::vector<T>& get_data_as() const
|
||||
{
|
||||
return std::get<mk_v<T>>(data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<blt::span<T>> get_as_spans() const
|
||||
{
|
||||
std::vector<blt::span<T>> spans;
|
||||
|
||||
blt::size_t total_size = data_size(1);
|
||||
|
||||
for (blt::size_t i = 0; i < dimensions[0]; i++)
|
||||
{
|
||||
auto& array = std::get<mk_v<T>>(data);
|
||||
spans.push_back({&array[i * total_size], total_size});
|
||||
}
|
||||
|
||||
return spans;
|
||||
}
|
||||
|
||||
[[nodiscard]] const std::vector<blt::u32>& get_dimensions() const
|
||||
{
|
||||
return dimensions;
|
||||
}
|
||||
|
||||
[[nodiscard]] blt::size_t data_size(const blt::size_t starting_dimension = 0) const
|
||||
{
|
||||
blt::size_t total_size = 1;
|
||||
for (const auto d : blt::iterate(dimensions).skip(starting_dimension))
|
||||
total_size *= d;
|
||||
return total_size;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void read_data(std::ifstream& file, blt::size_t total_size)
|
||||
{
|
||||
auto& array = std::get<mk_v<T>>(data);
|
||||
array.resize(total_size);
|
||||
file.read(reinterpret_cast<char*>(array.data()), static_cast<std::streamsize>(total_size) * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void reverse_data()
|
||||
{
|
||||
auto& array = std::get<mk_v<T>>(data);
|
||||
for (auto& v : array)
|
||||
blt::mem::reverse(v);
|
||||
}
|
||||
|
||||
std::vector<blt::u32> dimensions;
|
||||
vec_t data;
|
||||
};
|
||||
|
||||
class image_t
|
||||
{
|
||||
public:
|
||||
static constexpr blt::u32 target_size = 10;
|
||||
|
||||
image_t(const idx_file_t& image_data, const idx_file_t& label_data): samples(image_data.get_dimensions()[0]),
|
||||
input_size(image_data.data_size(1))
|
||||
{
|
||||
BLT_ASSERT_MSG(samples == label_data.get_dimensions()[0],
|
||||
("Mismatch in data sample sizes! " + std::to_string(samples) + " vs " + std::to_string(label_data.get_dimensions()[0])).
|
||||
c_str());
|
||||
auto& image_array = image_data.get_data_as<blt::u8>();
|
||||
auto& label_array = label_data.get_data_as<blt::u8>();
|
||||
|
||||
for (const auto label : label_array)
|
||||
image_labels.push_back(label);
|
||||
|
||||
const auto row_length = image_data.get_dimensions()[2];
|
||||
const auto number_of_rows = image_data.get_dimensions()[1];
|
||||
|
||||
for (blt::u32 i = 0; i < samples; i++)
|
||||
{
|
||||
dlib::matrix<blt::u8> mat(number_of_rows, row_length);
|
||||
for (blt::u32 y = 0; y < number_of_rows; y++)
|
||||
{
|
||||
for (blt::u32 x = 0; x < row_length; x++)
|
||||
{
|
||||
mat(x, y) = image_array[i * input_size + y * row_length + x];
|
||||
}
|
||||
}
|
||||
data.push_back(mat);
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] const std::vector<dlib::matrix<blt::u8>>& get_image_data() const
|
||||
{
|
||||
return data;
|
||||
}
|
||||
|
||||
[[nodiscard]] const std::vector<blt::u64>& get_image_labels() const
|
||||
{
|
||||
return image_labels;
|
||||
}
|
||||
|
||||
private:
|
||||
blt::u32 samples;
|
||||
blt::u32 input_size;
|
||||
std::vector<dlib::matrix<blt::u8>> data;
|
||||
std::vector<blt::u64> image_labels;
|
||||
};
|
||||
|
||||
void run_mnist()
|
||||
{
|
||||
using namespace dlib;
|
||||
|
||||
idx_file_t test_images{"../problems/mnist/t10k-images.idx3-ubyte"};
|
||||
idx_file_t test_labels{"../problems/mnist/t10k-labels.idx1-ubyte"};
|
||||
idx_file_t train_images{"../problems/mnist/train-images.idx3-ubyte"};
|
||||
idx_file_t train_labels{"../problems/mnist/train-labels.idx1-ubyte"};
|
||||
|
||||
auto train_samples = train_images.get_dimensions()[0];
|
||||
auto test_samples = test_images.get_dimensions()[0];
|
||||
|
||||
auto columns = train_images.get_dimensions()[1];
|
||||
auto rows = train_images.get_dimensions()[2];
|
||||
|
||||
auto input_size = static_cast<blt::u32>(train_images.data_size(1));
|
||||
|
||||
image_t train_image{train_images, train_labels};
|
||||
image_t test_image{test_images, test_labels};
|
||||
|
||||
using net_type = loss_multiclass_log<
|
||||
fc<10,
|
||||
relu<fc<84,
|
||||
relu<fc<120,
|
||||
max_pool<2,2,2,2,relu<con<16,5,5,1,1,
|
||||
max_pool<2,2,2,2,relu<con<6,5,5,1,1,
|
||||
input<matrix<blt::u8>>>>>>>>>>>>>>;
|
||||
|
||||
net_type network{};
|
||||
|
||||
dnn_trainer trainer(network);
|
||||
trainer.set_learning_rate(0.01);
|
||||
trainer.set_min_learning_rate(0.00001);
|
||||
trainer.set_mini_batch_size(128);
|
||||
trainer.be_verbose();
|
||||
|
||||
trainer.set_synchronization_file("mnist_sync", std::chrono::seconds(20));
|
||||
|
||||
trainer.train(train_image.get_image_data(), train_image.get_image_labels());
|
||||
trainer.train_one_step(train_image.get_image_data(), train_image.get_image_labels());
|
||||
|
||||
network.clean();
|
||||
serialize("mnist_network.dat") << network;
|
||||
|
||||
std::vector<unsigned long> predicted_labels = network(train_image.get_image_data());
|
||||
int num_right = 0;
|
||||
int num_wrong = 0;
|
||||
// And then let's see if it classified them correctly.
|
||||
for (size_t i = 0; i < train_image.get_image_data().size(); ++i)
|
||||
{
|
||||
if (predicted_labels[i] == train_image.get_image_labels()[i])
|
||||
++num_right;
|
||||
else
|
||||
++num_wrong;
|
||||
|
||||
}
|
||||
std::cout << "training num_right: " << num_right << std::endl;
|
||||
std::cout << "training num_wrong: " << num_wrong << std::endl;
|
||||
std::cout << "training accuracy: " << num_right / static_cast<double>(num_right + num_wrong) << std::endl;
|
||||
|
||||
predicted_labels = network(test_image.get_image_data());
|
||||
num_right = 0;
|
||||
num_wrong = 0;
|
||||
for (size_t i = 0; i < test_image.get_image_data().size(); ++i)
|
||||
{
|
||||
if (predicted_labels[i] == test_image.get_image_labels()[i])
|
||||
++num_right;
|
||||
else
|
||||
++num_wrong;
|
||||
|
||||
}
|
||||
std::cout << "testing num_right: " << num_right << std::endl;
|
||||
std::cout << "testing num_wrong: " << num_wrong << std::endl;
|
||||
std::cout << "testing accuracy: " << num_right / static_cast<double>(num_right + num_wrong) << std::endl;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
#include <iostream>
|
||||
|
||||
#include <MNIST.h>
|
||||
|
||||
int main()
|
||||
{
|
||||
fp::run_mnist();
|
||||
}
|
Loading…
Reference in New Issue