feat: plot 2d latent space + signal handling + fix SGD in Sampler
This commit is contained in:
@@ -4,7 +4,7 @@ from utils import (dynamic_loss_plot_init,
|
||||
dynamic_loss_plot_finish)
|
||||
from tqdm import tqdm
|
||||
from layers import DeepNNLayer, SampleLayer
|
||||
from activations import ActivationFunc
|
||||
from activations import ActivationFunc, Identity
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
||||
@@ -46,7 +46,6 @@ class AAutoencoder(ABC):
|
||||
if epoch > max_epoch:
|
||||
break
|
||||
epoch += 1
|
||||
print("Training complete !")
|
||||
if display_loss is True:
|
||||
dynamic_loss_plot_finish(ax, line)
|
||||
return losses
|
||||
@@ -129,12 +128,15 @@ class VariationalAutoencoder(AAutoencoder):
|
||||
)
|
||||
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
||||
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
||||
self.sampler = SampleLayer(self.encoder.out_size, lr, activation_func)
|
||||
self.sampler = SampleLayer(self.encoder.out_size, lr, Identity())
|
||||
|
||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||
loss = 0
|
||||
for x in data_set:
|
||||
loss += np.sum(np.abs(x - self.forward(x)[0])) / len(x)
|
||||
out = self.forward(x)[0]
|
||||
kl = self.sampler.DKL()
|
||||
loss += np.mean((out - x) ** 2)
|
||||
loss += kl
|
||||
return loss / len(data_set)
|
||||
|
||||
def train(self, v: np.ndarray) -> float:
|
||||
@@ -145,7 +147,7 @@ class VariationalAutoencoder(AAutoencoder):
|
||||
self.decoder.backprop(error)
|
||||
)
|
||||
)
|
||||
return np.sum(np.abs(error)) / len(v)
|
||||
return np.mean(error ** 2) + self.sampler.DKL()
|
||||
|
||||
def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
code = self.encoder.forward(v)
|
||||
|
||||
Reference in New Issue
Block a user