diff --git a/examples/mnist_test.py b/examples/mnist_test.py index 7aff0f9..5e68c9a 100644 --- a/examples/mnist_test.py +++ b/examples/mnist_test.py @@ -89,12 +89,12 @@ def plot_random_reconstruction(autoencoder: AAutoencoder, img_shape, y): output, code = autoencoder.forward(example.flatten()) - plt.subplot(1, 3, 1) + plt.subplot(1, 2, 1) plt.matshow( example.reshape(img_shape), fignum=False) plt.title(f"Input ({y})") - plt.subplot(1, 3, 2) + plt.subplot(1, 2, 2) plt.matshow( output.reshape(img_shape), fignum=False)