Merge pull request #2 from lenoctambule/dev

Refactor of plotting and kb interrupt logic
This commit is contained in:
Lenoctambule
2026-04-10 22:24:10 +02:00
committed by GitHub
6 changed files with 240 additions and 95 deletions

1
.gitignore vendored
View File

@@ -5,3 +5,4 @@ __pycache__
.venv .venv
dist dist
*.egg-info *.egg-info
.env

View File

@@ -6,7 +6,7 @@
src="https://raw.githubusercontent.com/lenoctambule/autoencoder/refs/heads/main/media/latent-space.png" src="https://raw.githubusercontent.com/lenoctambule/autoencoder/refs/heads/main/media/latent-space.png"
alt="Latent-space of the MNIST dataset" alt="Latent-space of the MNIST dataset"
width=70%> width=70%>
<figcaption align=center> <figcaption>
<p align="center"> <p align="center">
Latent-space representation of the MNIST dataset using Variational Autoencoder Latent-space representation of the MNIST dataset using Variational Autoencoder
</p> </p>

View File

@@ -1,14 +1,12 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import os import os
import signal
from easyvae.autoencoder import ( # noqa from easyvae.autoencoder import ( # noqa
VariationalAutoencoder, VariationalAutoencoder,
ClassicalAutoencoder, ClassicalAutoencoder,
AAutoencoder AAutoencoder
) )
from easyvae.activations import LeakyReLU from easyvae.activations import LeakyReLU
from easyvae.utils import dynamic_loss_plot_finish
def load_mnist() -> list[np.ndarray]: def load_mnist() -> list[np.ndarray]:
@@ -42,17 +40,7 @@ def mnist_train(
0.0001, 0.0001,
LeakyReLU() LeakyReLU()
) )
print("CTRL+C to interrupt training.")
def handler(signum, frame):
print(f"Saving {filename} before exit ...")
autoencoder.save(filename)
if plt.get_fignums():
dynamic_loss_plot_finish()
mnist_test(autoencoder)
exit()
signal.signal(signal.SIGINT, handler)
print("CTRL+C to exit and save model.")
autoencoder.train_dataset( autoencoder.train_dataset(
x_train, x_train,
max_epoch, max_epoch,
@@ -84,7 +72,8 @@ def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,):
plt.show() plt.show()
def plot_random_reconstruction(autoencoder: AAutoencoder, def plot_random_reconstruction(
autoencoder: AAutoencoder,
example: np.ndarray, example: np.ndarray,
img_shape, img_shape,
y): y):
@@ -99,7 +88,7 @@ def plot_random_reconstruction(autoencoder: AAutoencoder,
output.reshape(img_shape), output.reshape(img_shape),
fignum=False) fignum=False)
plt.title(f"Output ({y})") plt.title(f"Output ({y})")
print(f'{code=}') print(f'{code.tolist()}')
def mnist_test(model: str | AAutoencoder): def mnist_test(model: str | AAutoencoder):
@@ -114,6 +103,8 @@ def mnist_test(model: str | AAutoencoder):
autoencoder: AAutoencoder = AAutoencoder.load(model) autoencoder: AAutoencoder = AAutoencoder.load(model)
else: else:
autoencoder = model autoencoder = model
print("Testing model ...\n")
print(autoencoder)
idx = np.random.randint(0, len(x_test)) idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx] example: np.ndarray = x_test[idx]
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx]) plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])

View File

@@ -1,18 +1,17 @@
import numpy as np import numpy as np
from .utils import (
dynamic_loss_plot_init,
dynamic_loss_plot_update,
dynamic_loss_plot_finish
)
from tqdm import tqdm from tqdm import tqdm
from .layers import DeepNNLayer, SampleLayer from .layers import DeepNNLayer, SampleLayer
from .activations import ActivationFunc, Identity from .activations import ActivationFunc, Identity
from .plotters import Plotter, CAPlotter, VAEPlotter
from .utils import interruptable
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
LOADER = ['', '', '', '', '', '', '', ''] LOADER = ['', '', '', '', '', '', '', '']
class AAutoencoder(ABC): class AAutoencoder(ABC):
plotter_cls = Plotter
@abstractmethod @abstractmethod
def __init__(self, def __init__(self,
encoder_layers: list[int], encoder_layers: list[int],
@@ -27,45 +26,7 @@ class AAutoencoder(ABC):
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func) self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
self.space_dim = decoder_layers[0] self.space_dim = decoder_layers[0]
self.lr = lr self.lr = lr
self.losses = [0]
def train_dataset(self,
data_set: list[np.ndarray],
max_epoch: int,
patience: int,
display_loss: bool = False) -> list[float]:
losses = [self.loss(data_set)]
if display_loss is True:
ax, line = dynamic_loss_plot_init(losses)
epoch = 0
no_improv = 0
prev_error = losses[0]
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
while True:
lbar.set_description(
f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} error={float(prev_error):.6f})", # noqa
)
lbar.update()
error = 0
for x in tqdm(data_set, leave=False):
error += self.train(x)
error /= len(data_set)
derror = prev_error - error
if derror <= 0 or abs(derror) < 1e-4:
no_improv += 1
else:
no_improv = 0
prev_error = float(error)
losses.append(error)
if display_loss is True:
dynamic_loss_plot_update(ax, line, losses)
if no_improv > patience:
break
if epoch > max_epoch:
break
epoch += 1
if display_loss is True:
dynamic_loss_plot_finish()
return losses
def save(self, path: str): def save(self, path: str):
path = path.removesuffix('.npy') path = path.removesuffix('.npy')
@@ -88,13 +49,27 @@ class AAutoencoder(ABC):
def forward(self, v: np.ndarray) -> np.ndarray: def forward(self, v: np.ndarray) -> np.ndarray:
pass pass
@abstractmethod
def train_dataset(self, *args, **kwargs) -> list[float]:
pass
class ClassicalAutoencoder(AAutoencoder): class ClassicalAutoencoder(AAutoencoder):
plotter_cls = CAPlotter
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.losses = []
def __str__(self): def __str__(self):
return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}' return "\n".join((
f"Type: {__class__.__name__}",
"Encoder:",
f"{self.encoder}",
"Decoder:",
f"{self.decoder}"
)
)
def loss(self, data_set: list[np.ndarray]) -> float: def loss(self, data_set: list[np.ndarray]) -> float:
loss = 0 loss = 0
@@ -112,6 +87,41 @@ class ClassicalAutoencoder(AAutoencoder):
) )
return np.sum(np.abs(error)) / len(v) return np.sum(np.abs(error)) / len(v)
@interruptable
def train_dataset(self,
data_set: list[np.ndarray],
max_epoch: int,
patience: int,
display_loss: bool = False) -> list[float]:
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
self.losses = [self.loss(data_set)]
epoch = 0
no_improv = 0
prev_error = self.losses[0]
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
while True:
lbar.set_description(
f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} error={float(prev_error):.6f})", # noqa
)
lbar.update()
error = 0
for x in tqdm(data_set, leave=False):
error += self.train(x)
error /= len(data_set)
derror = prev_error - error
if derror <= 0 or abs(derror) < 1e-4:
no_improv += 1
else:
no_improv = 0
prev_error = float(error)
self.losses.append(error)
if no_improv > patience:
break
if epoch > max_epoch:
break
plotter.update()
epoch += 1
def encode(self, v: np.ndarray) -> np.ndarray: def encode(self, v: np.ndarray) -> np.ndarray:
return self.encoder.forward(v) return self.encoder.forward(v)
@@ -125,20 +135,36 @@ class ClassicalAutoencoder(AAutoencoder):
class VariationalAutoencoder(AAutoencoder): class VariationalAutoencoder(AAutoencoder):
plotter_cls = VAEPlotter
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity()) self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity())
self.KL_losses = []
self.recon_losses = []
def __str__(self):
return "\n".join((
f"Type: {__class__.__name__}",
"Encoder:",
f"{self.encoder}",
"Decoder:",
f"{self.decoder}"
))
def loss(self, data_set: list[np.ndarray]) -> float: def loss(self, data_set: list[np.ndarray]) -> float:
loss = 0 kl_loss = 0
recon_loss = 0
for x in data_set: for x in data_set:
out = self.forward(x)[0] out = self.forward(x)[0]
kl = self.sampler.DKL() kl = self.sampler.DKL()
loss += np.mean((out - x) ** 2) recon_loss += np.mean((out - x) ** 2)
loss += kl kl_loss += kl
return loss / len(data_set) kl_loss /= len(data_set)
recon_loss /= len(data_set)
return recon_loss, kl_loss
def train(self, v: np.ndarray) -> float: def train(self, v: np.ndarray) -> tuple[float, float]:
out, _ = self.forward(v) out, _ = self.forward(v)
error = out - v error = out - v
self.encoder.backprop( self.encoder.backprop(
@@ -146,10 +172,61 @@ class VariationalAutoencoder(AAutoencoder):
self.decoder.backprop(error) self.decoder.backprop(error)
) )
) )
return np.mean(error ** 2) + self.sampler.DKL() return np.mean(error ** 2), self.sampler.DKL()
@interruptable
def train_dataset(self,
data_set: list[np.ndarray],
max_epoch: int,
patience: int,
display_loss: bool = False) -> list[float]:
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
recon_0, kl_0 = self.loss(data_set)
self.recon_losses = [recon_0]
self.KL_losses = [kl_0]
epoch = 0
no_improv = 0
prev_loss = self.recon_losses[0] + self.KL_losses[0]
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
while True:
lbar.set_description(
f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} loss={float(prev_loss):.6f})", # noqa
)
lbar.update()
dkl = 0
recon = 0
for x in tqdm(data_set, leave=False):
recon_i, dkl_i = self.train(x)
dkl += dkl_i
recon += recon_i
recon /= len(data_set)
dkl /= len(data_set)
loss = recon + dkl
dloss = prev_loss - loss
if dloss <= 0 or abs(dloss) < 1e-4:
no_improv += 1
else:
no_improv = 0
prev_loss = float(loss)
self.recon_losses.append(recon)
self.KL_losses.append(dkl)
if no_improv > patience:
break
if epoch > max_epoch:
break
plotter.update()
epoch += 1
def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]: def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
code = self.encoder.forward(v) code = self.encoder.forward(v)
sample = self.sampler.forward(code) sample = self.sampler.forward(code)
out = self.decoder.forward(sample) out = self.decoder.forward(sample)
return out, code return out, code
def encode(self, v: np.ndarray) -> np.ndarray:
return self.sampler.forward(
self.encoder.forward(v)
)
def decode(self, v: np.ndarray) -> np.ndarray:
return self.decoder.forward(v)

93
src/easyvae/plotters.py Normal file
View File

@@ -0,0 +1,93 @@
import matplotlib.pyplot as plt
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .autoencoder import AAutoencoder, VariationalAutoencoder
class Plotter:
def __init__(self, autoencoder: 'AAutoencoder'):
pass
def update(self):
pass
def close(self):
pass
def __del__(self):
self.close()
class CAPlotter(Plotter):
def __init__(self, autoencoder: 'AAutoencoder'):
self.autoencoder = autoencoder
plt.ion()
self.fig, self.ax = plt.subplots()
self.line, = self.ax.plot(
list(range(len(autoencoder.losses))),
autoencoder.losses,
label="Loss"
)
self.ax.set_xlabel("Epoch")
self.ax.set_ylabel("Loss")
self.ax.set_title("Training MSE Loss")
self.ax.legend()
self.update()
def update(self):
self.line.set_xdata(range(len(self.autoencoder.losses)))
self.line.set_ydata(self.autoencoder.losses)
self.ax.relim()
self.ax.autoscale_view()
plt.draw()
plt.pause(0.1)
def close(self):
plt.ioff()
plt.close(self.fig)
class VAEPlotter(Plotter):
def __init__(self, autoencoder: 'VariationalAutoencoder'):
self.autoencoder = autoencoder
plt.ion()
self.fig, (self.ax_recon, self.ax_dkl) = plt.subplots(1, 2)
self.line, = self.ax_recon.plot(
list(range(len(self.autoencoder.recon_losses))),
self.autoencoder.recon_losses,
label="Loss"
)
self.ax_recon.set_xlabel("Epoch")
self.ax_recon.set_ylabel("Loss")
self.ax_recon.set_title("Reconstruction MSE Loss")
self.ax_recon.legend()
self.dkl_line, = self.ax_dkl.plot(
list(range(len(self.autoencoder.KL_losses))),
self.autoencoder.KL_losses,
label="DKL Loss",
)
self.ax_dkl.set_xlabel("Epoch")
self.ax_dkl.set_ylabel("Loss")
self.ax_dkl.set_title("DKL Loss")
self.ax_dkl.legend()
self.update()
def update(self):
self.line.set_xdata(range(len(self.autoencoder.recon_losses)))
self.line.set_ydata(self.autoencoder.recon_losses)
self.ax_recon.relim()
self.ax_recon.autoscale_view()
self.dkl_line.set_xdata(range(len(self.autoencoder.KL_losses)))
self.dkl_line.set_ydata(self.autoencoder.KL_losses)
self.ax_dkl.relim()
self.ax_dkl.autoscale_view()
plt.draw()
plt.pause(0.1)
def close(self):
plt.ioff()
plt.close(self.fig)

View File

@@ -1,6 +1,5 @@
import numpy as np import numpy as np
import matplotlib.pyplot as plt
def softmax(v: np.ndarray) -> np.ndarray: def softmax(v: np.ndarray) -> np.ndarray:
@@ -21,26 +20,10 @@ def regularize(v: np.ndarray) -> np.ndarray:
return (v - v_min) / (v_max - v_min) return (v - v_min) / (v_max - v_min)
def dynamic_loss_plot_init(losses: list): def interruptable(func):
plt.ion() def inner(*args, **kwargs):
fig, ax = plt.subplots() try:
line, = ax.plot([0], losses, label="Loss") return func(*args, **kwargs)
ax.set_xlabel("Epoch") except KeyboardInterrupt:
ax.set_ylabel("Loss") pass
ax.set_title("Training Loss") return inner
ax.legend()
return ax, line
def dynamic_loss_plot_update(ax, line, loss):
line.set_xdata(range(len(loss)))
line.set_ydata(loss)
ax.relim()
ax.autoscale_view()
plt.draw()
plt.pause(0.1)
def dynamic_loss_plot_finish():
plt.ioff()
plt.show()