refactor: move plot logic to plotters.py

This commit is contained in:
Lenoctambule
2026-04-09 22:47:22 +02:00
parent 9d718a6bc8
commit ea8a4079ac
6 changed files with 81 additions and 51 deletions

View File

@@ -8,7 +8,6 @@ from easyvae.autoencoder import ( # noqa
AAutoencoder
)
from easyvae.activations import LeakyReLU
from easyvae.utils import dynamic_loss_plot_finish
def load_mnist() -> list[np.ndarray]:
@@ -33,6 +32,7 @@ def mnist_train(
x_train.resize(x_train.shape[0], in_len)
x_test.resize(x_test.shape[0], in_len)
x_train = x_train / 255
x_train = x_train[:5000]
if os.path.exists(filename):
autoencoder = cls.load(filename)
else:
@@ -46,8 +46,8 @@ def mnist_train(
def handler(signum, frame):
print(f"Saving {filename} before exit ...")
autoencoder.save(filename)
if plt.get_fignums():
dynamic_loss_plot_finish()
plt.close('all')
plt.ioff()
mnist_test(autoencoder)
exit()
@@ -84,10 +84,11 @@ def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,):
plt.show()
def plot_random_reconstruction(autoencoder: AAutoencoder,
example: np.ndarray,
img_shape,
y):
def plot_random_reconstruction(
autoencoder: AAutoencoder,
example: np.ndarray,
img_shape,
y):
output, code = autoencoder.forward(example.flatten())
plt.subplot(1, 2, 1)
plt.matshow(
@@ -114,6 +115,8 @@ def mnist_test(model: str | AAutoencoder):
autoencoder: AAutoencoder = AAutoencoder.load(model)
else:
autoencoder = model
print("Testing model ...\n")
print(autoencoder)
idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx]
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])