diff --git a/examples/mnist_test.py b/examples/mnist_test.py index e0e37f5..160cb16 100644 --- a/examples/mnist_test.py +++ b/examples/mnist_test.py @@ -1,7 +1,6 @@ import matplotlib.pyplot as plt import numpy as np import os -import signal from easyvae.autoencoder import ( # noqa VariationalAutoencoder, ClassicalAutoencoder, @@ -32,7 +31,6 @@ def mnist_train( x_train.resize(x_train.shape[0], in_len) x_test.resize(x_test.shape[0], in_len) x_train = x_train / 255 - x_train = x_train[:5000] if os.path.exists(filename): autoencoder = cls.load(filename) else: @@ -42,17 +40,7 @@ def mnist_train( 0.0001, LeakyReLU() ) - - def handler(signum, frame): - print(f"Saving {filename} before exit ...") - autoencoder.save(filename) - plt.close('all') - plt.ioff() - mnist_test(autoencoder) - exit() - - signal.signal(signal.SIGINT, handler) - print("CTRL+C to exit and save model.") + print("CTRL+C to interrupt training.") autoencoder.train_dataset( x_train, max_epoch, @@ -100,7 +88,7 @@ def plot_random_reconstruction( output.reshape(img_shape), fignum=False) plt.title(f"Output ({y})") - print(f'{code=}') + print(f'{code.tolist()}') def mnist_test(model: str | AAutoencoder): diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index 652e664..3926481 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -3,6 +3,7 @@ from tqdm import tqdm from .layers import DeepNNLayer, SampleLayer from .activations import ActivationFunc, Identity from .plotters import Plotter, CAPlotter, VAEPlotter +from .utils import interruptable from abc import ABC, abstractmethod LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿'] @@ -86,6 +87,7 @@ class ClassicalAutoencoder(AAutoencoder): ) return np.sum(np.abs(error)) / len(v) + @interruptable def train_dataset(self, data_set: list[np.ndarray], max_epoch: int, @@ -119,8 +121,6 @@ class ClassicalAutoencoder(AAutoencoder): break plotter.update() epoch += 1 - plotter.close() - return self.losses def encode(self, v: np.ndarray) -> np.ndarray: return self.encoder.forward(v) @@ -174,6 +174,7 @@ class VariationalAutoencoder(AAutoencoder): ) return np.mean(error ** 2), self.sampler.DKL() + @interruptable def train_dataset(self, data_set: list[np.ndarray], max_epoch: int, @@ -215,8 +216,6 @@ class VariationalAutoencoder(AAutoencoder): break plotter.update() epoch += 1 - plotter.close() - return self.recon_losses def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]: code = self.encoder.forward(v) diff --git a/src/easyvae/plotters.py b/src/easyvae/plotters.py index e4ec493..6750b43 100644 --- a/src/easyvae/plotters.py +++ b/src/easyvae/plotters.py @@ -15,7 +15,7 @@ class Plotter: def close(self): pass - def __exit__(self, exc_type, exc_val, exc_tb): + def __del__(self): self.close() @@ -45,7 +45,7 @@ class CAPlotter(Plotter): def close(self): plt.ioff() - plt.show() + plt.close(self.fig) class VAEPlotter(Plotter): @@ -90,4 +90,4 @@ class VAEPlotter(Plotter): def close(self): plt.ioff() - plt.show() + plt.close(self.fig) diff --git a/src/easyvae/utils.py b/src/easyvae/utils.py index 2d10966..414db1a 100644 --- a/src/easyvae/utils.py +++ b/src/easyvae/utils.py @@ -18,3 +18,12 @@ def regularize(v: np.ndarray) -> np.ndarray: if v_min - v_max == 0: return v return (v - v_min) / (v_max - v_min) + + +def interruptable(func): + def inner(*args, **kwargs): + try: + return func(*args, **kwargs) + except KeyboardInterrupt: + pass + return inner