feat: mnist test

This commit is contained in:
Lenoctambule
2026-03-27 04:22:49 +01:00
parent 69607d89c2
commit 439a11a828
2 changed files with 52 additions and 1 deletions

View File

@@ -71,6 +71,6 @@ class Autoencoder:
def encode(self, v: np.ndarray) -> np.ndarray:
return self.encoder.forward(v)
def decode(self, v: np.ndarray) -> np.ndarray:
return self.decoder.forward(v)

51
mnist_test.py Normal file
View File

@@ -0,0 +1,51 @@
import matplotlib.pyplot as plt
import numpy as np
from autoencoder import Autoencoder
from utils import (relu,
dynamic_loss_plot_init,
dynamic_loss_plot_update,
dynamic_loss_plot_finish)
def mnist_embed():
import keras
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
IN_LEN = x_train[0].flatten().shape[0]
DIM = 50
autoencoder = Autoencoder(IN_LEN, DIM, 0.001, relu)
ax, line = dynamic_loss_plot_init()
NO_IMPROV = 0
prev_error = float('inf')
losses = []
epoch = 0
x_train = x_train[:1_000]
while True:
error = 0
for x in x_train:
input = x.flatten() / 255
error += autoencoder.train(input)
error /= len(x_train)
if error >= prev_error:
NO_IMPROV += 1
prev_error = error
losses.append(error)
dynamic_loss_plot_update(ax, line, losses)
if NO_IMPROV > 10:
print('Done !')
break
if epoch > 200:
break
epoch += 1
dynamic_loss_plot_finish(ax, line)
example: np.ndarray = x_test[np.random.randint(0, len(x_test))]
code = autoencoder.encode(example.flatten() / 255)
output = autoencoder.decode(code)
plt.subplot(1, 2, 1)
plt.matshow(example, fignum=False)
plt.subplot(1, 2, 2)
plt.matshow(output.reshape(example.shape), fignum=False)
plt.show()
if __name__ == "__main__":
mnist_embed()