refactor: move kb interrupt handling to autoencoder classes

This commit is contained in:
Lenoctambule
2026-04-10 22:20:35 +02:00
parent 5ff6cfe55e
commit 7a822782a5
4 changed files with 17 additions and 21 deletions

View File

@@ -3,6 +3,7 @@ from tqdm import tqdm
from .layers import DeepNNLayer, SampleLayer
from .activations import ActivationFunc, Identity
from .plotters import Plotter, CAPlotter, VAEPlotter
from .utils import interruptable
from abc import ABC, abstractmethod
LOADER = ['', '', '', '', '', '', '', '']
@@ -86,6 +87,7 @@ class ClassicalAutoencoder(AAutoencoder):
)
return np.sum(np.abs(error)) / len(v)
@interruptable
def train_dataset(self,
data_set: list[np.ndarray],
max_epoch: int,
@@ -119,8 +121,6 @@ class ClassicalAutoencoder(AAutoencoder):
break
plotter.update()
epoch += 1
plotter.close()
return self.losses
def encode(self, v: np.ndarray) -> np.ndarray:
return self.encoder.forward(v)
@@ -174,6 +174,7 @@ class VariationalAutoencoder(AAutoencoder):
)
return np.mean(error ** 2), self.sampler.DKL()
@interruptable
def train_dataset(self,
data_set: list[np.ndarray],
max_epoch: int,
@@ -215,8 +216,6 @@ class VariationalAutoencoder(AAutoencoder):
break
plotter.update()
epoch += 1
plotter.close()
return self.recon_losses
def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
code = self.encoder.forward(v)