refactor: move kb interrupt handling to autoencoder classes

This commit is contained in:
Lenoctambule
2026-04-10 22:20:35 +02:00
parent 5ff6cfe55e
commit 7a822782a5
4 changed files with 17 additions and 21 deletions

View File

@@ -1,7 +1,6 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import os import os
import signal
from easyvae.autoencoder import ( # noqa from easyvae.autoencoder import ( # noqa
VariationalAutoencoder, VariationalAutoencoder,
ClassicalAutoencoder, ClassicalAutoencoder,
@@ -32,7 +31,6 @@ def mnist_train(
x_train.resize(x_train.shape[0], in_len) x_train.resize(x_train.shape[0], in_len)
x_test.resize(x_test.shape[0], in_len) x_test.resize(x_test.shape[0], in_len)
x_train = x_train / 255 x_train = x_train / 255
x_train = x_train[:5000]
if os.path.exists(filename): if os.path.exists(filename):
autoencoder = cls.load(filename) autoencoder = cls.load(filename)
else: else:
@@ -42,17 +40,7 @@ def mnist_train(
0.0001, 0.0001,
LeakyReLU() LeakyReLU()
) )
print("CTRL+C to interrupt training.")
def handler(signum, frame):
print(f"Saving {filename} before exit ...")
autoencoder.save(filename)
plt.close('all')
plt.ioff()
mnist_test(autoencoder)
exit()
signal.signal(signal.SIGINT, handler)
print("CTRL+C to exit and save model.")
autoencoder.train_dataset( autoencoder.train_dataset(
x_train, x_train,
max_epoch, max_epoch,
@@ -100,7 +88,7 @@ def plot_random_reconstruction(
output.reshape(img_shape), output.reshape(img_shape),
fignum=False) fignum=False)
plt.title(f"Output ({y})") plt.title(f"Output ({y})")
print(f'{code=}') print(f'{code.tolist()}')
def mnist_test(model: str | AAutoencoder): def mnist_test(model: str | AAutoencoder):

View File

@@ -3,6 +3,7 @@ from tqdm import tqdm
from .layers import DeepNNLayer, SampleLayer from .layers import DeepNNLayer, SampleLayer
from .activations import ActivationFunc, Identity from .activations import ActivationFunc, Identity
from .plotters import Plotter, CAPlotter, VAEPlotter from .plotters import Plotter, CAPlotter, VAEPlotter
from .utils import interruptable
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
LOADER = ['', '', '', '', '', '', '', ''] LOADER = ['', '', '', '', '', '', '', '']
@@ -86,6 +87,7 @@ class ClassicalAutoencoder(AAutoencoder):
) )
return np.sum(np.abs(error)) / len(v) return np.sum(np.abs(error)) / len(v)
@interruptable
def train_dataset(self, def train_dataset(self,
data_set: list[np.ndarray], data_set: list[np.ndarray],
max_epoch: int, max_epoch: int,
@@ -119,8 +121,6 @@ class ClassicalAutoencoder(AAutoencoder):
break break
plotter.update() plotter.update()
epoch += 1 epoch += 1
plotter.close()
return self.losses
def encode(self, v: np.ndarray) -> np.ndarray: def encode(self, v: np.ndarray) -> np.ndarray:
return self.encoder.forward(v) return self.encoder.forward(v)
@@ -174,6 +174,7 @@ class VariationalAutoencoder(AAutoencoder):
) )
return np.mean(error ** 2), self.sampler.DKL() return np.mean(error ** 2), self.sampler.DKL()
@interruptable
def train_dataset(self, def train_dataset(self,
data_set: list[np.ndarray], data_set: list[np.ndarray],
max_epoch: int, max_epoch: int,
@@ -215,8 +216,6 @@ class VariationalAutoencoder(AAutoencoder):
break break
plotter.update() plotter.update()
epoch += 1 epoch += 1
plotter.close()
return self.recon_losses
def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]: def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
code = self.encoder.forward(v) code = self.encoder.forward(v)

View File

@@ -15,7 +15,7 @@ class Plotter:
def close(self): def close(self):
pass pass
def __exit__(self, exc_type, exc_val, exc_tb): def __del__(self):
self.close() self.close()
@@ -45,7 +45,7 @@ class CAPlotter(Plotter):
def close(self): def close(self):
plt.ioff() plt.ioff()
plt.show() plt.close(self.fig)
class VAEPlotter(Plotter): class VAEPlotter(Plotter):
@@ -90,4 +90,4 @@ class VAEPlotter(Plotter):
def close(self): def close(self):
plt.ioff() plt.ioff()
plt.show() plt.close(self.fig)

View File

@@ -18,3 +18,12 @@ def regularize(v: np.ndarray) -> np.ndarray:
if v_min - v_max == 0: if v_min - v_max == 0:
return v return v
return (v - v_min) / (v_max - v_min) return (v - v_min) / (v_max - v_min)
def interruptable(func):
def inner(*args, **kwargs):
try:
return func(*args, **kwargs)
except KeyboardInterrupt:
pass
return inner