feat(plotters.py): add VAEPlotter class + seperate training logic
This commit is contained in:
@@ -2,7 +2,7 @@ import numpy as np
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from .layers import DeepNNLayer, SampleLayer
|
from .layers import DeepNNLayer, SampleLayer
|
||||||
from .activations import ActivationFunc, Identity
|
from .activations import ActivationFunc, Identity
|
||||||
from .plotters import Plotter, CAPlotter
|
from .plotters import Plotter, CAPlotter, VAEPlotter
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
||||||
@@ -27,6 +27,65 @@ class AAutoencoder(ABC):
|
|||||||
self.lr = lr
|
self.lr = lr
|
||||||
self.losses = [0]
|
self.losses = [0]
|
||||||
|
|
||||||
|
def save(self, path: str):
|
||||||
|
path = path.removesuffix('.npy')
|
||||||
|
np.save(path, self)
|
||||||
|
|
||||||
|
def load(path: str) -> 'ClassicalAutoencoder':
|
||||||
|
path = path.removesuffix('.npy') + '.npy'
|
||||||
|
data = np.load(path, allow_pickle=True)
|
||||||
|
return data.item()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def train(self, v: np.ndarray) -> float:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(self, v: np.ndarray) -> np.ndarray:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def train_dataset(self, *args, **kwargs) -> list[float]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ClassicalAutoencoder(AAutoencoder):
|
||||||
|
plotter_cls = CAPlotter
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.losses = []
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "\n".join((
|
||||||
|
f"Type: {__class__.__name__}",
|
||||||
|
"Encoder:",
|
||||||
|
f"{self.encoder}",
|
||||||
|
"Decoder:",
|
||||||
|
f"{self.decoder}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||||
|
loss = 0
|
||||||
|
for x in data_set:
|
||||||
|
loss += np.sum(np.abs(x - self.forward(x)[0])) / len(x)
|
||||||
|
return loss / len(data_set)
|
||||||
|
|
||||||
|
def train(self, v: np.ndarray):
|
||||||
|
out = self.decoder.forward(
|
||||||
|
self.encoder.forward(v)
|
||||||
|
)
|
||||||
|
error = out - v
|
||||||
|
self.encoder.backprop(
|
||||||
|
self.decoder.backprop(error)
|
||||||
|
)
|
||||||
|
return np.sum(np.abs(error)) / len(v)
|
||||||
|
|
||||||
def train_dataset(self,
|
def train_dataset(self,
|
||||||
data_set: list[np.ndarray],
|
data_set: list[np.ndarray],
|
||||||
max_epoch: int,
|
max_epoch: int,
|
||||||
@@ -63,59 +122,6 @@ class AAutoencoder(ABC):
|
|||||||
plotter.close()
|
plotter.close()
|
||||||
return self.losses
|
return self.losses
|
||||||
|
|
||||||
def save(self, path: str):
|
|
||||||
path = path.removesuffix('.npy')
|
|
||||||
np.save(path, self)
|
|
||||||
|
|
||||||
def load(path: str) -> 'ClassicalAutoencoder':
|
|
||||||
path = path.removesuffix('.npy') + '.npy'
|
|
||||||
data = np.load(path, allow_pickle=True)
|
|
||||||
return data.item()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def train(self, v: np.ndarray) -> float:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def forward(self, v: np.ndarray) -> np.ndarray:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ClassicalAutoencoder(AAutoencoder):
|
|
||||||
plotter_cls = CAPlotter
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "\n".join((
|
|
||||||
f"Type: {__class__.__name__}",
|
|
||||||
"Encoder:",
|
|
||||||
f"{self.encoder}",
|
|
||||||
"Decoder:",
|
|
||||||
f"{self.decoder}"
|
|
||||||
))
|
|
||||||
|
|
||||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
|
||||||
loss = 0
|
|
||||||
for x in data_set:
|
|
||||||
loss += np.sum(np.abs(x - self.forward(x)[0])) / len(x)
|
|
||||||
return loss / len(data_set)
|
|
||||||
|
|
||||||
def train(self, v: np.ndarray):
|
|
||||||
out = self.decoder.forward(
|
|
||||||
self.encoder.forward(v)
|
|
||||||
)
|
|
||||||
error = out - v
|
|
||||||
self.encoder.backprop(
|
|
||||||
self.decoder.backprop(error)
|
|
||||||
)
|
|
||||||
return np.sum(np.abs(error)) / len(v)
|
|
||||||
|
|
||||||
def encode(self, v: np.ndarray) -> np.ndarray:
|
def encode(self, v: np.ndarray) -> np.ndarray:
|
||||||
return self.encoder.forward(v)
|
return self.encoder.forward(v)
|
||||||
|
|
||||||
@@ -129,9 +135,13 @@ class ClassicalAutoencoder(AAutoencoder):
|
|||||||
|
|
||||||
|
|
||||||
class VariationalAutoencoder(AAutoencoder):
|
class VariationalAutoencoder(AAutoencoder):
|
||||||
|
plotter_cls = VAEPlotter
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity())
|
self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity())
|
||||||
|
self.KL_losses = []
|
||||||
|
self.recon_losses = []
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "\n".join((
|
return "\n".join((
|
||||||
@@ -143,15 +153,18 @@ class VariationalAutoencoder(AAutoencoder):
|
|||||||
))
|
))
|
||||||
|
|
||||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||||
loss = 0
|
kl_loss = 0
|
||||||
|
recon_loss = 0
|
||||||
for x in data_set:
|
for x in data_set:
|
||||||
out = self.forward(x)[0]
|
out = self.forward(x)[0]
|
||||||
kl = self.sampler.DKL()
|
kl = self.sampler.DKL()
|
||||||
loss += np.mean((out - x) ** 2)
|
recon_loss += np.mean((out - x) ** 2)
|
||||||
loss += kl
|
kl_loss += kl
|
||||||
return loss / len(data_set)
|
kl_loss /= len(data_set)
|
||||||
|
recon_loss /= len(data_set)
|
||||||
|
return recon_loss, kl_loss
|
||||||
|
|
||||||
def train(self, v: np.ndarray) -> float:
|
def train(self, v: np.ndarray) -> tuple[float, float]:
|
||||||
out, _ = self.forward(v)
|
out, _ = self.forward(v)
|
||||||
error = out - v
|
error = out - v
|
||||||
self.encoder.backprop(
|
self.encoder.backprop(
|
||||||
@@ -159,10 +172,62 @@ class VariationalAutoencoder(AAutoencoder):
|
|||||||
self.decoder.backprop(error)
|
self.decoder.backprop(error)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return np.mean(error ** 2) + self.sampler.DKL()
|
return np.mean(error ** 2), self.sampler.DKL()
|
||||||
|
|
||||||
|
def train_dataset(self,
|
||||||
|
data_set: list[np.ndarray],
|
||||||
|
max_epoch: int,
|
||||||
|
patience: int,
|
||||||
|
display_loss: bool = False) -> list[float]:
|
||||||
|
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
|
||||||
|
recon_0, kl_0 = self.loss(data_set)
|
||||||
|
self.recon_losses = [recon_0]
|
||||||
|
self.KL_losses = [kl_0]
|
||||||
|
epoch = 0
|
||||||
|
no_improv = 0
|
||||||
|
prev_loss = self.recon_losses[0] + self.KL_losses[0]
|
||||||
|
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
|
||||||
|
while True:
|
||||||
|
lbar.set_description(
|
||||||
|
f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} loss={float(prev_loss):.6f})", # noqa
|
||||||
|
)
|
||||||
|
lbar.update()
|
||||||
|
dkl = 0
|
||||||
|
recon = 0
|
||||||
|
for x in tqdm(data_set, leave=False):
|
||||||
|
recon_i, dkl_i = self.train(x)
|
||||||
|
dkl += dkl_i
|
||||||
|
recon += recon_i
|
||||||
|
recon /= len(data_set)
|
||||||
|
dkl /= len(data_set)
|
||||||
|
loss = recon + dkl
|
||||||
|
dloss = prev_loss - loss
|
||||||
|
if dloss <= 0 or abs(dloss) < 1e-4:
|
||||||
|
no_improv += 1
|
||||||
|
else:
|
||||||
|
no_improv = 0
|
||||||
|
prev_loss = float(loss)
|
||||||
|
self.recon_losses.append(recon)
|
||||||
|
self.KL_losses.append(dkl)
|
||||||
|
if no_improv > patience:
|
||||||
|
break
|
||||||
|
if epoch > max_epoch:
|
||||||
|
break
|
||||||
|
plotter.update()
|
||||||
|
epoch += 1
|
||||||
|
plotter.close()
|
||||||
|
return self.recon_losses
|
||||||
|
|
||||||
def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||||
code = self.encoder.forward(v)
|
code = self.encoder.forward(v)
|
||||||
sample = self.sampler.forward(code)
|
sample = self.sampler.forward(code)
|
||||||
out = self.decoder.forward(sample)
|
out = self.decoder.forward(sample)
|
||||||
return out, code
|
return out, code
|
||||||
|
|
||||||
|
def encode(self, v: np.ndarray) -> np.ndarray:
|
||||||
|
return self.sampler.forward(
|
||||||
|
self.encoder.forward(v)
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, v: np.ndarray) -> np.ndarray:
|
||||||
|
return self.decoder.forward(v)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import matplotlib.pyplot as plt
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .autoencoder import AAutoencoder
|
from .autoencoder import AAutoencoder, VariationalAutoencoder
|
||||||
|
|
||||||
|
|
||||||
class Plotter:
|
class Plotter:
|
||||||
@@ -46,3 +46,48 @@ class CAPlotter(Plotter):
|
|||||||
def close(self):
|
def close(self):
|
||||||
plt.ioff()
|
plt.ioff()
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
class VAEPlotter(Plotter):
|
||||||
|
def __init__(self, autoencoder: 'VariationalAutoencoder'):
|
||||||
|
self.autoencoder = autoencoder
|
||||||
|
plt.ion()
|
||||||
|
self.fig, (self.ax_recon, self.ax_dkl) = plt.subplots(1, 2)
|
||||||
|
self.line, = self.ax_recon.plot(
|
||||||
|
list(range(len(self.autoencoder.recon_losses))),
|
||||||
|
self.autoencoder.recon_losses,
|
||||||
|
label="Loss"
|
||||||
|
)
|
||||||
|
self.ax_recon.set_xlabel("Epoch")
|
||||||
|
self.ax_recon.set_ylabel("Loss")
|
||||||
|
self.ax_recon.set_title("Reconstruction MSE Loss")
|
||||||
|
self.ax_recon.legend()
|
||||||
|
|
||||||
|
self.dkl_line, = self.ax_dkl.plot(
|
||||||
|
list(range(len(self.autoencoder.KL_losses))),
|
||||||
|
self.autoencoder.KL_losses,
|
||||||
|
label="DKL Loss",
|
||||||
|
)
|
||||||
|
self.ax_dkl.set_xlabel("Epoch")
|
||||||
|
self.ax_dkl.set_ylabel("Loss")
|
||||||
|
self.ax_dkl.set_title("DKL Loss")
|
||||||
|
self.ax_dkl.legend()
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
self.line.set_xdata(range(len(self.autoencoder.recon_losses)))
|
||||||
|
self.line.set_ydata(self.autoencoder.recon_losses)
|
||||||
|
self.ax_recon.relim()
|
||||||
|
self.ax_recon.autoscale_view()
|
||||||
|
|
||||||
|
self.dkl_line.set_xdata(range(len(self.autoencoder.KL_losses)))
|
||||||
|
self.dkl_line.set_ydata(self.autoencoder.KL_losses)
|
||||||
|
self.ax_dkl.relim()
|
||||||
|
self.ax_dkl.autoscale_view()
|
||||||
|
|
||||||
|
plt.draw()
|
||||||
|
plt.pause(0.1)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
plt.ioff()
|
||||||
|
plt.show()
|
||||||
|
|||||||
Reference in New Issue
Block a user