feat(plotters.py): add VAEPlotter class + seperate training logic
This commit is contained in:
@@ -2,7 +2,7 @@ import numpy as np
|
||||
from tqdm import tqdm
|
||||
from .layers import DeepNNLayer, SampleLayer
|
||||
from .activations import ActivationFunc, Identity
|
||||
from .plotters import Plotter, CAPlotter
|
||||
from .plotters import Plotter, CAPlotter, VAEPlotter
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
||||
@@ -27,6 +27,65 @@ class AAutoencoder(ABC):
|
||||
self.lr = lr
|
||||
self.losses = [0]
|
||||
|
||||
def save(self, path: str):
|
||||
path = path.removesuffix('.npy')
|
||||
np.save(path, self)
|
||||
|
||||
def load(path: str) -> 'ClassicalAutoencoder':
|
||||
path = path.removesuffix('.npy') + '.npy'
|
||||
data = np.load(path, allow_pickle=True)
|
||||
return data.item()
|
||||
|
||||
@abstractmethod
|
||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def train(self, v: np.ndarray) -> float:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, v: np.ndarray) -> np.ndarray:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def train_dataset(self, *args, **kwargs) -> list[float]:
|
||||
pass
|
||||
|
||||
|
||||
class ClassicalAutoencoder(AAutoencoder):
|
||||
plotter_cls = CAPlotter
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.losses = []
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join((
|
||||
f"Type: {__class__.__name__}",
|
||||
"Encoder:",
|
||||
f"{self.encoder}",
|
||||
"Decoder:",
|
||||
f"{self.decoder}"
|
||||
)
|
||||
)
|
||||
|
||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||
loss = 0
|
||||
for x in data_set:
|
||||
loss += np.sum(np.abs(x - self.forward(x)[0])) / len(x)
|
||||
return loss / len(data_set)
|
||||
|
||||
def train(self, v: np.ndarray):
|
||||
out = self.decoder.forward(
|
||||
self.encoder.forward(v)
|
||||
)
|
||||
error = out - v
|
||||
self.encoder.backprop(
|
||||
self.decoder.backprop(error)
|
||||
)
|
||||
return np.sum(np.abs(error)) / len(v)
|
||||
|
||||
def train_dataset(self,
|
||||
data_set: list[np.ndarray],
|
||||
max_epoch: int,
|
||||
@@ -63,59 +122,6 @@ class AAutoencoder(ABC):
|
||||
plotter.close()
|
||||
return self.losses
|
||||
|
||||
def save(self, path: str):
|
||||
path = path.removesuffix('.npy')
|
||||
np.save(path, self)
|
||||
|
||||
def load(path: str) -> 'ClassicalAutoencoder':
|
||||
path = path.removesuffix('.npy') + '.npy'
|
||||
data = np.load(path, allow_pickle=True)
|
||||
return data.item()
|
||||
|
||||
@abstractmethod
|
||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def train(self, v: np.ndarray) -> float:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, v: np.ndarray) -> np.ndarray:
|
||||
pass
|
||||
|
||||
|
||||
class ClassicalAutoencoder(AAutoencoder):
|
||||
plotter_cls = CAPlotter
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join((
|
||||
f"Type: {__class__.__name__}",
|
||||
"Encoder:",
|
||||
f"{self.encoder}",
|
||||
"Decoder:",
|
||||
f"{self.decoder}"
|
||||
))
|
||||
|
||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||
loss = 0
|
||||
for x in data_set:
|
||||
loss += np.sum(np.abs(x - self.forward(x)[0])) / len(x)
|
||||
return loss / len(data_set)
|
||||
|
||||
def train(self, v: np.ndarray):
|
||||
out = self.decoder.forward(
|
||||
self.encoder.forward(v)
|
||||
)
|
||||
error = out - v
|
||||
self.encoder.backprop(
|
||||
self.decoder.backprop(error)
|
||||
)
|
||||
return np.sum(np.abs(error)) / len(v)
|
||||
|
||||
def encode(self, v: np.ndarray) -> np.ndarray:
|
||||
return self.encoder.forward(v)
|
||||
|
||||
@@ -129,9 +135,13 @@ class ClassicalAutoencoder(AAutoencoder):
|
||||
|
||||
|
||||
class VariationalAutoencoder(AAutoencoder):
|
||||
plotter_cls = VAEPlotter
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity())
|
||||
self.KL_losses = []
|
||||
self.recon_losses = []
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join((
|
||||
@@ -143,15 +153,18 @@ class VariationalAutoencoder(AAutoencoder):
|
||||
))
|
||||
|
||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||
loss = 0
|
||||
kl_loss = 0
|
||||
recon_loss = 0
|
||||
for x in data_set:
|
||||
out = self.forward(x)[0]
|
||||
kl = self.sampler.DKL()
|
||||
loss += np.mean((out - x) ** 2)
|
||||
loss += kl
|
||||
return loss / len(data_set)
|
||||
recon_loss += np.mean((out - x) ** 2)
|
||||
kl_loss += kl
|
||||
kl_loss /= len(data_set)
|
||||
recon_loss /= len(data_set)
|
||||
return recon_loss, kl_loss
|
||||
|
||||
def train(self, v: np.ndarray) -> float:
|
||||
def train(self, v: np.ndarray) -> tuple[float, float]:
|
||||
out, _ = self.forward(v)
|
||||
error = out - v
|
||||
self.encoder.backprop(
|
||||
@@ -159,10 +172,62 @@ class VariationalAutoencoder(AAutoencoder):
|
||||
self.decoder.backprop(error)
|
||||
)
|
||||
)
|
||||
return np.mean(error ** 2) + self.sampler.DKL()
|
||||
return np.mean(error ** 2), self.sampler.DKL()
|
||||
|
||||
def train_dataset(self,
|
||||
data_set: list[np.ndarray],
|
||||
max_epoch: int,
|
||||
patience: int,
|
||||
display_loss: bool = False) -> list[float]:
|
||||
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
|
||||
recon_0, kl_0 = self.loss(data_set)
|
||||
self.recon_losses = [recon_0]
|
||||
self.KL_losses = [kl_0]
|
||||
epoch = 0
|
||||
no_improv = 0
|
||||
prev_loss = self.recon_losses[0] + self.KL_losses[0]
|
||||
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
|
||||
while True:
|
||||
lbar.set_description(
|
||||
f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} loss={float(prev_loss):.6f})", # noqa
|
||||
)
|
||||
lbar.update()
|
||||
dkl = 0
|
||||
recon = 0
|
||||
for x in tqdm(data_set, leave=False):
|
||||
recon_i, dkl_i = self.train(x)
|
||||
dkl += dkl_i
|
||||
recon += recon_i
|
||||
recon /= len(data_set)
|
||||
dkl /= len(data_set)
|
||||
loss = recon + dkl
|
||||
dloss = prev_loss - loss
|
||||
if dloss <= 0 or abs(dloss) < 1e-4:
|
||||
no_improv += 1
|
||||
else:
|
||||
no_improv = 0
|
||||
prev_loss = float(loss)
|
||||
self.recon_losses.append(recon)
|
||||
self.KL_losses.append(dkl)
|
||||
if no_improv > patience:
|
||||
break
|
||||
if epoch > max_epoch:
|
||||
break
|
||||
plotter.update()
|
||||
epoch += 1
|
||||
plotter.close()
|
||||
return self.recon_losses
|
||||
|
||||
def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
code = self.encoder.forward(v)
|
||||
sample = self.sampler.forward(code)
|
||||
out = self.decoder.forward(sample)
|
||||
return out, code
|
||||
|
||||
def encode(self, v: np.ndarray) -> np.ndarray:
|
||||
return self.sampler.forward(
|
||||
self.encoder.forward(v)
|
||||
)
|
||||
|
||||
def decode(self, v: np.ndarray) -> np.ndarray:
|
||||
return self.decoder.forward(v)
|
||||
|
||||
@@ -2,7 +2,7 @@ import matplotlib.pyplot as plt
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .autoencoder import AAutoencoder
|
||||
from .autoencoder import AAutoencoder, VariationalAutoencoder
|
||||
|
||||
|
||||
class Plotter:
|
||||
@@ -46,3 +46,48 @@ class CAPlotter(Plotter):
|
||||
def close(self):
|
||||
plt.ioff()
|
||||
plt.show()
|
||||
|
||||
|
||||
class VAEPlotter(Plotter):
|
||||
def __init__(self, autoencoder: 'VariationalAutoencoder'):
|
||||
self.autoencoder = autoencoder
|
||||
plt.ion()
|
||||
self.fig, (self.ax_recon, self.ax_dkl) = plt.subplots(1, 2)
|
||||
self.line, = self.ax_recon.plot(
|
||||
list(range(len(self.autoencoder.recon_losses))),
|
||||
self.autoencoder.recon_losses,
|
||||
label="Loss"
|
||||
)
|
||||
self.ax_recon.set_xlabel("Epoch")
|
||||
self.ax_recon.set_ylabel("Loss")
|
||||
self.ax_recon.set_title("Reconstruction MSE Loss")
|
||||
self.ax_recon.legend()
|
||||
|
||||
self.dkl_line, = self.ax_dkl.plot(
|
||||
list(range(len(self.autoencoder.KL_losses))),
|
||||
self.autoencoder.KL_losses,
|
||||
label="DKL Loss",
|
||||
)
|
||||
self.ax_dkl.set_xlabel("Epoch")
|
||||
self.ax_dkl.set_ylabel("Loss")
|
||||
self.ax_dkl.set_title("DKL Loss")
|
||||
self.ax_dkl.legend()
|
||||
self.update()
|
||||
|
||||
def update(self):
|
||||
self.line.set_xdata(range(len(self.autoencoder.recon_losses)))
|
||||
self.line.set_ydata(self.autoencoder.recon_losses)
|
||||
self.ax_recon.relim()
|
||||
self.ax_recon.autoscale_view()
|
||||
|
||||
self.dkl_line.set_xdata(range(len(self.autoencoder.KL_losses)))
|
||||
self.dkl_line.set_ydata(self.autoencoder.KL_losses)
|
||||
self.ax_dkl.relim()
|
||||
self.ax_dkl.autoscale_view()
|
||||
|
||||
plt.draw()
|
||||
plt.pause(0.1)
|
||||
|
||||
def close(self):
|
||||
plt.ioff()
|
||||
plt.show()
|
||||
|
||||
Reference in New Issue
Block a user