refactor: move plot logic to plotters.py

This commit is contained in:
Lenoctambule
2026-04-09 22:47:22 +02:00
parent 9d718a6bc8
commit ea8a4079ac
6 changed files with 81 additions and 51 deletions

View File

@@ -1,18 +1,16 @@
import numpy as np
from .utils import (
dynamic_loss_plot_init,
dynamic_loss_plot_update,
dynamic_loss_plot_finish
)
from tqdm import tqdm
from .layers import DeepNNLayer, SampleLayer
from .activations import ActivationFunc, Identity
from .plotters import Plotter, CAPlotter
from abc import ABC, abstractmethod
LOADER = ['', '', '', '', '', '', '', '']
class AAutoencoder(ABC):
plotter_cls = Plotter
@abstractmethod
def __init__(self,
encoder_layers: list[int],
@@ -27,18 +25,18 @@ class AAutoencoder(ABC):
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
self.space_dim = decoder_layers[0]
self.lr = lr
self.losses = [0]
def train_dataset(self,
data_set: list[np.ndarray],
max_epoch: int,
patience: int,
display_loss: bool = False) -> list[float]:
losses = [self.loss(data_set)]
if display_loss is True:
ax, line = dynamic_loss_plot_init(losses)
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
self.losses = [self.loss(data_set)]
epoch = 0
no_improv = 0
prev_error = losses[0]
prev_error = self.losses[0]
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
while True:
lbar.set_description(
@@ -55,17 +53,15 @@ class AAutoencoder(ABC):
else:
no_improv = 0
prev_error = float(error)
losses.append(error)
if display_loss is True:
dynamic_loss_plot_update(ax, line, losses)
self.losses.append(error)
if no_improv > patience:
break
if epoch > max_epoch:
break
plotter.update()
epoch += 1
if display_loss is True:
dynamic_loss_plot_finish()
return losses
plotter.close()
return self.losses
def save(self, path: str):
path = path.removesuffix('.npy')
@@ -90,11 +86,19 @@ class AAutoencoder(ABC):
class ClassicalAutoencoder(AAutoencoder):
plotter_cls = CAPlotter
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __str__(self):
return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}'
return "\n".join((
f"Type: {__class__.__name__}",
"Encoder:",
f"{self.encoder}",
"Decoder:"
f"{self.decoder}"
))
def loss(self, data_set: list[np.ndarray]) -> float:
loss = 0