diff --git a/mnist_test.py b/mnist_test.py index 0bd23a9..03a486b 100644 --- a/mnist_test.py +++ b/mnist_test.py @@ -14,9 +14,8 @@ def mnist_test( x_train = np.divide(x_train, 255) x_test = np.divide(x_train, 255) in_len = x_train[0].flatten().shape[0] - autoencoder = Autoencoder(in_len, bottleneck, 0.0001, relu) - x_train = x_train[:] - autoencoder.train_dataset(x_train, max_epoch, patience) + autoencoder = Autoencoder(in_len, bottleneck, 0.001, relu) + autoencoder.train_dataset(x_train, max_epoch, patience, display_loss=True) example: np.ndarray = x_test[np.random.randint(0, len(x_test))] code = autoencoder.encode(example.flatten()) output = autoencoder.decode(code)