feat: NoiseLayer class + keep loss across trainings
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from .layers import DeepNNLayer, SampleLayer
|
||||
from .layers import DeepNNLayer, SampleLayer, NoiseLayer
|
||||
from .activations import ActivationFunc, Identity
|
||||
from .plotters import Plotter, CAPlotter, VAEPlotter
|
||||
from .utils import interruptable
|
||||
@@ -17,13 +17,15 @@ class AAutoencoder(ABC):
|
||||
encoder_layers: list[int],
|
||||
decoder_layers: list[int],
|
||||
lr: float,
|
||||
activation_func: ActivationFunc):
|
||||
activation_func: ActivationFunc,
|
||||
noise=0):
|
||||
if encoder_layers[-1] != decoder_layers[0]:
|
||||
raise Exception(
|
||||
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
|
||||
)
|
||||
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
||||
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
||||
self.noise = NoiseLayer(noise)
|
||||
self.space_dim = decoder_layers[0]
|
||||
self.lr = lr
|
||||
self.losses = [0]
|
||||
@@ -78,8 +80,8 @@ class ClassicalAutoencoder(AAutoencoder):
|
||||
return loss / len(data_set)
|
||||
|
||||
def train(self, v: np.ndarray):
|
||||
out = self.decoder.forward(
|
||||
self.encoder.forward(v)
|
||||
out, _ = self.forward(
|
||||
self.noise.forward(v)
|
||||
)
|
||||
error = out - v
|
||||
self.encoder.back(
|
||||
@@ -96,7 +98,8 @@ class ClassicalAutoencoder(AAutoencoder):
|
||||
patience: int,
|
||||
display_loss: bool = False) -> list[float]:
|
||||
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
|
||||
self.losses = [self.loss(data_set)]
|
||||
if len(self.losses) == 0:
|
||||
self.losses = [self.loss(data_set)]
|
||||
epoch = 0
|
||||
no_improv = 0
|
||||
prev_error = self.losses[0]
|
||||
@@ -167,7 +170,9 @@ class VariationalAutoencoder(AAutoencoder):
|
||||
return recon_loss, kl_loss
|
||||
|
||||
def train(self, v: np.ndarray) -> tuple[float, float]:
|
||||
out, _ = self.forward(v)
|
||||
out, _ = self.forward(
|
||||
self.noise.forward(v)
|
||||
)
|
||||
error = out - v
|
||||
self.encoder.back(
|
||||
self.sampler.back(
|
||||
@@ -186,9 +191,10 @@ class VariationalAutoencoder(AAutoencoder):
|
||||
patience: int,
|
||||
display_loss: bool = False) -> list[float]:
|
||||
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
|
||||
recon_0, kl_0 = self.loss(data_set)
|
||||
self.recon_losses = [recon_0]
|
||||
self.KL_losses = [kl_0]
|
||||
if len(self.recon_losses) == 0:
|
||||
recon_0, kl_0 = self.loss(data_set)
|
||||
self.recon_losses = [recon_0]
|
||||
self.KL_losses = [kl_0]
|
||||
epoch = 0
|
||||
no_improv = 0
|
||||
prev_loss = self.recon_losses[0] + self.KL_losses[0]
|
||||
|
||||
Reference in New Issue
Block a user