feat: custom load_mnist function and rm useless dependency
This commit is contained in:
@@ -1,16 +1,28 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import keras
|
||||
from autoencoder import Autoencoder
|
||||
from utils import relu
|
||||
|
||||
|
||||
def load_mnist():
|
||||
import os
|
||||
import requests
|
||||
|
||||
mnist_path = "./mnist.npz"
|
||||
mnist_url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz"
|
||||
if not os.path.exists(mnist_path):
|
||||
with open(mnist_path, "w+b") as f:
|
||||
f.write(requests.get(mnist_url, stream=True).content)
|
||||
res = np.load(mnist_path)
|
||||
return res["x_train"], res["y_train"], res["x_test"], res["y_test"]
|
||||
|
||||
|
||||
def mnist_test(
|
||||
bottleneck: int,
|
||||
max_epoch: int,
|
||||
patience: int,
|
||||
):
|
||||
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
|
||||
x_train, _, x_test, _ = load_mnist()
|
||||
x_train = np.divide(x_train, 255)
|
||||
x_test = np.divide(x_train, 255)
|
||||
in_len = x_train[0].flatten().shape[0]
|
||||
|
||||
Reference in New Issue
Block a user