refactor: move plot logic to plotters.py

This commit is contained in:
Lenoctambule
2026-04-09 22:47:22 +02:00
parent 9d718a6bc8
commit ea8a4079ac
6 changed files with 81 additions and 51 deletions

1
.gitignore vendored
View File

@@ -5,3 +5,4 @@ __pycache__
.venv
dist
*.egg-info
.env

View File

@@ -6,7 +6,7 @@
src="https://raw.githubusercontent.com/lenoctambule/autoencoder/refs/heads/main/media/latent-space.png"
alt="Latent-space of the MNIST dataset"
width=70%>
<figcaption align=center>
<figcaption>
<p align="center">
Latent-space representation of the MNIST dataset using Variational Autoencoder
</p>

View File

@@ -8,7 +8,6 @@ from easyvae.autoencoder import ( # noqa
AAutoencoder
)
from easyvae.activations import LeakyReLU
from easyvae.utils import dynamic_loss_plot_finish
def load_mnist() -> list[np.ndarray]:
@@ -33,6 +32,7 @@ def mnist_train(
x_train.resize(x_train.shape[0], in_len)
x_test.resize(x_test.shape[0], in_len)
x_train = x_train / 255
x_train = x_train[:5000]
if os.path.exists(filename):
autoencoder = cls.load(filename)
else:
@@ -46,8 +46,8 @@ def mnist_train(
def handler(signum, frame):
print(f"Saving {filename} before exit ...")
autoencoder.save(filename)
if plt.get_fignums():
dynamic_loss_plot_finish()
plt.close('all')
plt.ioff()
mnist_test(autoencoder)
exit()
@@ -84,7 +84,8 @@ def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,):
plt.show()
def plot_random_reconstruction(autoencoder: AAutoencoder,
def plot_random_reconstruction(
autoencoder: AAutoencoder,
example: np.ndarray,
img_shape,
y):
@@ -114,6 +115,8 @@ def mnist_test(model: str | AAutoencoder):
autoencoder: AAutoencoder = AAutoencoder.load(model)
else:
autoencoder = model
print("Testing model ...\n")
print(autoencoder)
idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx]
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])

View File

@@ -1,18 +1,16 @@
import numpy as np
from .utils import (
dynamic_loss_plot_init,
dynamic_loss_plot_update,
dynamic_loss_plot_finish
)
from tqdm import tqdm
from .layers import DeepNNLayer, SampleLayer
from .activations import ActivationFunc, Identity
from .plotters import Plotter, CAPlotter
from abc import ABC, abstractmethod
LOADER = ['', '', '', '', '', '', '', '']
class AAutoencoder(ABC):
plotter_cls = Plotter
@abstractmethod
def __init__(self,
encoder_layers: list[int],
@@ -27,18 +25,18 @@ class AAutoencoder(ABC):
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
self.space_dim = decoder_layers[0]
self.lr = lr
self.losses = [0]
def train_dataset(self,
data_set: list[np.ndarray],
max_epoch: int,
patience: int,
display_loss: bool = False) -> list[float]:
losses = [self.loss(data_set)]
if display_loss is True:
ax, line = dynamic_loss_plot_init(losses)
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
self.losses = [self.loss(data_set)]
epoch = 0
no_improv = 0
prev_error = losses[0]
prev_error = self.losses[0]
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
while True:
lbar.set_description(
@@ -55,17 +53,15 @@ class AAutoencoder(ABC):
else:
no_improv = 0
prev_error = float(error)
losses.append(error)
if display_loss is True:
dynamic_loss_plot_update(ax, line, losses)
self.losses.append(error)
if no_improv > patience:
break
if epoch > max_epoch:
break
plotter.update()
epoch += 1
if display_loss is True:
dynamic_loss_plot_finish()
return losses
plotter.close()
return self.losses
def save(self, path: str):
path = path.removesuffix('.npy')
@@ -90,11 +86,19 @@ class AAutoencoder(ABC):
class ClassicalAutoencoder(AAutoencoder):
plotter_cls = CAPlotter
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __str__(self):
return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}'
return "\n".join((
f"Type: {__class__.__name__}",
"Encoder:",
f"{self.encoder}",
"Decoder:"
f"{self.decoder}"
))
def loss(self, data_set: list[np.ndarray]) -> float:
loss = 0

48
src/easyvae/plotters.py Normal file
View File

@@ -0,0 +1,48 @@
import matplotlib.pyplot as plt
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .autoencoder import AAutoencoder
class Plotter:
def __init__(self, autoencoder: 'AAutoencoder'):
pass
def update(self):
pass
def close(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class CAPlotter(Plotter):
def __init__(self, autoencoder: 'AAutoencoder'):
self.autoencoder = autoencoder
plt.ion()
self.fig, self.ax = plt.subplots()
self.line, = self.ax.plot(
list(range(len(autoencoder.losses))),
autoencoder.losses,
label="Loss"
)
self.ax.set_xlabel("Epoch")
self.ax.set_ylabel("Loss")
self.ax.set_title("Training MSE Loss")
self.ax.legend()
self.update()
def update(self):
self.line.set_xdata(range(len(self.autoencoder.losses)))
self.line.set_ydata(self.autoencoder.losses)
self.ax.relim()
self.ax.autoscale_view()
plt.draw()
plt.pause(0.1)
def close(self):
plt.ioff()
plt.show()

View File

@@ -1,6 +1,5 @@
import numpy as np
import matplotlib.pyplot as plt
def softmax(v: np.ndarray) -> np.ndarray:
@@ -19,28 +18,3 @@ def regularize(v: np.ndarray) -> np.ndarray:
if v_min - v_max == 0:
return v
return (v - v_min) / (v_max - v_min)
def dynamic_loss_plot_init(losses: list):
plt.ion()
fig, ax = plt.subplots()
line, = ax.plot([0], losses, label="Loss")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("Training Loss")
ax.legend()
return ax, line
def dynamic_loss_plot_update(ax, line, loss):
line.set_xdata(range(len(loss)))
line.set_ydata(loss)
ax.relim()
ax.autoscale_view()
plt.draw()
plt.pause(0.1)
def dynamic_loss_plot_finish():
plt.ioff()
plt.show()