feat: error handling and re-train in mnist_test
This commit is contained in:
@@ -15,6 +15,10 @@ class Autoencoder:
|
|||||||
decoder_layers: list[int],
|
decoder_layers: list[int],
|
||||||
lr: float,
|
lr: float,
|
||||||
activation_func: ActivationFunc):
|
activation_func: ActivationFunc):
|
||||||
|
if encoder_layers[-1] != decoder_layers[0]:
|
||||||
|
raise Exception(
|
||||||
|
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
|
||||||
|
)
|
||||||
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
||||||
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ import matplotlib.pyplot as plt
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from autoencoder import Autoencoder
|
from autoencoder import Autoencoder
|
||||||
from activations import LeakyReLU
|
from activations import LeakyReLU
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
def load_mnist() -> list[np.ndarray]:
|
def load_mnist() -> list[np.ndarray]:
|
||||||
import os
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
mnist_path = "./mnist.npz"
|
mnist_path = "./mnist.npz"
|
||||||
@@ -21,13 +21,16 @@ def mnist_train(
|
|||||||
filename: str,
|
filename: str,
|
||||||
max_epoch: int,
|
max_epoch: int,
|
||||||
patience: int,
|
patience: int,
|
||||||
):
|
) -> Autoencoder:
|
||||||
x_train, _, x_test, _ = load_mnist()
|
x_train, _, x_test, _ = load_mnist()
|
||||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||||
x_train.resize(x_train.shape[0], in_len)
|
x_train.resize(x_train.shape[0], in_len)
|
||||||
x_test.resize(x_test.shape[0], in_len)
|
x_test.resize(x_test.shape[0], in_len)
|
||||||
x_train = x_train / 255
|
x_train = x_train / 255
|
||||||
x_test = x_test / 255
|
x_test = x_test / 255
|
||||||
|
if os.path.exists(filename):
|
||||||
|
autoencoder = Autoencoder.load(filename)
|
||||||
|
else:
|
||||||
autoencoder = Autoencoder(
|
autoencoder = Autoencoder(
|
||||||
[in_len, 64, 16],
|
[in_len, 64, 16],
|
||||||
[16, 64, in_len],
|
[16, 64, in_len],
|
||||||
@@ -40,9 +43,10 @@ def mnist_train(
|
|||||||
patience,
|
patience,
|
||||||
display_loss=True)
|
display_loss=True)
|
||||||
autoencoder.save(filename)
|
autoencoder.save(filename)
|
||||||
|
return autoencoder
|
||||||
|
|
||||||
|
|
||||||
def mnist_test(filename: str):
|
def mnist_test(model: str | Autoencoder):
|
||||||
x_train, _, x_test, y_test = load_mnist()
|
x_train, _, x_test, y_test = load_mnist()
|
||||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||||
img_shape = x_train[0].shape
|
img_shape = x_train[0].shape
|
||||||
@@ -50,7 +54,10 @@ def mnist_test(filename: str):
|
|||||||
x_test.resize(x_test.shape[0], in_len)
|
x_test.resize(x_test.shape[0], in_len)
|
||||||
x_train = x_train / 255
|
x_train = x_train / 255
|
||||||
x_test = x_test / 255
|
x_test = x_test / 255
|
||||||
autoencoder: Autoencoder = Autoencoder.load(filename)
|
if isinstance(model, str):
|
||||||
|
autoencoder: Autoencoder = Autoencoder.load(model)
|
||||||
|
else:
|
||||||
|
autoencoder = model
|
||||||
print(autoencoder)
|
print(autoencoder)
|
||||||
idx = np.random.randint(0, len(x_test))
|
idx = np.random.randint(0, len(x_test))
|
||||||
example: np.ndarray = x_test[idx]
|
example: np.ndarray = x_test[idx]
|
||||||
@@ -107,5 +114,5 @@ if __name__ == "__main__":
|
|||||||
if args.r:
|
if args.r:
|
||||||
mnist_test(args.m)
|
mnist_test(args.m)
|
||||||
else:
|
else:
|
||||||
mnist_train(args.m, args.e, args.p)
|
autoencoder = mnist_train(args.m, args.e, args.p)
|
||||||
mnist_test(args.m)
|
mnist_test(autoencoder)
|
||||||
|
|||||||
Reference in New Issue
Block a user