From ea8a4079ac4d94fcd11d2c4b9f8166dda1f09158 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Thu, 9 Apr 2026 22:47:22 +0200 Subject: [PATCH 1/4] refactor: move plot logic to plotters.py --- .gitignore | 3 ++- README.md | 2 +- examples/mnist_test.py | 17 ++++++++------ src/easyvae/autoencoder.py | 36 +++++++++++++++------------- src/easyvae/plotters.py | 48 ++++++++++++++++++++++++++++++++++++++ src/easyvae/utils.py | 26 --------------------- 6 files changed, 81 insertions(+), 51 deletions(-) create mode 100644 src/easyvae/plotters.py diff --git a/.gitignore b/.gitignore index 50eb41c..d199881 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ __pycache__ *.npy .venv dist -*.egg-info \ No newline at end of file +*.egg-info +.env \ No newline at end of file diff --git a/README.md b/README.md index 66d2ae7..49f39a3 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ src="https://raw.githubusercontent.com/lenoctambule/autoencoder/refs/heads/main/media/latent-space.png" alt="Latent-space of the MNIST dataset" width=70%> -
+

Latent-space representation of the MNIST dataset using Variational Autoencoder

diff --git a/examples/mnist_test.py b/examples/mnist_test.py index 5e68c9a..e0e37f5 100644 --- a/examples/mnist_test.py +++ b/examples/mnist_test.py @@ -8,7 +8,6 @@ from easyvae.autoencoder import ( # noqa AAutoencoder ) from easyvae.activations import LeakyReLU -from easyvae.utils import dynamic_loss_plot_finish def load_mnist() -> list[np.ndarray]: @@ -33,6 +32,7 @@ def mnist_train( x_train.resize(x_train.shape[0], in_len) x_test.resize(x_test.shape[0], in_len) x_train = x_train / 255 + x_train = x_train[:5000] if os.path.exists(filename): autoencoder = cls.load(filename) else: @@ -46,8 +46,8 @@ def mnist_train( def handler(signum, frame): print(f"Saving {filename} before exit ...") autoencoder.save(filename) - if plt.get_fignums(): - dynamic_loss_plot_finish() + plt.close('all') + plt.ioff() mnist_test(autoencoder) exit() @@ -84,10 +84,11 @@ def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,): plt.show() -def plot_random_reconstruction(autoencoder: AAutoencoder, - example: np.ndarray, - img_shape, - y): +def plot_random_reconstruction( + autoencoder: AAutoencoder, + example: np.ndarray, + img_shape, + y): output, code = autoencoder.forward(example.flatten()) plt.subplot(1, 2, 1) plt.matshow( @@ -114,6 +115,8 @@ def mnist_test(model: str | AAutoencoder): autoencoder: AAutoencoder = AAutoencoder.load(model) else: autoencoder = model + print("Testing model ...\n") + print(autoencoder) idx = np.random.randint(0, len(x_test)) example: np.ndarray = x_test[idx] plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx]) diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index 5f436c9..f6f9566 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -1,18 +1,16 @@ import numpy as np -from .utils import ( - dynamic_loss_plot_init, - dynamic_loss_plot_update, - dynamic_loss_plot_finish - ) from tqdm import tqdm from .layers import DeepNNLayer, SampleLayer from .activations import ActivationFunc, Identity +from .plotters import Plotter, CAPlotter from abc import ABC, abstractmethod LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿'] class AAutoencoder(ABC): + plotter_cls = Plotter + @abstractmethod def __init__(self, encoder_layers: list[int], @@ -27,18 +25,18 @@ class AAutoencoder(ABC): self.decoder = DeepNNLayer(decoder_layers, lr, activation_func) self.space_dim = decoder_layers[0] self.lr = lr + self.losses = [0] def train_dataset(self, data_set: list[np.ndarray], max_epoch: int, patience: int, display_loss: bool = False) -> list[float]: - losses = [self.loss(data_set)] - if display_loss is True: - ax, line = dynamic_loss_plot_init(losses) + plotter = self.plotter_cls(self) if display_loss else Plotter(self) + self.losses = [self.loss(data_set)] epoch = 0 no_improv = 0 - prev_error = losses[0] + prev_error = self.losses[0] with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar: while True: lbar.set_description( @@ -55,17 +53,15 @@ class AAutoencoder(ABC): else: no_improv = 0 prev_error = float(error) - losses.append(error) - if display_loss is True: - dynamic_loss_plot_update(ax, line, losses) + self.losses.append(error) if no_improv > patience: break if epoch > max_epoch: break + plotter.update() epoch += 1 - if display_loss is True: - dynamic_loss_plot_finish() - return losses + plotter.close() + return self.losses def save(self, path: str): path = path.removesuffix('.npy') @@ -90,11 +86,19 @@ class AAutoencoder(ABC): class ClassicalAutoencoder(AAutoencoder): + plotter_cls = CAPlotter + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __str__(self): - return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}' + return "\n".join(( + f"Type: {__class__.__name__}", + "Encoder:", + f"{self.encoder}", + "Decoder:" + f"{self.decoder}" + )) def loss(self, data_set: list[np.ndarray]) -> float: loss = 0 diff --git a/src/easyvae/plotters.py b/src/easyvae/plotters.py new file mode 100644 index 0000000..97a36f9 --- /dev/null +++ b/src/easyvae/plotters.py @@ -0,0 +1,48 @@ +import matplotlib.pyplot as plt +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .autoencoder import AAutoencoder + + +class Plotter: + def __init__(self, autoencoder: 'AAutoencoder'): + pass + + def update(self): + pass + + def close(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +class CAPlotter(Plotter): + def __init__(self, autoencoder: 'AAutoencoder'): + self.autoencoder = autoencoder + plt.ion() + self.fig, self.ax = plt.subplots() + self.line, = self.ax.plot( + list(range(len(autoencoder.losses))), + autoencoder.losses, + label="Loss" + ) + self.ax.set_xlabel("Epoch") + self.ax.set_ylabel("Loss") + self.ax.set_title("Training MSE Loss") + self.ax.legend() + self.update() + + def update(self): + self.line.set_xdata(range(len(self.autoencoder.losses))) + self.line.set_ydata(self.autoencoder.losses) + self.ax.relim() + self.ax.autoscale_view() + plt.draw() + plt.pause(0.1) + + def close(self): + plt.ioff() + plt.show() diff --git a/src/easyvae/utils.py b/src/easyvae/utils.py index 7e61971..2d10966 100644 --- a/src/easyvae/utils.py +++ b/src/easyvae/utils.py @@ -1,6 +1,5 @@ import numpy as np -import matplotlib.pyplot as plt def softmax(v: np.ndarray) -> np.ndarray: @@ -19,28 +18,3 @@ def regularize(v: np.ndarray) -> np.ndarray: if v_min - v_max == 0: return v return (v - v_min) / (v_max - v_min) - - -def dynamic_loss_plot_init(losses: list): - plt.ion() - fig, ax = plt.subplots() - line, = ax.plot([0], losses, label="Loss") - ax.set_xlabel("Epoch") - ax.set_ylabel("Loss") - ax.set_title("Training Loss") - ax.legend() - return ax, line - - -def dynamic_loss_plot_update(ax, line, loss): - line.set_xdata(range(len(loss))) - line.set_ydata(loss) - ax.relim() - ax.autoscale_view() - plt.draw() - plt.pause(0.1) - - -def dynamic_loss_plot_finish(): - plt.ioff() - plt.show() From 849d988de5833712f5caeba48881cc6001a97e69 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Fri, 10 Apr 2026 15:00:04 +0200 Subject: [PATCH 2/4] feat(autoencoder.py): __str__ method for VariationalAutoencoder class --- src/easyvae/autoencoder.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index f6f9566..d913699 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -96,7 +96,7 @@ class ClassicalAutoencoder(AAutoencoder): f"Type: {__class__.__name__}", "Encoder:", f"{self.encoder}", - "Decoder:" + "Decoder:", f"{self.decoder}" )) @@ -133,6 +133,15 @@ class VariationalAutoencoder(AAutoencoder): super().__init__(*args, **kwargs) self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity()) + def __str__(self): + return "\n".join(( + f"Type: {__class__.__name__}", + "Encoder:", + f"{self.encoder}", + "Decoder:", + f"{self.decoder}" + )) + def loss(self, data_set: list[np.ndarray]) -> float: loss = 0 for x in data_set: From 5ff6cfe55eaac5221b458f44c1485c2357369a69 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Fri, 10 Apr 2026 20:37:03 +0200 Subject: [PATCH 3/4] feat(plotters.py): add VAEPlotter class + seperate training logic --- src/easyvae/autoencoder.py | 185 +++++++++++++++++++++++++------------ src/easyvae/plotters.py | 47 +++++++++- 2 files changed, 171 insertions(+), 61 deletions(-) diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index d913699..652e664 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -2,7 +2,7 @@ import numpy as np from tqdm import tqdm from .layers import DeepNNLayer, SampleLayer from .activations import ActivationFunc, Identity -from .plotters import Plotter, CAPlotter +from .plotters import Plotter, CAPlotter, VAEPlotter from abc import ABC, abstractmethod LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿'] @@ -27,6 +27,65 @@ class AAutoencoder(ABC): self.lr = lr self.losses = [0] + def save(self, path: str): + path = path.removesuffix('.npy') + np.save(path, self) + + def load(path: str) -> 'ClassicalAutoencoder': + path = path.removesuffix('.npy') + '.npy' + data = np.load(path, allow_pickle=True) + return data.item() + + @abstractmethod + def loss(self, data_set: list[np.ndarray]) -> float: + pass + + @abstractmethod + def train(self, v: np.ndarray) -> float: + pass + + @abstractmethod + def forward(self, v: np.ndarray) -> np.ndarray: + pass + + @abstractmethod + def train_dataset(self, *args, **kwargs) -> list[float]: + pass + + +class ClassicalAutoencoder(AAutoencoder): + plotter_cls = CAPlotter + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.losses = [] + + def __str__(self): + return "\n".join(( + f"Type: {__class__.__name__}", + "Encoder:", + f"{self.encoder}", + "Decoder:", + f"{self.decoder}" + ) + ) + + def loss(self, data_set: list[np.ndarray]) -> float: + loss = 0 + for x in data_set: + loss += np.sum(np.abs(x - self.forward(x)[0])) / len(x) + return loss / len(data_set) + + def train(self, v: np.ndarray): + out = self.decoder.forward( + self.encoder.forward(v) + ) + error = out - v + self.encoder.backprop( + self.decoder.backprop(error) + ) + return np.sum(np.abs(error)) / len(v) + def train_dataset(self, data_set: list[np.ndarray], max_epoch: int, @@ -63,59 +122,6 @@ class AAutoencoder(ABC): plotter.close() return self.losses - def save(self, path: str): - path = path.removesuffix('.npy') - np.save(path, self) - - def load(path: str) -> 'ClassicalAutoencoder': - path = path.removesuffix('.npy') + '.npy' - data = np.load(path, allow_pickle=True) - return data.item() - - @abstractmethod - def loss(self, data_set: list[np.ndarray]) -> float: - pass - - @abstractmethod - def train(self, v: np.ndarray) -> float: - pass - - @abstractmethod - def forward(self, v: np.ndarray) -> np.ndarray: - pass - - -class ClassicalAutoencoder(AAutoencoder): - plotter_cls = CAPlotter - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def __str__(self): - return "\n".join(( - f"Type: {__class__.__name__}", - "Encoder:", - f"{self.encoder}", - "Decoder:", - f"{self.decoder}" - )) - - def loss(self, data_set: list[np.ndarray]) -> float: - loss = 0 - for x in data_set: - loss += np.sum(np.abs(x - self.forward(x)[0])) / len(x) - return loss / len(data_set) - - def train(self, v: np.ndarray): - out = self.decoder.forward( - self.encoder.forward(v) - ) - error = out - v - self.encoder.backprop( - self.decoder.backprop(error) - ) - return np.sum(np.abs(error)) / len(v) - def encode(self, v: np.ndarray) -> np.ndarray: return self.encoder.forward(v) @@ -129,9 +135,13 @@ class ClassicalAutoencoder(AAutoencoder): class VariationalAutoencoder(AAutoencoder): + plotter_cls = VAEPlotter + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity()) + self.KL_losses = [] + self.recon_losses = [] def __str__(self): return "\n".join(( @@ -143,15 +153,18 @@ class VariationalAutoencoder(AAutoencoder): )) def loss(self, data_set: list[np.ndarray]) -> float: - loss = 0 + kl_loss = 0 + recon_loss = 0 for x in data_set: out = self.forward(x)[0] kl = self.sampler.DKL() - loss += np.mean((out - x) ** 2) - loss += kl - return loss / len(data_set) + recon_loss += np.mean((out - x) ** 2) + kl_loss += kl + kl_loss /= len(data_set) + recon_loss /= len(data_set) + return recon_loss, kl_loss - def train(self, v: np.ndarray) -> float: + def train(self, v: np.ndarray) -> tuple[float, float]: out, _ = self.forward(v) error = out - v self.encoder.backprop( @@ -159,10 +172,62 @@ class VariationalAutoencoder(AAutoencoder): self.decoder.backprop(error) ) ) - return np.mean(error ** 2) + self.sampler.DKL() + return np.mean(error ** 2), self.sampler.DKL() + + def train_dataset(self, + data_set: list[np.ndarray], + max_epoch: int, + patience: int, + display_loss: bool = False) -> list[float]: + plotter = self.plotter_cls(self) if display_loss else Plotter(self) + recon_0, kl_0 = self.loss(data_set) + self.recon_losses = [recon_0] + self.KL_losses = [kl_0] + epoch = 0 + no_improv = 0 + prev_loss = self.recon_losses[0] + self.KL_losses[0] + with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar: + while True: + lbar.set_description( + f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} loss={float(prev_loss):.6f})", # noqa + ) + lbar.update() + dkl = 0 + recon = 0 + for x in tqdm(data_set, leave=False): + recon_i, dkl_i = self.train(x) + dkl += dkl_i + recon += recon_i + recon /= len(data_set) + dkl /= len(data_set) + loss = recon + dkl + dloss = prev_loss - loss + if dloss <= 0 or abs(dloss) < 1e-4: + no_improv += 1 + else: + no_improv = 0 + prev_loss = float(loss) + self.recon_losses.append(recon) + self.KL_losses.append(dkl) + if no_improv > patience: + break + if epoch > max_epoch: + break + plotter.update() + epoch += 1 + plotter.close() + return self.recon_losses def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]: code = self.encoder.forward(v) sample = self.sampler.forward(code) out = self.decoder.forward(sample) return out, code + + def encode(self, v: np.ndarray) -> np.ndarray: + return self.sampler.forward( + self.encoder.forward(v) + ) + + def decode(self, v: np.ndarray) -> np.ndarray: + return self.decoder.forward(v) diff --git a/src/easyvae/plotters.py b/src/easyvae/plotters.py index 97a36f9..e4ec493 100644 --- a/src/easyvae/plotters.py +++ b/src/easyvae/plotters.py @@ -2,7 +2,7 @@ import matplotlib.pyplot as plt from typing import TYPE_CHECKING if TYPE_CHECKING: - from .autoencoder import AAutoencoder + from .autoencoder import AAutoencoder, VariationalAutoencoder class Plotter: @@ -46,3 +46,48 @@ class CAPlotter(Plotter): def close(self): plt.ioff() plt.show() + + +class VAEPlotter(Plotter): + def __init__(self, autoencoder: 'VariationalAutoencoder'): + self.autoencoder = autoencoder + plt.ion() + self.fig, (self.ax_recon, self.ax_dkl) = plt.subplots(1, 2) + self.line, = self.ax_recon.plot( + list(range(len(self.autoencoder.recon_losses))), + self.autoencoder.recon_losses, + label="Loss" + ) + self.ax_recon.set_xlabel("Epoch") + self.ax_recon.set_ylabel("Loss") + self.ax_recon.set_title("Reconstruction MSE Loss") + self.ax_recon.legend() + + self.dkl_line, = self.ax_dkl.plot( + list(range(len(self.autoencoder.KL_losses))), + self.autoencoder.KL_losses, + label="DKL Loss", + ) + self.ax_dkl.set_xlabel("Epoch") + self.ax_dkl.set_ylabel("Loss") + self.ax_dkl.set_title("DKL Loss") + self.ax_dkl.legend() + self.update() + + def update(self): + self.line.set_xdata(range(len(self.autoencoder.recon_losses))) + self.line.set_ydata(self.autoencoder.recon_losses) + self.ax_recon.relim() + self.ax_recon.autoscale_view() + + self.dkl_line.set_xdata(range(len(self.autoencoder.KL_losses))) + self.dkl_line.set_ydata(self.autoencoder.KL_losses) + self.ax_dkl.relim() + self.ax_dkl.autoscale_view() + + plt.draw() + plt.pause(0.1) + + def close(self): + plt.ioff() + plt.show() From 7a822782a56a0338548e9f586d6a91783a182325 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Fri, 10 Apr 2026 22:20:35 +0200 Subject: [PATCH 4/4] refactor: move kb interrupt handling to autoencoder classes --- examples/mnist_test.py | 16 ++-------------- src/easyvae/autoencoder.py | 7 +++---- src/easyvae/plotters.py | 6 +++--- src/easyvae/utils.py | 9 +++++++++ 4 files changed, 17 insertions(+), 21 deletions(-) diff --git a/examples/mnist_test.py b/examples/mnist_test.py index e0e37f5..160cb16 100644 --- a/examples/mnist_test.py +++ b/examples/mnist_test.py @@ -1,7 +1,6 @@ import matplotlib.pyplot as plt import numpy as np import os -import signal from easyvae.autoencoder import ( # noqa VariationalAutoencoder, ClassicalAutoencoder, @@ -32,7 +31,6 @@ def mnist_train( x_train.resize(x_train.shape[0], in_len) x_test.resize(x_test.shape[0], in_len) x_train = x_train / 255 - x_train = x_train[:5000] if os.path.exists(filename): autoencoder = cls.load(filename) else: @@ -42,17 +40,7 @@ def mnist_train( 0.0001, LeakyReLU() ) - - def handler(signum, frame): - print(f"Saving {filename} before exit ...") - autoencoder.save(filename) - plt.close('all') - plt.ioff() - mnist_test(autoencoder) - exit() - - signal.signal(signal.SIGINT, handler) - print("CTRL+C to exit and save model.") + print("CTRL+C to interrupt training.") autoencoder.train_dataset( x_train, max_epoch, @@ -100,7 +88,7 @@ def plot_random_reconstruction( output.reshape(img_shape), fignum=False) plt.title(f"Output ({y})") - print(f'{code=}') + print(f'{code.tolist()}') def mnist_test(model: str | AAutoencoder): diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index 652e664..3926481 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -3,6 +3,7 @@ from tqdm import tqdm from .layers import DeepNNLayer, SampleLayer from .activations import ActivationFunc, Identity from .plotters import Plotter, CAPlotter, VAEPlotter +from .utils import interruptable from abc import ABC, abstractmethod LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿'] @@ -86,6 +87,7 @@ class ClassicalAutoencoder(AAutoencoder): ) return np.sum(np.abs(error)) / len(v) + @interruptable def train_dataset(self, data_set: list[np.ndarray], max_epoch: int, @@ -119,8 +121,6 @@ class ClassicalAutoencoder(AAutoencoder): break plotter.update() epoch += 1 - plotter.close() - return self.losses def encode(self, v: np.ndarray) -> np.ndarray: return self.encoder.forward(v) @@ -174,6 +174,7 @@ class VariationalAutoencoder(AAutoencoder): ) return np.mean(error ** 2), self.sampler.DKL() + @interruptable def train_dataset(self, data_set: list[np.ndarray], max_epoch: int, @@ -215,8 +216,6 @@ class VariationalAutoencoder(AAutoencoder): break plotter.update() epoch += 1 - plotter.close() - return self.recon_losses def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]: code = self.encoder.forward(v) diff --git a/src/easyvae/plotters.py b/src/easyvae/plotters.py index e4ec493..6750b43 100644 --- a/src/easyvae/plotters.py +++ b/src/easyvae/plotters.py @@ -15,7 +15,7 @@ class Plotter: def close(self): pass - def __exit__(self, exc_type, exc_val, exc_tb): + def __del__(self): self.close() @@ -45,7 +45,7 @@ class CAPlotter(Plotter): def close(self): plt.ioff() - plt.show() + plt.close(self.fig) class VAEPlotter(Plotter): @@ -90,4 +90,4 @@ class VAEPlotter(Plotter): def close(self): plt.ioff() - plt.show() + plt.close(self.fig) diff --git a/src/easyvae/utils.py b/src/easyvae/utils.py index 2d10966..414db1a 100644 --- a/src/easyvae/utils.py +++ b/src/easyvae/utils.py @@ -18,3 +18,12 @@ def regularize(v: np.ndarray) -> np.ndarray: if v_min - v_max == 0: return v return (v - v_min) / (v_max - v_min) + + +def interruptable(func): + def inner(*args, **kwargs): + try: + return func(*args, **kwargs) + except KeyboardInterrupt: + pass + return inner