refactor: move kb interrupt handling to autoencoder classes
This commit is contained in:
@@ -3,6 +3,7 @@ from tqdm import tqdm
|
||||
from .layers import DeepNNLayer, SampleLayer
|
||||
from .activations import ActivationFunc, Identity
|
||||
from .plotters import Plotter, CAPlotter, VAEPlotter
|
||||
from .utils import interruptable
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
||||
@@ -86,6 +87,7 @@ class ClassicalAutoencoder(AAutoencoder):
|
||||
)
|
||||
return np.sum(np.abs(error)) / len(v)
|
||||
|
||||
@interruptable
|
||||
def train_dataset(self,
|
||||
data_set: list[np.ndarray],
|
||||
max_epoch: int,
|
||||
@@ -119,8 +121,6 @@ class ClassicalAutoencoder(AAutoencoder):
|
||||
break
|
||||
plotter.update()
|
||||
epoch += 1
|
||||
plotter.close()
|
||||
return self.losses
|
||||
|
||||
def encode(self, v: np.ndarray) -> np.ndarray:
|
||||
return self.encoder.forward(v)
|
||||
@@ -174,6 +174,7 @@ class VariationalAutoencoder(AAutoencoder):
|
||||
)
|
||||
return np.mean(error ** 2), self.sampler.DKL()
|
||||
|
||||
@interruptable
|
||||
def train_dataset(self,
|
||||
data_set: list[np.ndarray],
|
||||
max_epoch: int,
|
||||
@@ -215,8 +216,6 @@ class VariationalAutoencoder(AAutoencoder):
|
||||
break
|
||||
plotter.update()
|
||||
epoch += 1
|
||||
plotter.close()
|
||||
return self.recon_losses
|
||||
|
||||
def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
code = self.encoder.forward(v)
|
||||
|
||||
Reference in New Issue
Block a user