feat(plotters.py): add VAEPlotter class + seperate training logic

This commit is contained in:
Lenoctambule
2026-04-10 20:37:03 +02:00
parent 849d988de5
commit 5ff6cfe55e
2 changed files with 171 additions and 61 deletions

View File

@@ -2,7 +2,7 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from .layers import DeepNNLayer, SampleLayer from .layers import DeepNNLayer, SampleLayer
from .activations import ActivationFunc, Identity from .activations import ActivationFunc, Identity
from .plotters import Plotter, CAPlotter from .plotters import Plotter, CAPlotter, VAEPlotter
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
LOADER = ['', '', '', '', '', '', '', ''] LOADER = ['', '', '', '', '', '', '', '']
@@ -27,6 +27,65 @@ class AAutoencoder(ABC):
self.lr = lr self.lr = lr
self.losses = [0] self.losses = [0]
def save(self, path: str):
path = path.removesuffix('.npy')
np.save(path, self)
def load(path: str) -> 'ClassicalAutoencoder':
path = path.removesuffix('.npy') + '.npy'
data = np.load(path, allow_pickle=True)
return data.item()
@abstractmethod
def loss(self, data_set: list[np.ndarray]) -> float:
pass
@abstractmethod
def train(self, v: np.ndarray) -> float:
pass
@abstractmethod
def forward(self, v: np.ndarray) -> np.ndarray:
pass
@abstractmethod
def train_dataset(self, *args, **kwargs) -> list[float]:
pass
class ClassicalAutoencoder(AAutoencoder):
plotter_cls = CAPlotter
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.losses = []
def __str__(self):
return "\n".join((
f"Type: {__class__.__name__}",
"Encoder:",
f"{self.encoder}",
"Decoder:",
f"{self.decoder}"
)
)
def loss(self, data_set: list[np.ndarray]) -> float:
loss = 0
for x in data_set:
loss += np.sum(np.abs(x - self.forward(x)[0])) / len(x)
return loss / len(data_set)
def train(self, v: np.ndarray):
out = self.decoder.forward(
self.encoder.forward(v)
)
error = out - v
self.encoder.backprop(
self.decoder.backprop(error)
)
return np.sum(np.abs(error)) / len(v)
def train_dataset(self, def train_dataset(self,
data_set: list[np.ndarray], data_set: list[np.ndarray],
max_epoch: int, max_epoch: int,
@@ -63,59 +122,6 @@ class AAutoencoder(ABC):
plotter.close() plotter.close()
return self.losses return self.losses
def save(self, path: str):
path = path.removesuffix('.npy')
np.save(path, self)
def load(path: str) -> 'ClassicalAutoencoder':
path = path.removesuffix('.npy') + '.npy'
data = np.load(path, allow_pickle=True)
return data.item()
@abstractmethod
def loss(self, data_set: list[np.ndarray]) -> float:
pass
@abstractmethod
def train(self, v: np.ndarray) -> float:
pass
@abstractmethod
def forward(self, v: np.ndarray) -> np.ndarray:
pass
class ClassicalAutoencoder(AAutoencoder):
plotter_cls = CAPlotter
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __str__(self):
return "\n".join((
f"Type: {__class__.__name__}",
"Encoder:",
f"{self.encoder}",
"Decoder:",
f"{self.decoder}"
))
def loss(self, data_set: list[np.ndarray]) -> float:
loss = 0
for x in data_set:
loss += np.sum(np.abs(x - self.forward(x)[0])) / len(x)
return loss / len(data_set)
def train(self, v: np.ndarray):
out = self.decoder.forward(
self.encoder.forward(v)
)
error = out - v
self.encoder.backprop(
self.decoder.backprop(error)
)
return np.sum(np.abs(error)) / len(v)
def encode(self, v: np.ndarray) -> np.ndarray: def encode(self, v: np.ndarray) -> np.ndarray:
return self.encoder.forward(v) return self.encoder.forward(v)
@@ -129,9 +135,13 @@ class ClassicalAutoencoder(AAutoencoder):
class VariationalAutoencoder(AAutoencoder): class VariationalAutoencoder(AAutoencoder):
plotter_cls = VAEPlotter
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity()) self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity())
self.KL_losses = []
self.recon_losses = []
def __str__(self): def __str__(self):
return "\n".join(( return "\n".join((
@@ -143,15 +153,18 @@ class VariationalAutoencoder(AAutoencoder):
)) ))
def loss(self, data_set: list[np.ndarray]) -> float: def loss(self, data_set: list[np.ndarray]) -> float:
loss = 0 kl_loss = 0
recon_loss = 0
for x in data_set: for x in data_set:
out = self.forward(x)[0] out = self.forward(x)[0]
kl = self.sampler.DKL() kl = self.sampler.DKL()
loss += np.mean((out - x) ** 2) recon_loss += np.mean((out - x) ** 2)
loss += kl kl_loss += kl
return loss / len(data_set) kl_loss /= len(data_set)
recon_loss /= len(data_set)
return recon_loss, kl_loss
def train(self, v: np.ndarray) -> float: def train(self, v: np.ndarray) -> tuple[float, float]:
out, _ = self.forward(v) out, _ = self.forward(v)
error = out - v error = out - v
self.encoder.backprop( self.encoder.backprop(
@@ -159,10 +172,62 @@ class VariationalAutoencoder(AAutoencoder):
self.decoder.backprop(error) self.decoder.backprop(error)
) )
) )
return np.mean(error ** 2) + self.sampler.DKL() return np.mean(error ** 2), self.sampler.DKL()
def train_dataset(self,
data_set: list[np.ndarray],
max_epoch: int,
patience: int,
display_loss: bool = False) -> list[float]:
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
recon_0, kl_0 = self.loss(data_set)
self.recon_losses = [recon_0]
self.KL_losses = [kl_0]
epoch = 0
no_improv = 0
prev_loss = self.recon_losses[0] + self.KL_losses[0]
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
while True:
lbar.set_description(
f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} loss={float(prev_loss):.6f})", # noqa
)
lbar.update()
dkl = 0
recon = 0
for x in tqdm(data_set, leave=False):
recon_i, dkl_i = self.train(x)
dkl += dkl_i
recon += recon_i
recon /= len(data_set)
dkl /= len(data_set)
loss = recon + dkl
dloss = prev_loss - loss
if dloss <= 0 or abs(dloss) < 1e-4:
no_improv += 1
else:
no_improv = 0
prev_loss = float(loss)
self.recon_losses.append(recon)
self.KL_losses.append(dkl)
if no_improv > patience:
break
if epoch > max_epoch:
break
plotter.update()
epoch += 1
plotter.close()
return self.recon_losses
def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]: def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
code = self.encoder.forward(v) code = self.encoder.forward(v)
sample = self.sampler.forward(code) sample = self.sampler.forward(code)
out = self.decoder.forward(sample) out = self.decoder.forward(sample)
return out, code return out, code
def encode(self, v: np.ndarray) -> np.ndarray:
return self.sampler.forward(
self.encoder.forward(v)
)
def decode(self, v: np.ndarray) -> np.ndarray:
return self.decoder.forward(v)

View File

@@ -2,7 +2,7 @@ import matplotlib.pyplot as plt
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from .autoencoder import AAutoencoder from .autoencoder import AAutoencoder, VariationalAutoencoder
class Plotter: class Plotter:
@@ -46,3 +46,48 @@ class CAPlotter(Plotter):
def close(self): def close(self):
plt.ioff() plt.ioff()
plt.show() plt.show()
class VAEPlotter(Plotter):
def __init__(self, autoencoder: 'VariationalAutoencoder'):
self.autoencoder = autoencoder
plt.ion()
self.fig, (self.ax_recon, self.ax_dkl) = plt.subplots(1, 2)
self.line, = self.ax_recon.plot(
list(range(len(self.autoencoder.recon_losses))),
self.autoencoder.recon_losses,
label="Loss"
)
self.ax_recon.set_xlabel("Epoch")
self.ax_recon.set_ylabel("Loss")
self.ax_recon.set_title("Reconstruction MSE Loss")
self.ax_recon.legend()
self.dkl_line, = self.ax_dkl.plot(
list(range(len(self.autoencoder.KL_losses))),
self.autoencoder.KL_losses,
label="DKL Loss",
)
self.ax_dkl.set_xlabel("Epoch")
self.ax_dkl.set_ylabel("Loss")
self.ax_dkl.set_title("DKL Loss")
self.ax_dkl.legend()
self.update()
def update(self):
self.line.set_xdata(range(len(self.autoencoder.recon_losses)))
self.line.set_ydata(self.autoencoder.recon_losses)
self.ax_recon.relim()
self.ax_recon.autoscale_view()
self.dkl_line.set_xdata(range(len(self.autoencoder.KL_losses)))
self.dkl_line.set_ydata(self.autoencoder.KL_losses)
self.ax_dkl.relim()
self.ax_dkl.autoscale_view()
plt.draw()
plt.pause(0.1)
def close(self):
plt.ioff()
plt.show()