feat: error handling and re-train in mnist_test
This commit is contained in:
@@ -2,10 +2,10 @@ import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from autoencoder import Autoencoder
|
||||
from activations import LeakyReLU
|
||||
import os
|
||||
|
||||
|
||||
def load_mnist() -> list[np.ndarray]:
|
||||
import os
|
||||
import requests
|
||||
|
||||
mnist_path = "./mnist.npz"
|
||||
@@ -21,28 +21,32 @@ def mnist_train(
|
||||
filename: str,
|
||||
max_epoch: int,
|
||||
patience: int,
|
||||
):
|
||||
) -> Autoencoder:
|
||||
x_train, _, x_test, _ = load_mnist()
|
||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||
x_train.resize(x_train.shape[0], in_len)
|
||||
x_test.resize(x_test.shape[0], in_len)
|
||||
x_train = x_train / 255
|
||||
x_test = x_test / 255
|
||||
autoencoder = Autoencoder(
|
||||
[in_len, 64, 16],
|
||||
[16, 64, in_len],
|
||||
0.01,
|
||||
LeakyReLU()
|
||||
)
|
||||
if os.path.exists(filename):
|
||||
autoencoder = Autoencoder.load(filename)
|
||||
else:
|
||||
autoencoder = Autoencoder(
|
||||
[in_len, 64, 16],
|
||||
[16, 64, in_len],
|
||||
0.01,
|
||||
LeakyReLU()
|
||||
)
|
||||
autoencoder.train_dataset(
|
||||
x_train,
|
||||
max_epoch,
|
||||
patience,
|
||||
display_loss=True)
|
||||
autoencoder.save(filename)
|
||||
return autoencoder
|
||||
|
||||
|
||||
def mnist_test(filename: str):
|
||||
def mnist_test(model: str | Autoencoder):
|
||||
x_train, _, x_test, y_test = load_mnist()
|
||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||
img_shape = x_train[0].shape
|
||||
@@ -50,7 +54,10 @@ def mnist_test(filename: str):
|
||||
x_test.resize(x_test.shape[0], in_len)
|
||||
x_train = x_train / 255
|
||||
x_test = x_test / 255
|
||||
autoencoder: Autoencoder = Autoencoder.load(filename)
|
||||
if isinstance(model, str):
|
||||
autoencoder: Autoencoder = Autoencoder.load(model)
|
||||
else:
|
||||
autoencoder = model
|
||||
print(autoencoder)
|
||||
idx = np.random.randint(0, len(x_test))
|
||||
example: np.ndarray = x_test[idx]
|
||||
@@ -107,5 +114,5 @@ if __name__ == "__main__":
|
||||
if args.r:
|
||||
mnist_test(args.m)
|
||||
else:
|
||||
mnist_train(args.m, args.e, args.p)
|
||||
mnist_test(args.m)
|
||||
autoencoder = mnist_train(args.m, args.e, args.p)
|
||||
mnist_test(autoencoder)
|
||||
|
||||
Reference in New Issue
Block a user