Merge pull request #2 from lenoctambule/dev

Refactor of plotting and kb interrupt logic
This commit is contained in:
Lenoctambule
2026-04-10 22:24:10 +02:00
committed by GitHub
6 changed files with 240 additions and 95 deletions

1
.gitignore vendored
View File

@@ -5,3 +5,4 @@ __pycache__
.venv
dist
*.egg-info
.env

View File

@@ -6,7 +6,7 @@
src="https://raw.githubusercontent.com/lenoctambule/autoencoder/refs/heads/main/media/latent-space.png"
alt="Latent-space of the MNIST dataset"
width=70%>
<figcaption align=center>
<figcaption>
<p align="center">
Latent-space representation of the MNIST dataset using Variational Autoencoder
</p>

View File

@@ -1,14 +1,12 @@
import matplotlib.pyplot as plt
import numpy as np
import os
import signal
from easyvae.autoencoder import ( # noqa
VariationalAutoencoder,
ClassicalAutoencoder,
AAutoencoder
)
from easyvae.activations import LeakyReLU
from easyvae.utils import dynamic_loss_plot_finish
def load_mnist() -> list[np.ndarray]:
@@ -42,17 +40,7 @@ def mnist_train(
0.0001,
LeakyReLU()
)
def handler(signum, frame):
print(f"Saving {filename} before exit ...")
autoencoder.save(filename)
if plt.get_fignums():
dynamic_loss_plot_finish()
mnist_test(autoencoder)
exit()
signal.signal(signal.SIGINT, handler)
print("CTRL+C to exit and save model.")
print("CTRL+C to interrupt training.")
autoencoder.train_dataset(
x_train,
max_epoch,
@@ -84,7 +72,8 @@ def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,):
plt.show()
def plot_random_reconstruction(autoencoder: AAutoencoder,
def plot_random_reconstruction(
autoencoder: AAutoencoder,
example: np.ndarray,
img_shape,
y):
@@ -99,7 +88,7 @@ def plot_random_reconstruction(autoencoder: AAutoencoder,
output.reshape(img_shape),
fignum=False)
plt.title(f"Output ({y})")
print(f'{code=}')
print(f'{code.tolist()}')
def mnist_test(model: str | AAutoencoder):
@@ -114,6 +103,8 @@ def mnist_test(model: str | AAutoencoder):
autoencoder: AAutoencoder = AAutoencoder.load(model)
else:
autoencoder = model
print("Testing model ...\n")
print(autoencoder)
idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx]
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])

View File

@@ -1,18 +1,17 @@
import numpy as np
from .utils import (
dynamic_loss_plot_init,
dynamic_loss_plot_update,
dynamic_loss_plot_finish
)
from tqdm import tqdm
from .layers import DeepNNLayer, SampleLayer
from .activations import ActivationFunc, Identity
from .plotters import Plotter, CAPlotter, VAEPlotter
from .utils import interruptable
from abc import ABC, abstractmethod
LOADER = ['', '', '', '', '', '', '', '']
class AAutoencoder(ABC):
plotter_cls = Plotter
@abstractmethod
def __init__(self,
encoder_layers: list[int],
@@ -27,45 +26,7 @@ class AAutoencoder(ABC):
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
self.space_dim = decoder_layers[0]
self.lr = lr
def train_dataset(self,
data_set: list[np.ndarray],
max_epoch: int,
patience: int,
display_loss: bool = False) -> list[float]:
losses = [self.loss(data_set)]
if display_loss is True:
ax, line = dynamic_loss_plot_init(losses)
epoch = 0
no_improv = 0
prev_error = losses[0]
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
while True:
lbar.set_description(
f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} error={float(prev_error):.6f})", # noqa
)
lbar.update()
error = 0
for x in tqdm(data_set, leave=False):
error += self.train(x)
error /= len(data_set)
derror = prev_error - error
if derror <= 0 or abs(derror) < 1e-4:
no_improv += 1
else:
no_improv = 0
prev_error = float(error)
losses.append(error)
if display_loss is True:
dynamic_loss_plot_update(ax, line, losses)
if no_improv > patience:
break
if epoch > max_epoch:
break
epoch += 1
if display_loss is True:
dynamic_loss_plot_finish()
return losses
self.losses = [0]
def save(self, path: str):
path = path.removesuffix('.npy')
@@ -88,13 +49,27 @@ class AAutoencoder(ABC):
def forward(self, v: np.ndarray) -> np.ndarray:
pass
@abstractmethod
def train_dataset(self, *args, **kwargs) -> list[float]:
pass
class ClassicalAutoencoder(AAutoencoder):
plotter_cls = CAPlotter
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.losses = []
def __str__(self):
return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}'
return "\n".join((
f"Type: {__class__.__name__}",
"Encoder:",
f"{self.encoder}",
"Decoder:",
f"{self.decoder}"
)
)
def loss(self, data_set: list[np.ndarray]) -> float:
loss = 0
@@ -112,6 +87,41 @@ class ClassicalAutoencoder(AAutoencoder):
)
return np.sum(np.abs(error)) / len(v)
@interruptable
def train_dataset(self,
data_set: list[np.ndarray],
max_epoch: int,
patience: int,
display_loss: bool = False) -> list[float]:
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
self.losses = [self.loss(data_set)]
epoch = 0
no_improv = 0
prev_error = self.losses[0]
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
while True:
lbar.set_description(
f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} error={float(prev_error):.6f})", # noqa
)
lbar.update()
error = 0
for x in tqdm(data_set, leave=False):
error += self.train(x)
error /= len(data_set)
derror = prev_error - error
if derror <= 0 or abs(derror) < 1e-4:
no_improv += 1
else:
no_improv = 0
prev_error = float(error)
self.losses.append(error)
if no_improv > patience:
break
if epoch > max_epoch:
break
plotter.update()
epoch += 1
def encode(self, v: np.ndarray) -> np.ndarray:
return self.encoder.forward(v)
@@ -125,20 +135,36 @@ class ClassicalAutoencoder(AAutoencoder):
class VariationalAutoencoder(AAutoencoder):
plotter_cls = VAEPlotter
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity())
self.KL_losses = []
self.recon_losses = []
def __str__(self):
return "\n".join((
f"Type: {__class__.__name__}",
"Encoder:",
f"{self.encoder}",
"Decoder:",
f"{self.decoder}"
))
def loss(self, data_set: list[np.ndarray]) -> float:
loss = 0
kl_loss = 0
recon_loss = 0
for x in data_set:
out = self.forward(x)[0]
kl = self.sampler.DKL()
loss += np.mean((out - x) ** 2)
loss += kl
return loss / len(data_set)
recon_loss += np.mean((out - x) ** 2)
kl_loss += kl
kl_loss /= len(data_set)
recon_loss /= len(data_set)
return recon_loss, kl_loss
def train(self, v: np.ndarray) -> float:
def train(self, v: np.ndarray) -> tuple[float, float]:
out, _ = self.forward(v)
error = out - v
self.encoder.backprop(
@@ -146,10 +172,61 @@ class VariationalAutoencoder(AAutoencoder):
self.decoder.backprop(error)
)
)
return np.mean(error ** 2) + self.sampler.DKL()
return np.mean(error ** 2), self.sampler.DKL()
@interruptable
def train_dataset(self,
data_set: list[np.ndarray],
max_epoch: int,
patience: int,
display_loss: bool = False) -> list[float]:
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
recon_0, kl_0 = self.loss(data_set)
self.recon_losses = [recon_0]
self.KL_losses = [kl_0]
epoch = 0
no_improv = 0
prev_loss = self.recon_losses[0] + self.KL_losses[0]
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
while True:
lbar.set_description(
f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} loss={float(prev_loss):.6f})", # noqa
)
lbar.update()
dkl = 0
recon = 0
for x in tqdm(data_set, leave=False):
recon_i, dkl_i = self.train(x)
dkl += dkl_i
recon += recon_i
recon /= len(data_set)
dkl /= len(data_set)
loss = recon + dkl
dloss = prev_loss - loss
if dloss <= 0 or abs(dloss) < 1e-4:
no_improv += 1
else:
no_improv = 0
prev_loss = float(loss)
self.recon_losses.append(recon)
self.KL_losses.append(dkl)
if no_improv > patience:
break
if epoch > max_epoch:
break
plotter.update()
epoch += 1
def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
code = self.encoder.forward(v)
sample = self.sampler.forward(code)
out = self.decoder.forward(sample)
return out, code
def encode(self, v: np.ndarray) -> np.ndarray:
return self.sampler.forward(
self.encoder.forward(v)
)
def decode(self, v: np.ndarray) -> np.ndarray:
return self.decoder.forward(v)

93
src/easyvae/plotters.py Normal file
View File

@@ -0,0 +1,93 @@
import matplotlib.pyplot as plt
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .autoencoder import AAutoencoder, VariationalAutoencoder
class Plotter:
def __init__(self, autoencoder: 'AAutoencoder'):
pass
def update(self):
pass
def close(self):
pass
def __del__(self):
self.close()
class CAPlotter(Plotter):
def __init__(self, autoencoder: 'AAutoencoder'):
self.autoencoder = autoencoder
plt.ion()
self.fig, self.ax = plt.subplots()
self.line, = self.ax.plot(
list(range(len(autoencoder.losses))),
autoencoder.losses,
label="Loss"
)
self.ax.set_xlabel("Epoch")
self.ax.set_ylabel("Loss")
self.ax.set_title("Training MSE Loss")
self.ax.legend()
self.update()
def update(self):
self.line.set_xdata(range(len(self.autoencoder.losses)))
self.line.set_ydata(self.autoencoder.losses)
self.ax.relim()
self.ax.autoscale_view()
plt.draw()
plt.pause(0.1)
def close(self):
plt.ioff()
plt.close(self.fig)
class VAEPlotter(Plotter):
def __init__(self, autoencoder: 'VariationalAutoencoder'):
self.autoencoder = autoencoder
plt.ion()
self.fig, (self.ax_recon, self.ax_dkl) = plt.subplots(1, 2)
self.line, = self.ax_recon.plot(
list(range(len(self.autoencoder.recon_losses))),
self.autoencoder.recon_losses,
label="Loss"
)
self.ax_recon.set_xlabel("Epoch")
self.ax_recon.set_ylabel("Loss")
self.ax_recon.set_title("Reconstruction MSE Loss")
self.ax_recon.legend()
self.dkl_line, = self.ax_dkl.plot(
list(range(len(self.autoencoder.KL_losses))),
self.autoencoder.KL_losses,
label="DKL Loss",
)
self.ax_dkl.set_xlabel("Epoch")
self.ax_dkl.set_ylabel("Loss")
self.ax_dkl.set_title("DKL Loss")
self.ax_dkl.legend()
self.update()
def update(self):
self.line.set_xdata(range(len(self.autoencoder.recon_losses)))
self.line.set_ydata(self.autoencoder.recon_losses)
self.ax_recon.relim()
self.ax_recon.autoscale_view()
self.dkl_line.set_xdata(range(len(self.autoencoder.KL_losses)))
self.dkl_line.set_ydata(self.autoencoder.KL_losses)
self.ax_dkl.relim()
self.ax_dkl.autoscale_view()
plt.draw()
plt.pause(0.1)
def close(self):
plt.ioff()
plt.close(self.fig)

View File

@@ -1,6 +1,5 @@
import numpy as np
import matplotlib.pyplot as plt
def softmax(v: np.ndarray) -> np.ndarray:
@@ -21,26 +20,10 @@ def regularize(v: np.ndarray) -> np.ndarray:
return (v - v_min) / (v_max - v_min)
def dynamic_loss_plot_init(losses: list):
plt.ion()
fig, ax = plt.subplots()
line, = ax.plot([0], losses, label="Loss")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("Training Loss")
ax.legend()
return ax, line
def dynamic_loss_plot_update(ax, line, loss):
line.set_xdata(range(len(loss)))
line.set_ydata(loss)
ax.relim()
ax.autoscale_view()
plt.draw()
plt.pause(0.1)
def dynamic_loss_plot_finish():
plt.ioff()
plt.show()
def interruptable(func):
def inner(*args, **kwargs):
try:
return func(*args, **kwargs)
except KeyboardInterrupt:
pass
return inner