refactor: move plot logic to plotters.py
This commit is contained in:
@@ -1,18 +1,16 @@
|
||||
import numpy as np
|
||||
from .utils import (
|
||||
dynamic_loss_plot_init,
|
||||
dynamic_loss_plot_update,
|
||||
dynamic_loss_plot_finish
|
||||
)
|
||||
from tqdm import tqdm
|
||||
from .layers import DeepNNLayer, SampleLayer
|
||||
from .activations import ActivationFunc, Identity
|
||||
from .plotters import Plotter, CAPlotter
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
||||
|
||||
|
||||
class AAutoencoder(ABC):
|
||||
plotter_cls = Plotter
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self,
|
||||
encoder_layers: list[int],
|
||||
@@ -27,18 +25,18 @@ class AAutoencoder(ABC):
|
||||
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
||||
self.space_dim = decoder_layers[0]
|
||||
self.lr = lr
|
||||
self.losses = [0]
|
||||
|
||||
def train_dataset(self,
|
||||
data_set: list[np.ndarray],
|
||||
max_epoch: int,
|
||||
patience: int,
|
||||
display_loss: bool = False) -> list[float]:
|
||||
losses = [self.loss(data_set)]
|
||||
if display_loss is True:
|
||||
ax, line = dynamic_loss_plot_init(losses)
|
||||
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
|
||||
self.losses = [self.loss(data_set)]
|
||||
epoch = 0
|
||||
no_improv = 0
|
||||
prev_error = losses[0]
|
||||
prev_error = self.losses[0]
|
||||
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
|
||||
while True:
|
||||
lbar.set_description(
|
||||
@@ -55,17 +53,15 @@ class AAutoencoder(ABC):
|
||||
else:
|
||||
no_improv = 0
|
||||
prev_error = float(error)
|
||||
losses.append(error)
|
||||
if display_loss is True:
|
||||
dynamic_loss_plot_update(ax, line, losses)
|
||||
self.losses.append(error)
|
||||
if no_improv > patience:
|
||||
break
|
||||
if epoch > max_epoch:
|
||||
break
|
||||
plotter.update()
|
||||
epoch += 1
|
||||
if display_loss is True:
|
||||
dynamic_loss_plot_finish()
|
||||
return losses
|
||||
plotter.close()
|
||||
return self.losses
|
||||
|
||||
def save(self, path: str):
|
||||
path = path.removesuffix('.npy')
|
||||
@@ -90,11 +86,19 @@ class AAutoencoder(ABC):
|
||||
|
||||
|
||||
class ClassicalAutoencoder(AAutoencoder):
|
||||
plotter_cls = CAPlotter
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}'
|
||||
return "\n".join((
|
||||
f"Type: {__class__.__name__}",
|
||||
"Encoder:",
|
||||
f"{self.encoder}",
|
||||
"Decoder:"
|
||||
f"{self.decoder}"
|
||||
))
|
||||
|
||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||
loss = 0
|
||||
|
||||
Reference in New Issue
Block a user