refactor: move plot logic to plotters.py

This commit is contained in:
Lenoctambule
2026-04-09 22:47:22 +02:00
parent 9d718a6bc8
commit ea8a4079ac
6 changed files with 81 additions and 51 deletions

3
.gitignore vendored
View File

@@ -4,4 +4,5 @@ __pycache__
*.npy *.npy
.venv .venv
dist dist
*.egg-info *.egg-info
.env

View File

@@ -6,7 +6,7 @@
src="https://raw.githubusercontent.com/lenoctambule/autoencoder/refs/heads/main/media/latent-space.png" src="https://raw.githubusercontent.com/lenoctambule/autoencoder/refs/heads/main/media/latent-space.png"
alt="Latent-space of the MNIST dataset" alt="Latent-space of the MNIST dataset"
width=70%> width=70%>
<figcaption align=center> <figcaption>
<p align="center"> <p align="center">
Latent-space representation of the MNIST dataset using Variational Autoencoder Latent-space representation of the MNIST dataset using Variational Autoencoder
</p> </p>

View File

@@ -8,7 +8,6 @@ from easyvae.autoencoder import ( # noqa
AAutoencoder AAutoencoder
) )
from easyvae.activations import LeakyReLU from easyvae.activations import LeakyReLU
from easyvae.utils import dynamic_loss_plot_finish
def load_mnist() -> list[np.ndarray]: def load_mnist() -> list[np.ndarray]:
@@ -33,6 +32,7 @@ def mnist_train(
x_train.resize(x_train.shape[0], in_len) x_train.resize(x_train.shape[0], in_len)
x_test.resize(x_test.shape[0], in_len) x_test.resize(x_test.shape[0], in_len)
x_train = x_train / 255 x_train = x_train / 255
x_train = x_train[:5000]
if os.path.exists(filename): if os.path.exists(filename):
autoencoder = cls.load(filename) autoencoder = cls.load(filename)
else: else:
@@ -46,8 +46,8 @@ def mnist_train(
def handler(signum, frame): def handler(signum, frame):
print(f"Saving {filename} before exit ...") print(f"Saving {filename} before exit ...")
autoencoder.save(filename) autoencoder.save(filename)
if plt.get_fignums(): plt.close('all')
dynamic_loss_plot_finish() plt.ioff()
mnist_test(autoencoder) mnist_test(autoencoder)
exit() exit()
@@ -84,10 +84,11 @@ def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,):
plt.show() plt.show()
def plot_random_reconstruction(autoencoder: AAutoencoder, def plot_random_reconstruction(
example: np.ndarray, autoencoder: AAutoencoder,
img_shape, example: np.ndarray,
y): img_shape,
y):
output, code = autoencoder.forward(example.flatten()) output, code = autoencoder.forward(example.flatten())
plt.subplot(1, 2, 1) plt.subplot(1, 2, 1)
plt.matshow( plt.matshow(
@@ -114,6 +115,8 @@ def mnist_test(model: str | AAutoencoder):
autoencoder: AAutoencoder = AAutoencoder.load(model) autoencoder: AAutoencoder = AAutoencoder.load(model)
else: else:
autoencoder = model autoencoder = model
print("Testing model ...\n")
print(autoencoder)
idx = np.random.randint(0, len(x_test)) idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx] example: np.ndarray = x_test[idx]
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx]) plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])

View File

@@ -1,18 +1,16 @@
import numpy as np import numpy as np
from .utils import (
dynamic_loss_plot_init,
dynamic_loss_plot_update,
dynamic_loss_plot_finish
)
from tqdm import tqdm from tqdm import tqdm
from .layers import DeepNNLayer, SampleLayer from .layers import DeepNNLayer, SampleLayer
from .activations import ActivationFunc, Identity from .activations import ActivationFunc, Identity
from .plotters import Plotter, CAPlotter
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
LOADER = ['', '', '', '', '', '', '', ''] LOADER = ['', '', '', '', '', '', '', '']
class AAutoencoder(ABC): class AAutoencoder(ABC):
plotter_cls = Plotter
@abstractmethod @abstractmethod
def __init__(self, def __init__(self,
encoder_layers: list[int], encoder_layers: list[int],
@@ -27,18 +25,18 @@ class AAutoencoder(ABC):
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func) self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
self.space_dim = decoder_layers[0] self.space_dim = decoder_layers[0]
self.lr = lr self.lr = lr
self.losses = [0]
def train_dataset(self, def train_dataset(self,
data_set: list[np.ndarray], data_set: list[np.ndarray],
max_epoch: int, max_epoch: int,
patience: int, patience: int,
display_loss: bool = False) -> list[float]: display_loss: bool = False) -> list[float]:
losses = [self.loss(data_set)] plotter = self.plotter_cls(self) if display_loss else Plotter(self)
if display_loss is True: self.losses = [self.loss(data_set)]
ax, line = dynamic_loss_plot_init(losses)
epoch = 0 epoch = 0
no_improv = 0 no_improv = 0
prev_error = losses[0] prev_error = self.losses[0]
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar: with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
while True: while True:
lbar.set_description( lbar.set_description(
@@ -55,17 +53,15 @@ class AAutoencoder(ABC):
else: else:
no_improv = 0 no_improv = 0
prev_error = float(error) prev_error = float(error)
losses.append(error) self.losses.append(error)
if display_loss is True:
dynamic_loss_plot_update(ax, line, losses)
if no_improv > patience: if no_improv > patience:
break break
if epoch > max_epoch: if epoch > max_epoch:
break break
plotter.update()
epoch += 1 epoch += 1
if display_loss is True: plotter.close()
dynamic_loss_plot_finish() return self.losses
return losses
def save(self, path: str): def save(self, path: str):
path = path.removesuffix('.npy') path = path.removesuffix('.npy')
@@ -90,11 +86,19 @@ class AAutoencoder(ABC):
class ClassicalAutoencoder(AAutoencoder): class ClassicalAutoencoder(AAutoencoder):
plotter_cls = CAPlotter
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def __str__(self): def __str__(self):
return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}' return "\n".join((
f"Type: {__class__.__name__}",
"Encoder:",
f"{self.encoder}",
"Decoder:"
f"{self.decoder}"
))
def loss(self, data_set: list[np.ndarray]) -> float: def loss(self, data_set: list[np.ndarray]) -> float:
loss = 0 loss = 0

48
src/easyvae/plotters.py Normal file
View File

@@ -0,0 +1,48 @@
import matplotlib.pyplot as plt
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .autoencoder import AAutoencoder
class Plotter:
def __init__(self, autoencoder: 'AAutoencoder'):
pass
def update(self):
pass
def close(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class CAPlotter(Plotter):
def __init__(self, autoencoder: 'AAutoencoder'):
self.autoencoder = autoencoder
plt.ion()
self.fig, self.ax = plt.subplots()
self.line, = self.ax.plot(
list(range(len(autoencoder.losses))),
autoencoder.losses,
label="Loss"
)
self.ax.set_xlabel("Epoch")
self.ax.set_ylabel("Loss")
self.ax.set_title("Training MSE Loss")
self.ax.legend()
self.update()
def update(self):
self.line.set_xdata(range(len(self.autoencoder.losses)))
self.line.set_ydata(self.autoencoder.losses)
self.ax.relim()
self.ax.autoscale_view()
plt.draw()
plt.pause(0.1)
def close(self):
plt.ioff()
plt.show()

View File

@@ -1,6 +1,5 @@
import numpy as np import numpy as np
import matplotlib.pyplot as plt
def softmax(v: np.ndarray) -> np.ndarray: def softmax(v: np.ndarray) -> np.ndarray:
@@ -19,28 +18,3 @@ def regularize(v: np.ndarray) -> np.ndarray:
if v_min - v_max == 0: if v_min - v_max == 0:
return v return v
return (v - v_min) / (v_max - v_min) return (v - v_min) / (v_max - v_min)
def dynamic_loss_plot_init(losses: list):
plt.ion()
fig, ax = plt.subplots()
line, = ax.plot([0], losses, label="Loss")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("Training Loss")
ax.legend()
return ax, line
def dynamic_loss_plot_update(ax, line, loss):
line.set_xdata(range(len(loss)))
line.set_ydata(loss)
ax.relim()
ax.autoscale_view()
plt.draw()
plt.pause(0.1)
def dynamic_loss_plot_finish():
plt.ioff()
plt.show()