refactor: code de-dup
This commit is contained in:
@@ -9,7 +9,7 @@ def load_mnist():
|
||||
import requests
|
||||
|
||||
mnist_path = "./mnist.npz"
|
||||
mnist_url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz"
|
||||
mnist_url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz" # noqa
|
||||
if not os.path.exists(mnist_path):
|
||||
with open(mnist_path, "w+b") as f:
|
||||
f.write(requests.get(mnist_url, stream=True).content)
|
||||
@@ -25,7 +25,7 @@ def mnist_test(
|
||||
x_train, _, x_test, _ = load_mnist()
|
||||
x_train = np.divide(x_train, 255)
|
||||
x_test = np.divide(x_train, 255)
|
||||
in_len = x_train[0].flatten().shape[0]
|
||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||
autoencoder = Autoencoder(in_len, bottleneck, 0.001, relu)
|
||||
autoencoder.train_dataset(x_train, max_epoch, patience, display_loss=True)
|
||||
example: np.ndarray = x_test[np.random.randint(0, len(x_test))]
|
||||
|
||||
Reference in New Issue
Block a user