feat: error handling and re-train in mnist_test

This commit is contained in:
Lenoctambule
2026-03-29 21:12:19 +02:00
parent 8a3d408b7a
commit a93bb0a692
2 changed files with 23 additions and 12 deletions

View File

@@ -2,10 +2,10 @@ import matplotlib.pyplot as plt
import numpy as np
from autoencoder import Autoencoder
from activations import LeakyReLU
import os
def load_mnist() -> list[np.ndarray]:
import os
import requests
mnist_path = "./mnist.npz"
@@ -21,28 +21,32 @@ def mnist_train(
filename: str,
max_epoch: int,
patience: int,
):
) -> Autoencoder:
x_train, _, x_test, _ = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0]
x_train.resize(x_train.shape[0], in_len)
x_test.resize(x_test.shape[0], in_len)
x_train = x_train / 255
x_test = x_test / 255
autoencoder = Autoencoder(
[in_len, 64, 16],
[16, 64, in_len],
0.01,
LeakyReLU()
)
if os.path.exists(filename):
autoencoder = Autoencoder.load(filename)
else:
autoencoder = Autoencoder(
[in_len, 64, 16],
[16, 64, in_len],
0.01,
LeakyReLU()
)
autoencoder.train_dataset(
x_train,
max_epoch,
patience,
display_loss=True)
autoencoder.save(filename)
return autoencoder
def mnist_test(filename: str):
def mnist_test(model: str | Autoencoder):
x_train, _, x_test, y_test = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0]
img_shape = x_train[0].shape
@@ -50,7 +54,10 @@ def mnist_test(filename: str):
x_test.resize(x_test.shape[0], in_len)
x_train = x_train / 255
x_test = x_test / 255
autoencoder: Autoencoder = Autoencoder.load(filename)
if isinstance(model, str):
autoencoder: Autoencoder = Autoencoder.load(model)
else:
autoencoder = model
print(autoencoder)
idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx]
@@ -107,5 +114,5 @@ if __name__ == "__main__":
if args.r:
mnist_test(args.m)
else:
mnist_train(args.m, args.e, args.p)
mnist_test(args.m)
autoencoder = mnist_train(args.m, args.e, args.p)
mnist_test(autoencoder)