refactor: move kb interrupt handling to autoencoder classes
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
import signal
|
||||
from easyvae.autoencoder import ( # noqa
|
||||
VariationalAutoencoder,
|
||||
ClassicalAutoencoder,
|
||||
@@ -32,7 +31,6 @@ def mnist_train(
|
||||
x_train.resize(x_train.shape[0], in_len)
|
||||
x_test.resize(x_test.shape[0], in_len)
|
||||
x_train = x_train / 255
|
||||
x_train = x_train[:5000]
|
||||
if os.path.exists(filename):
|
||||
autoencoder = cls.load(filename)
|
||||
else:
|
||||
@@ -42,17 +40,7 @@ def mnist_train(
|
||||
0.0001,
|
||||
LeakyReLU()
|
||||
)
|
||||
|
||||
def handler(signum, frame):
|
||||
print(f"Saving {filename} before exit ...")
|
||||
autoencoder.save(filename)
|
||||
plt.close('all')
|
||||
plt.ioff()
|
||||
mnist_test(autoencoder)
|
||||
exit()
|
||||
|
||||
signal.signal(signal.SIGINT, handler)
|
||||
print("CTRL+C to exit and save model.")
|
||||
print("CTRL+C to interrupt training.")
|
||||
autoencoder.train_dataset(
|
||||
x_train,
|
||||
max_epoch,
|
||||
@@ -100,7 +88,7 @@ def plot_random_reconstruction(
|
||||
output.reshape(img_shape),
|
||||
fignum=False)
|
||||
plt.title(f"Output ({y})")
|
||||
print(f'{code=}')
|
||||
print(f'{code.tolist()}')
|
||||
|
||||
|
||||
def mnist_test(model: str | AAutoencoder):
|
||||
|
||||
@@ -3,6 +3,7 @@ from tqdm import tqdm
|
||||
from .layers import DeepNNLayer, SampleLayer
|
||||
from .activations import ActivationFunc, Identity
|
||||
from .plotters import Plotter, CAPlotter, VAEPlotter
|
||||
from .utils import interruptable
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
||||
@@ -86,6 +87,7 @@ class ClassicalAutoencoder(AAutoencoder):
|
||||
)
|
||||
return np.sum(np.abs(error)) / len(v)
|
||||
|
||||
@interruptable
|
||||
def train_dataset(self,
|
||||
data_set: list[np.ndarray],
|
||||
max_epoch: int,
|
||||
@@ -119,8 +121,6 @@ class ClassicalAutoencoder(AAutoencoder):
|
||||
break
|
||||
plotter.update()
|
||||
epoch += 1
|
||||
plotter.close()
|
||||
return self.losses
|
||||
|
||||
def encode(self, v: np.ndarray) -> np.ndarray:
|
||||
return self.encoder.forward(v)
|
||||
@@ -174,6 +174,7 @@ class VariationalAutoencoder(AAutoencoder):
|
||||
)
|
||||
return np.mean(error ** 2), self.sampler.DKL()
|
||||
|
||||
@interruptable
|
||||
def train_dataset(self,
|
||||
data_set: list[np.ndarray],
|
||||
max_epoch: int,
|
||||
@@ -215,8 +216,6 @@ class VariationalAutoencoder(AAutoencoder):
|
||||
break
|
||||
plotter.update()
|
||||
epoch += 1
|
||||
plotter.close()
|
||||
return self.recon_losses
|
||||
|
||||
def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
code = self.encoder.forward(v)
|
||||
|
||||
@@ -15,7 +15,7 @@ class Plotter:
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class CAPlotter(Plotter):
|
||||
|
||||
def close(self):
|
||||
plt.ioff()
|
||||
plt.show()
|
||||
plt.close(self.fig)
|
||||
|
||||
|
||||
class VAEPlotter(Plotter):
|
||||
@@ -90,4 +90,4 @@ class VAEPlotter(Plotter):
|
||||
|
||||
def close(self):
|
||||
plt.ioff()
|
||||
plt.show()
|
||||
plt.close(self.fig)
|
||||
|
||||
@@ -18,3 +18,12 @@ def regularize(v: np.ndarray) -> np.ndarray:
|
||||
if v_min - v_max == 0:
|
||||
return v
|
||||
return (v - v_min) / (v_max - v_min)
|
||||
|
||||
|
||||
def interruptable(func):
|
||||
def inner(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
return inner
|
||||
|
||||
Reference in New Issue
Block a user