From ea8a4079ac4d94fcd11d2c4b9f8166dda1f09158 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Thu, 9 Apr 2026 22:47:22 +0200 Subject: [PATCH] refactor: move plot logic to plotters.py --- .gitignore | 3 ++- README.md | 2 +- examples/mnist_test.py | 17 ++++++++------ src/easyvae/autoencoder.py | 36 +++++++++++++++------------- src/easyvae/plotters.py | 48 ++++++++++++++++++++++++++++++++++++++ src/easyvae/utils.py | 26 --------------------- 6 files changed, 81 insertions(+), 51 deletions(-) create mode 100644 src/easyvae/plotters.py diff --git a/.gitignore b/.gitignore index 50eb41c..d199881 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ __pycache__ *.npy .venv dist -*.egg-info \ No newline at end of file +*.egg-info +.env \ No newline at end of file diff --git a/README.md b/README.md index 66d2ae7..49f39a3 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ src="https://raw.githubusercontent.com/lenoctambule/autoencoder/refs/heads/main/media/latent-space.png" alt="Latent-space of the MNIST dataset" width=70%> -
+

Latent-space representation of the MNIST dataset using Variational Autoencoder

diff --git a/examples/mnist_test.py b/examples/mnist_test.py index 5e68c9a..e0e37f5 100644 --- a/examples/mnist_test.py +++ b/examples/mnist_test.py @@ -8,7 +8,6 @@ from easyvae.autoencoder import ( # noqa AAutoencoder ) from easyvae.activations import LeakyReLU -from easyvae.utils import dynamic_loss_plot_finish def load_mnist() -> list[np.ndarray]: @@ -33,6 +32,7 @@ def mnist_train( x_train.resize(x_train.shape[0], in_len) x_test.resize(x_test.shape[0], in_len) x_train = x_train / 255 + x_train = x_train[:5000] if os.path.exists(filename): autoencoder = cls.load(filename) else: @@ -46,8 +46,8 @@ def mnist_train( def handler(signum, frame): print(f"Saving {filename} before exit ...") autoencoder.save(filename) - if plt.get_fignums(): - dynamic_loss_plot_finish() + plt.close('all') + plt.ioff() mnist_test(autoencoder) exit() @@ -84,10 +84,11 @@ def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,): plt.show() -def plot_random_reconstruction(autoencoder: AAutoencoder, - example: np.ndarray, - img_shape, - y): +def plot_random_reconstruction( + autoencoder: AAutoencoder, + example: np.ndarray, + img_shape, + y): output, code = autoencoder.forward(example.flatten()) plt.subplot(1, 2, 1) plt.matshow( @@ -114,6 +115,8 @@ def mnist_test(model: str | AAutoencoder): autoencoder: AAutoencoder = AAutoencoder.load(model) else: autoencoder = model + print("Testing model ...\n") + print(autoencoder) idx = np.random.randint(0, len(x_test)) example: np.ndarray = x_test[idx] plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx]) diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index 5f436c9..f6f9566 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -1,18 +1,16 @@ import numpy as np -from .utils import ( - dynamic_loss_plot_init, - dynamic_loss_plot_update, - dynamic_loss_plot_finish - ) from tqdm import tqdm from .layers import DeepNNLayer, SampleLayer from .activations import ActivationFunc, Identity +from .plotters import Plotter, CAPlotter from abc import ABC, abstractmethod LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿'] class AAutoencoder(ABC): + plotter_cls = Plotter + @abstractmethod def __init__(self, encoder_layers: list[int], @@ -27,18 +25,18 @@ class AAutoencoder(ABC): self.decoder = DeepNNLayer(decoder_layers, lr, activation_func) self.space_dim = decoder_layers[0] self.lr = lr + self.losses = [0] def train_dataset(self, data_set: list[np.ndarray], max_epoch: int, patience: int, display_loss: bool = False) -> list[float]: - losses = [self.loss(data_set)] - if display_loss is True: - ax, line = dynamic_loss_plot_init(losses) + plotter = self.plotter_cls(self) if display_loss else Plotter(self) + self.losses = [self.loss(data_set)] epoch = 0 no_improv = 0 - prev_error = losses[0] + prev_error = self.losses[0] with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar: while True: lbar.set_description( @@ -55,17 +53,15 @@ class AAutoencoder(ABC): else: no_improv = 0 prev_error = float(error) - losses.append(error) - if display_loss is True: - dynamic_loss_plot_update(ax, line, losses) + self.losses.append(error) if no_improv > patience: break if epoch > max_epoch: break + plotter.update() epoch += 1 - if display_loss is True: - dynamic_loss_plot_finish() - return losses + plotter.close() + return self.losses def save(self, path: str): path = path.removesuffix('.npy') @@ -90,11 +86,19 @@ class AAutoencoder(ABC): class ClassicalAutoencoder(AAutoencoder): + plotter_cls = CAPlotter + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __str__(self): - return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}' + return "\n".join(( + f"Type: {__class__.__name__}", + "Encoder:", + f"{self.encoder}", + "Decoder:" + f"{self.decoder}" + )) def loss(self, data_set: list[np.ndarray]) -> float: loss = 0 diff --git a/src/easyvae/plotters.py b/src/easyvae/plotters.py new file mode 100644 index 0000000..97a36f9 --- /dev/null +++ b/src/easyvae/plotters.py @@ -0,0 +1,48 @@ +import matplotlib.pyplot as plt +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .autoencoder import AAutoencoder + + +class Plotter: + def __init__(self, autoencoder: 'AAutoencoder'): + pass + + def update(self): + pass + + def close(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +class CAPlotter(Plotter): + def __init__(self, autoencoder: 'AAutoencoder'): + self.autoencoder = autoencoder + plt.ion() + self.fig, self.ax = plt.subplots() + self.line, = self.ax.plot( + list(range(len(autoencoder.losses))), + autoencoder.losses, + label="Loss" + ) + self.ax.set_xlabel("Epoch") + self.ax.set_ylabel("Loss") + self.ax.set_title("Training MSE Loss") + self.ax.legend() + self.update() + + def update(self): + self.line.set_xdata(range(len(self.autoencoder.losses))) + self.line.set_ydata(self.autoencoder.losses) + self.ax.relim() + self.ax.autoscale_view() + plt.draw() + plt.pause(0.1) + + def close(self): + plt.ioff() + plt.show() diff --git a/src/easyvae/utils.py b/src/easyvae/utils.py index 7e61971..2d10966 100644 --- a/src/easyvae/utils.py +++ b/src/easyvae/utils.py @@ -1,6 +1,5 @@ import numpy as np -import matplotlib.pyplot as plt def softmax(v: np.ndarray) -> np.ndarray: @@ -19,28 +18,3 @@ def regularize(v: np.ndarray) -> np.ndarray: if v_min - v_max == 0: return v return (v - v_min) / (v_max - v_min) - - -def dynamic_loss_plot_init(losses: list): - plt.ion() - fig, ax = plt.subplots() - line, = ax.plot([0], losses, label="Loss") - ax.set_xlabel("Epoch") - ax.set_ylabel("Loss") - ax.set_title("Training Loss") - ax.legend() - return ax, line - - -def dynamic_loss_plot_update(ax, line, loss): - line.set_xdata(range(len(loss))) - line.set_ydata(loss) - ax.relim() - ax.autoscale_view() - plt.draw() - plt.pause(0.1) - - -def dynamic_loss_plot_finish(): - plt.ioff() - plt.show()