From 439a11a8288ab5ec416ede5ba0bc3266b40afe54 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Fri, 27 Mar 2026 04:22:49 +0100 Subject: [PATCH] feat: mnist test --- autoencoder.py | 2 +- mnist_test.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 mnist_test.py diff --git a/autoencoder.py b/autoencoder.py index 77d9fb1..b64dd35 100644 --- a/autoencoder.py +++ b/autoencoder.py @@ -71,6 +71,6 @@ class Autoencoder: def encode(self, v: np.ndarray) -> np.ndarray: return self.encoder.forward(v) - + def decode(self, v: np.ndarray) -> np.ndarray: return self.decoder.forward(v) diff --git a/mnist_test.py b/mnist_test.py new file mode 100644 index 0000000..3a7ed85 --- /dev/null +++ b/mnist_test.py @@ -0,0 +1,51 @@ +import matplotlib.pyplot as plt +import numpy as np +from autoencoder import Autoencoder +from utils import (relu, + dynamic_loss_plot_init, + dynamic_loss_plot_update, + dynamic_loss_plot_finish) + + +def mnist_embed(): + import keras + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + IN_LEN = x_train[0].flatten().shape[0] + DIM = 50 + autoencoder = Autoencoder(IN_LEN, DIM, 0.001, relu) + ax, line = dynamic_loss_plot_init() + NO_IMPROV = 0 + prev_error = float('inf') + losses = [] + epoch = 0 + x_train = x_train[:1_000] + while True: + error = 0 + for x in x_train: + input = x.flatten() / 255 + error += autoencoder.train(input) + error /= len(x_train) + if error >= prev_error: + NO_IMPROV += 1 + prev_error = error + losses.append(error) + dynamic_loss_plot_update(ax, line, losses) + if NO_IMPROV > 10: + print('Done !') + break + if epoch > 200: + break + epoch += 1 + dynamic_loss_plot_finish(ax, line) + example: np.ndarray = x_test[np.random.randint(0, len(x_test))] + code = autoencoder.encode(example.flatten() / 255) + output = autoencoder.decode(code) + plt.subplot(1, 2, 1) + plt.matshow(example, fignum=False) + plt.subplot(1, 2, 2) + plt.matshow(output.reshape(example.shape), fignum=False) + plt.show() + + +if __name__ == "__main__": + mnist_embed()