Files
autoencoder/examples/mnist_test.py

156 lines
4.1 KiB
Python

import matplotlib.pyplot as plt
import numpy as np
import os
from easyvae.autoencoder import ( # noqa
VariationalAutoencoder,
ClassicalAutoencoder,
AAutoencoder
)
from easyvae.activations import LeakyReLU
def load_mnist() -> list[np.ndarray]:
import requests
mnist_path = "./mnist.npz"
mnist_url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz" # noqa
if not os.path.exists(mnist_path):
with open(mnist_path, "w+b") as f:
f.write(requests.get(mnist_url, stream=True).content)
res = np.load(mnist_path)
return res["x_train"], res["y_train"], res["x_test"], res["y_test"]
def mnist_train(
filename: str,
max_epoch: int,
patience: int,
cls: type[AAutoencoder],) -> AAutoencoder:
x_train, _, x_test, _ = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0]
x_train.resize(x_train.shape[0], in_len)
x_test.resize(x_test.shape[0], in_len)
x_train = x_train / 255
if os.path.exists(filename):
autoencoder = cls.load(filename)
else:
autoencoder = cls(
[in_len, 256, 2],
[2, 256, in_len],
0.0001,
LeakyReLU()
)
print("CTRL+C to interrupt training.")
autoencoder.train_dataset(
x_train,
max_epoch,
patience,
display_loss=True)
autoencoder.save(filename)
print("Training complete !")
return autoencoder
def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,):
codes = []
for x in x:
_, c = autoencoder.forward(x.flatten())
codes.append(c)
codes = np.array(codes)
if codes.shape[1] == 2:
plt.figure(figsize=(6, 6))
scatter = plt.scatter(
codes[:, 0],
codes[:, 1],
c=y,
cmap='tab10',
s=5,
alpha=0.7
)
plt.colorbar(scatter)
plt.grid(True)
plt.show()
def plot_random_reconstruction(
autoencoder: AAutoencoder,
example: np.ndarray,
img_shape,
y):
output, code = autoencoder.forward(example.flatten())
plt.subplot(1, 2, 1)
plt.matshow(
example.reshape(img_shape),
fignum=False)
plt.title(f"Input ({y})")
plt.subplot(1, 2, 2)
plt.matshow(
output.reshape(img_shape),
fignum=False)
plt.title(f"Output ({y})")
print(f'{code.tolist()}')
def mnist_test(model: str | AAutoencoder):
x_train, _, x_test, y_test = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0]
img_shape = x_train[0].shape
x_train.resize(x_train.shape[0], in_len)
x_test.resize(x_test.shape[0], in_len)
x_train = x_train / 255
x_test = x_test / 255
if isinstance(model, str):
autoencoder: AAutoencoder = AAutoencoder.load(model)
else:
autoencoder = model
print("Testing model ...\n")
print(autoencoder)
idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx]
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])
if autoencoder.space_dim == 2:
plot_mnist_latent_space(autoencoder, x_test, y_test)
if __name__ == "__main__":
import argparse
import sys
parser = argparse.ArgumentParser()
parser.add_argument(
'-e',
type=int,
nargs='?',
default=30,
help='Max epochs'
)
parser.add_argument(
'-p',
type=int,
nargs='?',
default=30,
help='Patience'
)
parser.add_argument(
'-m',
type=str, nargs='?',
default='autoencoder_mnist.npy',
help='Model filename to save in run mode or load in training mode'
)
parser.add_argument(
'-r',
action='store_true',
help='Run the model'
)
args = parser.parse_args(sys.argv[1:])
if args.r:
mnist_test(args.m)
else:
autoencoder = mnist_train(
args.m,
args.e,
args.p,
VariationalAutoencoder
)
mnist_test(autoencoder)