feat: mnist test
This commit is contained in:
@@ -71,6 +71,6 @@ class Autoencoder:
|
|||||||
|
|
||||||
def encode(self, v: np.ndarray) -> np.ndarray:
|
def encode(self, v: np.ndarray) -> np.ndarray:
|
||||||
return self.encoder.forward(v)
|
return self.encoder.forward(v)
|
||||||
|
|
||||||
def decode(self, v: np.ndarray) -> np.ndarray:
|
def decode(self, v: np.ndarray) -> np.ndarray:
|
||||||
return self.decoder.forward(v)
|
return self.decoder.forward(v)
|
||||||
|
|||||||
51
mnist_test.py
Normal file
51
mnist_test.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from autoencoder import Autoencoder
|
||||||
|
from utils import (relu,
|
||||||
|
dynamic_loss_plot_init,
|
||||||
|
dynamic_loss_plot_update,
|
||||||
|
dynamic_loss_plot_finish)
|
||||||
|
|
||||||
|
|
||||||
|
def mnist_embed():
|
||||||
|
import keras
|
||||||
|
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
|
||||||
|
IN_LEN = x_train[0].flatten().shape[0]
|
||||||
|
DIM = 50
|
||||||
|
autoencoder = Autoencoder(IN_LEN, DIM, 0.001, relu)
|
||||||
|
ax, line = dynamic_loss_plot_init()
|
||||||
|
NO_IMPROV = 0
|
||||||
|
prev_error = float('inf')
|
||||||
|
losses = []
|
||||||
|
epoch = 0
|
||||||
|
x_train = x_train[:1_000]
|
||||||
|
while True:
|
||||||
|
error = 0
|
||||||
|
for x in x_train:
|
||||||
|
input = x.flatten() / 255
|
||||||
|
error += autoencoder.train(input)
|
||||||
|
error /= len(x_train)
|
||||||
|
if error >= prev_error:
|
||||||
|
NO_IMPROV += 1
|
||||||
|
prev_error = error
|
||||||
|
losses.append(error)
|
||||||
|
dynamic_loss_plot_update(ax, line, losses)
|
||||||
|
if NO_IMPROV > 10:
|
||||||
|
print('Done !')
|
||||||
|
break
|
||||||
|
if epoch > 200:
|
||||||
|
break
|
||||||
|
epoch += 1
|
||||||
|
dynamic_loss_plot_finish(ax, line)
|
||||||
|
example: np.ndarray = x_test[np.random.randint(0, len(x_test))]
|
||||||
|
code = autoencoder.encode(example.flatten() / 255)
|
||||||
|
output = autoencoder.decode(code)
|
||||||
|
plt.subplot(1, 2, 1)
|
||||||
|
plt.matshow(example, fignum=False)
|
||||||
|
plt.subplot(1, 2, 2)
|
||||||
|
plt.matshow(output.reshape(example.shape), fignum=False)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mnist_embed()
|
||||||
Reference in New Issue
Block a user