diff --git a/.gitignore b/.gitignore index 50eb41c..d199881 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ __pycache__ *.npy .venv dist -*.egg-info \ No newline at end of file +*.egg-info +.env \ No newline at end of file diff --git a/README.md b/README.md index 068406a..d038f33 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ src="https://raw.githubusercontent.com/lenoctambule/autoencoder/refs/heads/main/media/latent-space.png" alt="Latent-space of the MNIST dataset" width=70%> -
+

Latent-space representation of the MNIST dataset using Variational Autoencoder

diff --git a/examples/mnist_test.py b/examples/mnist_test.py index 5e68c9a..160cb16 100644 --- a/examples/mnist_test.py +++ b/examples/mnist_test.py @@ -1,14 +1,12 @@ import matplotlib.pyplot as plt import numpy as np import os -import signal from easyvae.autoencoder import ( # noqa VariationalAutoencoder, ClassicalAutoencoder, AAutoencoder ) from easyvae.activations import LeakyReLU -from easyvae.utils import dynamic_loss_plot_finish def load_mnist() -> list[np.ndarray]: @@ -42,17 +40,7 @@ def mnist_train( 0.0001, LeakyReLU() ) - - def handler(signum, frame): - print(f"Saving {filename} before exit ...") - autoencoder.save(filename) - if plt.get_fignums(): - dynamic_loss_plot_finish() - mnist_test(autoencoder) - exit() - - signal.signal(signal.SIGINT, handler) - print("CTRL+C to exit and save model.") + print("CTRL+C to interrupt training.") autoencoder.train_dataset( x_train, max_epoch, @@ -84,10 +72,11 @@ def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,): plt.show() -def plot_random_reconstruction(autoencoder: AAutoencoder, - example: np.ndarray, - img_shape, - y): +def plot_random_reconstruction( + autoencoder: AAutoencoder, + example: np.ndarray, + img_shape, + y): output, code = autoencoder.forward(example.flatten()) plt.subplot(1, 2, 1) plt.matshow( @@ -99,7 +88,7 @@ def plot_random_reconstruction(autoencoder: AAutoencoder, output.reshape(img_shape), fignum=False) plt.title(f"Output ({y})") - print(f'{code=}') + print(f'{code.tolist()}') def mnist_test(model: str | AAutoencoder): @@ -114,6 +103,8 @@ def mnist_test(model: str | AAutoencoder): autoencoder: AAutoencoder = AAutoencoder.load(model) else: autoencoder = model + print("Testing model ...\n") + print(autoencoder) idx = np.random.randint(0, len(x_test)) example: np.ndarray = x_test[idx] plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx]) diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index 5f436c9..3926481 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -1,18 +1,17 @@ import numpy as np -from .utils import ( - dynamic_loss_plot_init, - dynamic_loss_plot_update, - dynamic_loss_plot_finish - ) from tqdm import tqdm from .layers import DeepNNLayer, SampleLayer from .activations import ActivationFunc, Identity +from .plotters import Plotter, CAPlotter, VAEPlotter +from .utils import interruptable from abc import ABC, abstractmethod LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿'] class AAutoencoder(ABC): + plotter_cls = Plotter + @abstractmethod def __init__(self, encoder_layers: list[int], @@ -27,45 +26,7 @@ class AAutoencoder(ABC): self.decoder = DeepNNLayer(decoder_layers, lr, activation_func) self.space_dim = decoder_layers[0] self.lr = lr - - def train_dataset(self, - data_set: list[np.ndarray], - max_epoch: int, - patience: int, - display_loss: bool = False) -> list[float]: - losses = [self.loss(data_set)] - if display_loss is True: - ax, line = dynamic_loss_plot_init(losses) - epoch = 0 - no_improv = 0 - prev_error = losses[0] - with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar: - while True: - lbar.set_description( - f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} error={float(prev_error):.6f})", # noqa - ) - lbar.update() - error = 0 - for x in tqdm(data_set, leave=False): - error += self.train(x) - error /= len(data_set) - derror = prev_error - error - if derror <= 0 or abs(derror) < 1e-4: - no_improv += 1 - else: - no_improv = 0 - prev_error = float(error) - losses.append(error) - if display_loss is True: - dynamic_loss_plot_update(ax, line, losses) - if no_improv > patience: - break - if epoch > max_epoch: - break - epoch += 1 - if display_loss is True: - dynamic_loss_plot_finish() - return losses + self.losses = [0] def save(self, path: str): path = path.removesuffix('.npy') @@ -88,13 +49,27 @@ class AAutoencoder(ABC): def forward(self, v: np.ndarray) -> np.ndarray: pass + @abstractmethod + def train_dataset(self, *args, **kwargs) -> list[float]: + pass + class ClassicalAutoencoder(AAutoencoder): + plotter_cls = CAPlotter + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.losses = [] def __str__(self): - return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}' + return "\n".join(( + f"Type: {__class__.__name__}", + "Encoder:", + f"{self.encoder}", + "Decoder:", + f"{self.decoder}" + ) + ) def loss(self, data_set: list[np.ndarray]) -> float: loss = 0 @@ -112,6 +87,41 @@ class ClassicalAutoencoder(AAutoencoder): ) return np.sum(np.abs(error)) / len(v) + @interruptable + def train_dataset(self, + data_set: list[np.ndarray], + max_epoch: int, + patience: int, + display_loss: bool = False) -> list[float]: + plotter = self.plotter_cls(self) if display_loss else Plotter(self) + self.losses = [self.loss(data_set)] + epoch = 0 + no_improv = 0 + prev_error = self.losses[0] + with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar: + while True: + lbar.set_description( + f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} error={float(prev_error):.6f})", # noqa + ) + lbar.update() + error = 0 + for x in tqdm(data_set, leave=False): + error += self.train(x) + error /= len(data_set) + derror = prev_error - error + if derror <= 0 or abs(derror) < 1e-4: + no_improv += 1 + else: + no_improv = 0 + prev_error = float(error) + self.losses.append(error) + if no_improv > patience: + break + if epoch > max_epoch: + break + plotter.update() + epoch += 1 + def encode(self, v: np.ndarray) -> np.ndarray: return self.encoder.forward(v) @@ -125,20 +135,36 @@ class ClassicalAutoencoder(AAutoencoder): class VariationalAutoencoder(AAutoencoder): + plotter_cls = VAEPlotter + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity()) + self.KL_losses = [] + self.recon_losses = [] + + def __str__(self): + return "\n".join(( + f"Type: {__class__.__name__}", + "Encoder:", + f"{self.encoder}", + "Decoder:", + f"{self.decoder}" + )) def loss(self, data_set: list[np.ndarray]) -> float: - loss = 0 + kl_loss = 0 + recon_loss = 0 for x in data_set: out = self.forward(x)[0] kl = self.sampler.DKL() - loss += np.mean((out - x) ** 2) - loss += kl - return loss / len(data_set) + recon_loss += np.mean((out - x) ** 2) + kl_loss += kl + kl_loss /= len(data_set) + recon_loss /= len(data_set) + return recon_loss, kl_loss - def train(self, v: np.ndarray) -> float: + def train(self, v: np.ndarray) -> tuple[float, float]: out, _ = self.forward(v) error = out - v self.encoder.backprop( @@ -146,10 +172,61 @@ class VariationalAutoencoder(AAutoencoder): self.decoder.backprop(error) ) ) - return np.mean(error ** 2) + self.sampler.DKL() + return np.mean(error ** 2), self.sampler.DKL() + + @interruptable + def train_dataset(self, + data_set: list[np.ndarray], + max_epoch: int, + patience: int, + display_loss: bool = False) -> list[float]: + plotter = self.plotter_cls(self) if display_loss else Plotter(self) + recon_0, kl_0 = self.loss(data_set) + self.recon_losses = [recon_0] + self.KL_losses = [kl_0] + epoch = 0 + no_improv = 0 + prev_loss = self.recon_losses[0] + self.KL_losses[0] + with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar: + while True: + lbar.set_description( + f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} loss={float(prev_loss):.6f})", # noqa + ) + lbar.update() + dkl = 0 + recon = 0 + for x in tqdm(data_set, leave=False): + recon_i, dkl_i = self.train(x) + dkl += dkl_i + recon += recon_i + recon /= len(data_set) + dkl /= len(data_set) + loss = recon + dkl + dloss = prev_loss - loss + if dloss <= 0 or abs(dloss) < 1e-4: + no_improv += 1 + else: + no_improv = 0 + prev_loss = float(loss) + self.recon_losses.append(recon) + self.KL_losses.append(dkl) + if no_improv > patience: + break + if epoch > max_epoch: + break + plotter.update() + epoch += 1 def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]: code = self.encoder.forward(v) sample = self.sampler.forward(code) out = self.decoder.forward(sample) return out, code + + def encode(self, v: np.ndarray) -> np.ndarray: + return self.sampler.forward( + self.encoder.forward(v) + ) + + def decode(self, v: np.ndarray) -> np.ndarray: + return self.decoder.forward(v) diff --git a/src/easyvae/plotters.py b/src/easyvae/plotters.py new file mode 100644 index 0000000..6750b43 --- /dev/null +++ b/src/easyvae/plotters.py @@ -0,0 +1,93 @@ +import matplotlib.pyplot as plt +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .autoencoder import AAutoencoder, VariationalAutoencoder + + +class Plotter: + def __init__(self, autoencoder: 'AAutoencoder'): + pass + + def update(self): + pass + + def close(self): + pass + + def __del__(self): + self.close() + + +class CAPlotter(Plotter): + def __init__(self, autoencoder: 'AAutoencoder'): + self.autoencoder = autoencoder + plt.ion() + self.fig, self.ax = plt.subplots() + self.line, = self.ax.plot( + list(range(len(autoencoder.losses))), + autoencoder.losses, + label="Loss" + ) + self.ax.set_xlabel("Epoch") + self.ax.set_ylabel("Loss") + self.ax.set_title("Training MSE Loss") + self.ax.legend() + self.update() + + def update(self): + self.line.set_xdata(range(len(self.autoencoder.losses))) + self.line.set_ydata(self.autoencoder.losses) + self.ax.relim() + self.ax.autoscale_view() + plt.draw() + plt.pause(0.1) + + def close(self): + plt.ioff() + plt.close(self.fig) + + +class VAEPlotter(Plotter): + def __init__(self, autoencoder: 'VariationalAutoencoder'): + self.autoencoder = autoencoder + plt.ion() + self.fig, (self.ax_recon, self.ax_dkl) = plt.subplots(1, 2) + self.line, = self.ax_recon.plot( + list(range(len(self.autoencoder.recon_losses))), + self.autoencoder.recon_losses, + label="Loss" + ) + self.ax_recon.set_xlabel("Epoch") + self.ax_recon.set_ylabel("Loss") + self.ax_recon.set_title("Reconstruction MSE Loss") + self.ax_recon.legend() + + self.dkl_line, = self.ax_dkl.plot( + list(range(len(self.autoencoder.KL_losses))), + self.autoencoder.KL_losses, + label="DKL Loss", + ) + self.ax_dkl.set_xlabel("Epoch") + self.ax_dkl.set_ylabel("Loss") + self.ax_dkl.set_title("DKL Loss") + self.ax_dkl.legend() + self.update() + + def update(self): + self.line.set_xdata(range(len(self.autoencoder.recon_losses))) + self.line.set_ydata(self.autoencoder.recon_losses) + self.ax_recon.relim() + self.ax_recon.autoscale_view() + + self.dkl_line.set_xdata(range(len(self.autoencoder.KL_losses))) + self.dkl_line.set_ydata(self.autoencoder.KL_losses) + self.ax_dkl.relim() + self.ax_dkl.autoscale_view() + + plt.draw() + plt.pause(0.1) + + def close(self): + plt.ioff() + plt.close(self.fig) diff --git a/src/easyvae/utils.py b/src/easyvae/utils.py index 7e61971..414db1a 100644 --- a/src/easyvae/utils.py +++ b/src/easyvae/utils.py @@ -1,6 +1,5 @@ import numpy as np -import matplotlib.pyplot as plt def softmax(v: np.ndarray) -> np.ndarray: @@ -21,26 +20,10 @@ def regularize(v: np.ndarray) -> np.ndarray: return (v - v_min) / (v_max - v_min) -def dynamic_loss_plot_init(losses: list): - plt.ion() - fig, ax = plt.subplots() - line, = ax.plot([0], losses, label="Loss") - ax.set_xlabel("Epoch") - ax.set_ylabel("Loss") - ax.set_title("Training Loss") - ax.legend() - return ax, line - - -def dynamic_loss_plot_update(ax, line, loss): - line.set_xdata(range(len(loss))) - line.set_ydata(loss) - ax.relim() - ax.autoscale_view() - plt.draw() - plt.pause(0.1) - - -def dynamic_loss_plot_finish(): - plt.ioff() - plt.show() +def interruptable(func): + def inner(*args, **kwargs): + try: + return func(*args, **kwargs) + except KeyboardInterrupt: + pass + return inner