refactor: move plot logic to plotters.py
This commit is contained in:
@@ -8,7 +8,6 @@ from easyvae.autoencoder import ( # noqa
|
||||
AAutoencoder
|
||||
)
|
||||
from easyvae.activations import LeakyReLU
|
||||
from easyvae.utils import dynamic_loss_plot_finish
|
||||
|
||||
|
||||
def load_mnist() -> list[np.ndarray]:
|
||||
@@ -33,6 +32,7 @@ def mnist_train(
|
||||
x_train.resize(x_train.shape[0], in_len)
|
||||
x_test.resize(x_test.shape[0], in_len)
|
||||
x_train = x_train / 255
|
||||
x_train = x_train[:5000]
|
||||
if os.path.exists(filename):
|
||||
autoencoder = cls.load(filename)
|
||||
else:
|
||||
@@ -46,8 +46,8 @@ def mnist_train(
|
||||
def handler(signum, frame):
|
||||
print(f"Saving {filename} before exit ...")
|
||||
autoencoder.save(filename)
|
||||
if plt.get_fignums():
|
||||
dynamic_loss_plot_finish()
|
||||
plt.close('all')
|
||||
plt.ioff()
|
||||
mnist_test(autoencoder)
|
||||
exit()
|
||||
|
||||
@@ -84,10 +84,11 @@ def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,):
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_random_reconstruction(autoencoder: AAutoencoder,
|
||||
example: np.ndarray,
|
||||
img_shape,
|
||||
y):
|
||||
def plot_random_reconstruction(
|
||||
autoencoder: AAutoencoder,
|
||||
example: np.ndarray,
|
||||
img_shape,
|
||||
y):
|
||||
output, code = autoencoder.forward(example.flatten())
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.matshow(
|
||||
@@ -114,6 +115,8 @@ def mnist_test(model: str | AAutoencoder):
|
||||
autoencoder: AAutoencoder = AAutoencoder.load(model)
|
||||
else:
|
||||
autoencoder = model
|
||||
print("Testing model ...\n")
|
||||
print(autoencoder)
|
||||
idx = np.random.randint(0, len(x_test))
|
||||
example: np.ndarray = x_test[idx]
|
||||
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])
|
||||
|
||||
Reference in New Issue
Block a user