feat: loss method + mv data reshaping out of Autoencoder class
This commit is contained in:
@@ -4,7 +4,7 @@ from autoencoder import Autoencoder
|
||||
from utils import relu
|
||||
|
||||
|
||||
def load_mnist():
|
||||
def load_mnist() -> list[np.ndarray]:
|
||||
import os
|
||||
import requests
|
||||
|
||||
@@ -23,9 +23,13 @@ def mnist_test(
|
||||
patience: int,
|
||||
):
|
||||
x_train, _, x_test, _ = load_mnist()
|
||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||
img_shape = x_train[0].shape
|
||||
x_train.resize(x_train.shape[0], in_len)
|
||||
x_test.resize(x_test.shape[0], in_len)
|
||||
x_train = np.divide(x_train, 255)
|
||||
x_test = np.divide(x_train, 255)
|
||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||
x_train = x_train[:1000]
|
||||
autoencoder = Autoencoder(
|
||||
[in_len, bottleneck],
|
||||
[bottleneck, in_len],
|
||||
@@ -41,9 +45,9 @@ def mnist_test(
|
||||
code = autoencoder.encode(example.flatten())
|
||||
output = autoencoder.decode(code)
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.matshow(example, fignum=False)
|
||||
plt.matshow(example.reshape(img_shape), fignum=False)
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.matshow(output.reshape(example.shape), fignum=False)
|
||||
plt.matshow(output.reshape(img_shape), fignum=False)
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user