From 5ff6cfe55eaac5221b458f44c1485c2357369a69 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Fri, 10 Apr 2026 20:37:03 +0200 Subject: [PATCH] feat(plotters.py): add VAEPlotter class + seperate training logic --- src/easyvae/autoencoder.py | 185 +++++++++++++++++++++++++------------ src/easyvae/plotters.py | 47 +++++++++- 2 files changed, 171 insertions(+), 61 deletions(-) diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index d913699..652e664 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -2,7 +2,7 @@ import numpy as np from tqdm import tqdm from .layers import DeepNNLayer, SampleLayer from .activations import ActivationFunc, Identity -from .plotters import Plotter, CAPlotter +from .plotters import Plotter, CAPlotter, VAEPlotter from abc import ABC, abstractmethod LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿'] @@ -27,6 +27,65 @@ class AAutoencoder(ABC): self.lr = lr self.losses = [0] + def save(self, path: str): + path = path.removesuffix('.npy') + np.save(path, self) + + def load(path: str) -> 'ClassicalAutoencoder': + path = path.removesuffix('.npy') + '.npy' + data = np.load(path, allow_pickle=True) + return data.item() + + @abstractmethod + def loss(self, data_set: list[np.ndarray]) -> float: + pass + + @abstractmethod + def train(self, v: np.ndarray) -> float: + pass + + @abstractmethod + def forward(self, v: np.ndarray) -> np.ndarray: + pass + + @abstractmethod + def train_dataset(self, *args, **kwargs) -> list[float]: + pass + + +class ClassicalAutoencoder(AAutoencoder): + plotter_cls = CAPlotter + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.losses = [] + + def __str__(self): + return "\n".join(( + f"Type: {__class__.__name__}", + "Encoder:", + f"{self.encoder}", + "Decoder:", + f"{self.decoder}" + ) + ) + + def loss(self, data_set: list[np.ndarray]) -> float: + loss = 0 + for x in data_set: + loss += np.sum(np.abs(x - self.forward(x)[0])) / len(x) + return loss / len(data_set) + + def train(self, v: np.ndarray): + out = self.decoder.forward( + self.encoder.forward(v) + ) + error = out - v + self.encoder.backprop( + self.decoder.backprop(error) + ) + return np.sum(np.abs(error)) / len(v) + def train_dataset(self, data_set: list[np.ndarray], max_epoch: int, @@ -63,59 +122,6 @@ class AAutoencoder(ABC): plotter.close() return self.losses - def save(self, path: str): - path = path.removesuffix('.npy') - np.save(path, self) - - def load(path: str) -> 'ClassicalAutoencoder': - path = path.removesuffix('.npy') + '.npy' - data = np.load(path, allow_pickle=True) - return data.item() - - @abstractmethod - def loss(self, data_set: list[np.ndarray]) -> float: - pass - - @abstractmethod - def train(self, v: np.ndarray) -> float: - pass - - @abstractmethod - def forward(self, v: np.ndarray) -> np.ndarray: - pass - - -class ClassicalAutoencoder(AAutoencoder): - plotter_cls = CAPlotter - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def __str__(self): - return "\n".join(( - f"Type: {__class__.__name__}", - "Encoder:", - f"{self.encoder}", - "Decoder:", - f"{self.decoder}" - )) - - def loss(self, data_set: list[np.ndarray]) -> float: - loss = 0 - for x in data_set: - loss += np.sum(np.abs(x - self.forward(x)[0])) / len(x) - return loss / len(data_set) - - def train(self, v: np.ndarray): - out = self.decoder.forward( - self.encoder.forward(v) - ) - error = out - v - self.encoder.backprop( - self.decoder.backprop(error) - ) - return np.sum(np.abs(error)) / len(v) - def encode(self, v: np.ndarray) -> np.ndarray: return self.encoder.forward(v) @@ -129,9 +135,13 @@ class ClassicalAutoencoder(AAutoencoder): class VariationalAutoencoder(AAutoencoder): + plotter_cls = VAEPlotter + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity()) + self.KL_losses = [] + self.recon_losses = [] def __str__(self): return "\n".join(( @@ -143,15 +153,18 @@ class VariationalAutoencoder(AAutoencoder): )) def loss(self, data_set: list[np.ndarray]) -> float: - loss = 0 + kl_loss = 0 + recon_loss = 0 for x in data_set: out = self.forward(x)[0] kl = self.sampler.DKL() - loss += np.mean((out - x) ** 2) - loss += kl - return loss / len(data_set) + recon_loss += np.mean((out - x) ** 2) + kl_loss += kl + kl_loss /= len(data_set) + recon_loss /= len(data_set) + return recon_loss, kl_loss - def train(self, v: np.ndarray) -> float: + def train(self, v: np.ndarray) -> tuple[float, float]: out, _ = self.forward(v) error = out - v self.encoder.backprop( @@ -159,10 +172,62 @@ class VariationalAutoencoder(AAutoencoder): self.decoder.backprop(error) ) ) - return np.mean(error ** 2) + self.sampler.DKL() + return np.mean(error ** 2), self.sampler.DKL() + + def train_dataset(self, + data_set: list[np.ndarray], + max_epoch: int, + patience: int, + display_loss: bool = False) -> list[float]: + plotter = self.plotter_cls(self) if display_loss else Plotter(self) + recon_0, kl_0 = self.loss(data_set) + self.recon_losses = [recon_0] + self.KL_losses = [kl_0] + epoch = 0 + no_improv = 0 + prev_loss = self.recon_losses[0] + self.KL_losses[0] + with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar: + while True: + lbar.set_description( + f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} loss={float(prev_loss):.6f})", # noqa + ) + lbar.update() + dkl = 0 + recon = 0 + for x in tqdm(data_set, leave=False): + recon_i, dkl_i = self.train(x) + dkl += dkl_i + recon += recon_i + recon /= len(data_set) + dkl /= len(data_set) + loss = recon + dkl + dloss = prev_loss - loss + if dloss <= 0 or abs(dloss) < 1e-4: + no_improv += 1 + else: + no_improv = 0 + prev_loss = float(loss) + self.recon_losses.append(recon) + self.KL_losses.append(dkl) + if no_improv > patience: + break + if epoch > max_epoch: + break + plotter.update() + epoch += 1 + plotter.close() + return self.recon_losses def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]: code = self.encoder.forward(v) sample = self.sampler.forward(code) out = self.decoder.forward(sample) return out, code + + def encode(self, v: np.ndarray) -> np.ndarray: + return self.sampler.forward( + self.encoder.forward(v) + ) + + def decode(self, v: np.ndarray) -> np.ndarray: + return self.decoder.forward(v) diff --git a/src/easyvae/plotters.py b/src/easyvae/plotters.py index 97a36f9..e4ec493 100644 --- a/src/easyvae/plotters.py +++ b/src/easyvae/plotters.py @@ -2,7 +2,7 @@ import matplotlib.pyplot as plt from typing import TYPE_CHECKING if TYPE_CHECKING: - from .autoencoder import AAutoencoder + from .autoencoder import AAutoencoder, VariationalAutoencoder class Plotter: @@ -46,3 +46,48 @@ class CAPlotter(Plotter): def close(self): plt.ioff() plt.show() + + +class VAEPlotter(Plotter): + def __init__(self, autoencoder: 'VariationalAutoencoder'): + self.autoencoder = autoencoder + plt.ion() + self.fig, (self.ax_recon, self.ax_dkl) = plt.subplots(1, 2) + self.line, = self.ax_recon.plot( + list(range(len(self.autoencoder.recon_losses))), + self.autoencoder.recon_losses, + label="Loss" + ) + self.ax_recon.set_xlabel("Epoch") + self.ax_recon.set_ylabel("Loss") + self.ax_recon.set_title("Reconstruction MSE Loss") + self.ax_recon.legend() + + self.dkl_line, = self.ax_dkl.plot( + list(range(len(self.autoencoder.KL_losses))), + self.autoencoder.KL_losses, + label="DKL Loss", + ) + self.ax_dkl.set_xlabel("Epoch") + self.ax_dkl.set_ylabel("Loss") + self.ax_dkl.set_title("DKL Loss") + self.ax_dkl.legend() + self.update() + + def update(self): + self.line.set_xdata(range(len(self.autoencoder.recon_losses))) + self.line.set_ydata(self.autoencoder.recon_losses) + self.ax_recon.relim() + self.ax_recon.autoscale_view() + + self.dkl_line.set_xdata(range(len(self.autoencoder.KL_losses))) + self.dkl_line.set_ydata(self.autoencoder.KL_losses) + self.ax_dkl.relim() + self.ax_dkl.autoscale_view() + + plt.draw() + plt.pause(0.1) + + def close(self): + plt.ioff() + plt.show()