feat: custom load_mnist function and rm useless dependency
This commit is contained in:
@@ -91,7 +91,6 @@ class Autoencoder:
|
|||||||
while True:
|
while True:
|
||||||
print(
|
print(
|
||||||
f"{LOADER[epoch % len(LOADER)]} Training \t({epoch=} error={prev_error:.2f})", # noqa
|
f"{LOADER[epoch % len(LOADER)]} Training \t({epoch=} error={prev_error:.2f})", # noqa
|
||||||
end="\r"
|
|
||||||
)
|
)
|
||||||
error = 0
|
error = 0
|
||||||
for x in data_set:
|
for x in data_set:
|
||||||
@@ -113,7 +112,7 @@ class Autoencoder:
|
|||||||
epoch += 1
|
epoch += 1
|
||||||
if display_loss is True:
|
if display_loss is True:
|
||||||
dynamic_loss_plot_finish(ax, line)
|
dynamic_loss_plot_finish(ax, line)
|
||||||
print("\r#Training complete !")
|
print("#Training complete !")
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def encode(self, v: np.ndarray) -> np.ndarray:
|
def encode(self, v: np.ndarray) -> np.ndarray:
|
||||||
|
|||||||
@@ -1,16 +1,28 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import keras
|
|
||||||
from autoencoder import Autoencoder
|
from autoencoder import Autoencoder
|
||||||
from utils import relu
|
from utils import relu
|
||||||
|
|
||||||
|
|
||||||
|
def load_mnist():
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
|
||||||
|
mnist_path = "./mnist.npz"
|
||||||
|
mnist_url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz"
|
||||||
|
if not os.path.exists(mnist_path):
|
||||||
|
with open(mnist_path, "w+b") as f:
|
||||||
|
f.write(requests.get(mnist_url, stream=True).content)
|
||||||
|
res = np.load(mnist_path)
|
||||||
|
return res["x_train"], res["y_train"], res["x_test"], res["y_test"]
|
||||||
|
|
||||||
|
|
||||||
def mnist_test(
|
def mnist_test(
|
||||||
bottleneck: int,
|
bottleneck: int,
|
||||||
max_epoch: int,
|
max_epoch: int,
|
||||||
patience: int,
|
patience: int,
|
||||||
):
|
):
|
||||||
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
|
x_train, _, x_test, _ = load_mnist()
|
||||||
x_train = np.divide(x_train, 255)
|
x_train = np.divide(x_train, 255)
|
||||||
x_test = np.divide(x_train, 255)
|
x_test = np.divide(x_train, 255)
|
||||||
in_len = x_train[0].flatten().shape[0]
|
in_len = x_train[0].flatten().shape[0]
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
numpy
|
numpy
|
||||||
matplotlib
|
matplotlib
|
||||||
keras
|
requests
|
||||||
tensorflow
|
|
||||||
Reference in New Issue
Block a user