From a93bb0a692197e2f3aefeba3f73699df0651bf7f Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Sun, 29 Mar 2026 21:12:19 +0200 Subject: [PATCH] feat: error handling and re-train in mnist_test --- autoencoder.py | 4 ++++ mnist_test.py | 31 +++++++++++++++++++------------ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/autoencoder.py b/autoencoder.py index 5486ece..42ded53 100644 --- a/autoencoder.py +++ b/autoencoder.py @@ -15,6 +15,10 @@ class Autoencoder: decoder_layers: list[int], lr: float, activation_func: ActivationFunc): + if encoder_layers[-1] != decoder_layers[0]: + raise Exception( + f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa + ) self.encoder = DeepNNLayer(encoder_layers, lr, activation_func) self.decoder = DeepNNLayer(decoder_layers, lr, activation_func) diff --git a/mnist_test.py b/mnist_test.py index 97f6cf7..b9f788b 100644 --- a/mnist_test.py +++ b/mnist_test.py @@ -2,10 +2,10 @@ import matplotlib.pyplot as plt import numpy as np from autoencoder import Autoencoder from activations import LeakyReLU +import os def load_mnist() -> list[np.ndarray]: - import os import requests mnist_path = "./mnist.npz" @@ -21,28 +21,32 @@ def mnist_train( filename: str, max_epoch: int, patience: int, - ): + ) -> Autoencoder: x_train, _, x_test, _ = load_mnist() in_len = x_train[0].shape[0] * x_train[0].shape[0] x_train.resize(x_train.shape[0], in_len) x_test.resize(x_test.shape[0], in_len) x_train = x_train / 255 x_test = x_test / 255 - autoencoder = Autoencoder( - [in_len, 64, 16], - [16, 64, in_len], - 0.01, - LeakyReLU() - ) + if os.path.exists(filename): + autoencoder = Autoencoder.load(filename) + else: + autoencoder = Autoencoder( + [in_len, 64, 16], + [16, 64, in_len], + 0.01, + LeakyReLU() + ) autoencoder.train_dataset( x_train, max_epoch, patience, display_loss=True) autoencoder.save(filename) + return autoencoder -def mnist_test(filename: str): +def mnist_test(model: str | Autoencoder): x_train, _, x_test, y_test = load_mnist() in_len = x_train[0].shape[0] * x_train[0].shape[0] img_shape = x_train[0].shape @@ -50,7 +54,10 @@ def mnist_test(filename: str): x_test.resize(x_test.shape[0], in_len) x_train = x_train / 255 x_test = x_test / 255 - autoencoder: Autoencoder = Autoencoder.load(filename) + if isinstance(model, str): + autoencoder: Autoencoder = Autoencoder.load(model) + else: + autoencoder = model print(autoencoder) idx = np.random.randint(0, len(x_test)) example: np.ndarray = x_test[idx] @@ -107,5 +114,5 @@ if __name__ == "__main__": if args.r: mnist_test(args.m) else: - mnist_train(args.m, args.e, args.p) - mnist_test(args.m) + autoencoder = mnist_train(args.m, args.e, args.p) + mnist_test(autoencoder)