feat: error handling and re-train in mnist_test

This commit is contained in:
Lenoctambule
2026-03-29 21:12:19 +02:00
parent 8a3d408b7a
commit a93bb0a692
2 changed files with 23 additions and 12 deletions

View File

@@ -15,6 +15,10 @@ class Autoencoder:
decoder_layers: list[int], decoder_layers: list[int],
lr: float, lr: float,
activation_func: ActivationFunc): activation_func: ActivationFunc):
if encoder_layers[-1] != decoder_layers[0]:
raise Exception(
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
)
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func) self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func) self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)

View File

@@ -2,10 +2,10 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
from autoencoder import Autoencoder from autoencoder import Autoencoder
from activations import LeakyReLU from activations import LeakyReLU
import os
def load_mnist() -> list[np.ndarray]: def load_mnist() -> list[np.ndarray]:
import os
import requests import requests
mnist_path = "./mnist.npz" mnist_path = "./mnist.npz"
@@ -21,28 +21,32 @@ def mnist_train(
filename: str, filename: str,
max_epoch: int, max_epoch: int,
patience: int, patience: int,
): ) -> Autoencoder:
x_train, _, x_test, _ = load_mnist() x_train, _, x_test, _ = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0] in_len = x_train[0].shape[0] * x_train[0].shape[0]
x_train.resize(x_train.shape[0], in_len) x_train.resize(x_train.shape[0], in_len)
x_test.resize(x_test.shape[0], in_len) x_test.resize(x_test.shape[0], in_len)
x_train = x_train / 255 x_train = x_train / 255
x_test = x_test / 255 x_test = x_test / 255
autoencoder = Autoencoder( if os.path.exists(filename):
[in_len, 64, 16], autoencoder = Autoencoder.load(filename)
[16, 64, in_len], else:
0.01, autoencoder = Autoencoder(
LeakyReLU() [in_len, 64, 16],
) [16, 64, in_len],
0.01,
LeakyReLU()
)
autoencoder.train_dataset( autoencoder.train_dataset(
x_train, x_train,
max_epoch, max_epoch,
patience, patience,
display_loss=True) display_loss=True)
autoencoder.save(filename) autoencoder.save(filename)
return autoencoder
def mnist_test(filename: str): def mnist_test(model: str | Autoencoder):
x_train, _, x_test, y_test = load_mnist() x_train, _, x_test, y_test = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0] in_len = x_train[0].shape[0] * x_train[0].shape[0]
img_shape = x_train[0].shape img_shape = x_train[0].shape
@@ -50,7 +54,10 @@ def mnist_test(filename: str):
x_test.resize(x_test.shape[0], in_len) x_test.resize(x_test.shape[0], in_len)
x_train = x_train / 255 x_train = x_train / 255
x_test = x_test / 255 x_test = x_test / 255
autoencoder: Autoencoder = Autoencoder.load(filename) if isinstance(model, str):
autoencoder: Autoencoder = Autoencoder.load(model)
else:
autoencoder = model
print(autoencoder) print(autoencoder)
idx = np.random.randint(0, len(x_test)) idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx] example: np.ndarray = x_test[idx]
@@ -107,5 +114,5 @@ if __name__ == "__main__":
if args.r: if args.r:
mnist_test(args.m) mnist_test(args.m)
else: else:
mnist_train(args.m, args.e, args.p) autoencoder = mnist_train(args.m, args.e, args.p)
mnist_test(args.m) mnist_test(autoencoder)