feat: custom load_mnist function and rm useless dependency

This commit is contained in:
Lenoctambule
2026-03-27 21:54:17 +01:00
parent fa2bfe4ef5
commit 32e27b4b62
3 changed files with 16 additions and 6 deletions

View File

@@ -91,7 +91,6 @@ class Autoencoder:
while True: while True:
print( print(
f"{LOADER[epoch % len(LOADER)]} Training \t({epoch=} error={prev_error:.2f})", # noqa f"{LOADER[epoch % len(LOADER)]} Training \t({epoch=} error={prev_error:.2f})", # noqa
end="\r"
) )
error = 0 error = 0
for x in data_set: for x in data_set:
@@ -113,7 +112,7 @@ class Autoencoder:
epoch += 1 epoch += 1
if display_loss is True: if display_loss is True:
dynamic_loss_plot_finish(ax, line) dynamic_loss_plot_finish(ax, line)
print("\r#Training complete !") print("#Training complete !")
return losses return losses
def encode(self, v: np.ndarray) -> np.ndarray: def encode(self, v: np.ndarray) -> np.ndarray:

View File

@@ -1,16 +1,28 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import keras
from autoencoder import Autoencoder from autoencoder import Autoencoder
from utils import relu from utils import relu
def load_mnist():
import os
import requests
mnist_path = "./mnist.npz"
mnist_url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz"
if not os.path.exists(mnist_path):
with open(mnist_path, "w+b") as f:
f.write(requests.get(mnist_url, stream=True).content)
res = np.load(mnist_path)
return res["x_train"], res["y_train"], res["x_test"], res["y_test"]
def mnist_test( def mnist_test(
bottleneck: int, bottleneck: int,
max_epoch: int, max_epoch: int,
patience: int, patience: int,
): ):
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data() x_train, _, x_test, _ = load_mnist()
x_train = np.divide(x_train, 255) x_train = np.divide(x_train, 255)
x_test = np.divide(x_train, 255) x_test = np.divide(x_train, 255)
in_len = x_train[0].flatten().shape[0] in_len = x_train[0].flatten().shape[0]

View File

@@ -1,4 +1,3 @@
numpy numpy
matplotlib matplotlib
keras requests
tensorflow