refactor: move plot logic to plotters.py
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -4,4 +4,5 @@ __pycache__
|
||||
*.npy
|
||||
.venv
|
||||
dist
|
||||
*.egg-info
|
||||
*.egg-info
|
||||
.env
|
||||
@@ -6,7 +6,7 @@
|
||||
src="https://raw.githubusercontent.com/lenoctambule/autoencoder/refs/heads/main/media/latent-space.png"
|
||||
alt="Latent-space of the MNIST dataset"
|
||||
width=70%>
|
||||
<figcaption align=center>
|
||||
<figcaption>
|
||||
<p align="center">
|
||||
Latent-space representation of the MNIST dataset using Variational Autoencoder
|
||||
</p>
|
||||
|
||||
@@ -8,7 +8,6 @@ from easyvae.autoencoder import ( # noqa
|
||||
AAutoencoder
|
||||
)
|
||||
from easyvae.activations import LeakyReLU
|
||||
from easyvae.utils import dynamic_loss_plot_finish
|
||||
|
||||
|
||||
def load_mnist() -> list[np.ndarray]:
|
||||
@@ -33,6 +32,7 @@ def mnist_train(
|
||||
x_train.resize(x_train.shape[0], in_len)
|
||||
x_test.resize(x_test.shape[0], in_len)
|
||||
x_train = x_train / 255
|
||||
x_train = x_train[:5000]
|
||||
if os.path.exists(filename):
|
||||
autoencoder = cls.load(filename)
|
||||
else:
|
||||
@@ -46,8 +46,8 @@ def mnist_train(
|
||||
def handler(signum, frame):
|
||||
print(f"Saving {filename} before exit ...")
|
||||
autoencoder.save(filename)
|
||||
if plt.get_fignums():
|
||||
dynamic_loss_plot_finish()
|
||||
plt.close('all')
|
||||
plt.ioff()
|
||||
mnist_test(autoencoder)
|
||||
exit()
|
||||
|
||||
@@ -84,10 +84,11 @@ def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,):
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_random_reconstruction(autoencoder: AAutoencoder,
|
||||
example: np.ndarray,
|
||||
img_shape,
|
||||
y):
|
||||
def plot_random_reconstruction(
|
||||
autoencoder: AAutoencoder,
|
||||
example: np.ndarray,
|
||||
img_shape,
|
||||
y):
|
||||
output, code = autoencoder.forward(example.flatten())
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.matshow(
|
||||
@@ -114,6 +115,8 @@ def mnist_test(model: str | AAutoencoder):
|
||||
autoencoder: AAutoencoder = AAutoencoder.load(model)
|
||||
else:
|
||||
autoencoder = model
|
||||
print("Testing model ...\n")
|
||||
print(autoencoder)
|
||||
idx = np.random.randint(0, len(x_test))
|
||||
example: np.ndarray = x_test[idx]
|
||||
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
import numpy as np
|
||||
from .utils import (
|
||||
dynamic_loss_plot_init,
|
||||
dynamic_loss_plot_update,
|
||||
dynamic_loss_plot_finish
|
||||
)
|
||||
from tqdm import tqdm
|
||||
from .layers import DeepNNLayer, SampleLayer
|
||||
from .activations import ActivationFunc, Identity
|
||||
from .plotters import Plotter, CAPlotter
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
||||
|
||||
|
||||
class AAutoencoder(ABC):
|
||||
plotter_cls = Plotter
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self,
|
||||
encoder_layers: list[int],
|
||||
@@ -27,18 +25,18 @@ class AAutoencoder(ABC):
|
||||
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
||||
self.space_dim = decoder_layers[0]
|
||||
self.lr = lr
|
||||
self.losses = [0]
|
||||
|
||||
def train_dataset(self,
|
||||
data_set: list[np.ndarray],
|
||||
max_epoch: int,
|
||||
patience: int,
|
||||
display_loss: bool = False) -> list[float]:
|
||||
losses = [self.loss(data_set)]
|
||||
if display_loss is True:
|
||||
ax, line = dynamic_loss_plot_init(losses)
|
||||
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
|
||||
self.losses = [self.loss(data_set)]
|
||||
epoch = 0
|
||||
no_improv = 0
|
||||
prev_error = losses[0]
|
||||
prev_error = self.losses[0]
|
||||
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
|
||||
while True:
|
||||
lbar.set_description(
|
||||
@@ -55,17 +53,15 @@ class AAutoencoder(ABC):
|
||||
else:
|
||||
no_improv = 0
|
||||
prev_error = float(error)
|
||||
losses.append(error)
|
||||
if display_loss is True:
|
||||
dynamic_loss_plot_update(ax, line, losses)
|
||||
self.losses.append(error)
|
||||
if no_improv > patience:
|
||||
break
|
||||
if epoch > max_epoch:
|
||||
break
|
||||
plotter.update()
|
||||
epoch += 1
|
||||
if display_loss is True:
|
||||
dynamic_loss_plot_finish()
|
||||
return losses
|
||||
plotter.close()
|
||||
return self.losses
|
||||
|
||||
def save(self, path: str):
|
||||
path = path.removesuffix('.npy')
|
||||
@@ -90,11 +86,19 @@ class AAutoencoder(ABC):
|
||||
|
||||
|
||||
class ClassicalAutoencoder(AAutoencoder):
|
||||
plotter_cls = CAPlotter
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}'
|
||||
return "\n".join((
|
||||
f"Type: {__class__.__name__}",
|
||||
"Encoder:",
|
||||
f"{self.encoder}",
|
||||
"Decoder:"
|
||||
f"{self.decoder}"
|
||||
))
|
||||
|
||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||
loss = 0
|
||||
|
||||
48
src/easyvae/plotters.py
Normal file
48
src/easyvae/plotters.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .autoencoder import AAutoencoder
|
||||
|
||||
|
||||
class Plotter:
|
||||
def __init__(self, autoencoder: 'AAutoencoder'):
|
||||
pass
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
|
||||
class CAPlotter(Plotter):
|
||||
def __init__(self, autoencoder: 'AAutoencoder'):
|
||||
self.autoencoder = autoencoder
|
||||
plt.ion()
|
||||
self.fig, self.ax = plt.subplots()
|
||||
self.line, = self.ax.plot(
|
||||
list(range(len(autoencoder.losses))),
|
||||
autoencoder.losses,
|
||||
label="Loss"
|
||||
)
|
||||
self.ax.set_xlabel("Epoch")
|
||||
self.ax.set_ylabel("Loss")
|
||||
self.ax.set_title("Training MSE Loss")
|
||||
self.ax.legend()
|
||||
self.update()
|
||||
|
||||
def update(self):
|
||||
self.line.set_xdata(range(len(self.autoencoder.losses)))
|
||||
self.line.set_ydata(self.autoencoder.losses)
|
||||
self.ax.relim()
|
||||
self.ax.autoscale_view()
|
||||
plt.draw()
|
||||
plt.pause(0.1)
|
||||
|
||||
def close(self):
|
||||
plt.ioff()
|
||||
plt.show()
|
||||
@@ -1,6 +1,5 @@
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def softmax(v: np.ndarray) -> np.ndarray:
|
||||
@@ -19,28 +18,3 @@ def regularize(v: np.ndarray) -> np.ndarray:
|
||||
if v_min - v_max == 0:
|
||||
return v
|
||||
return (v - v_min) / (v_max - v_min)
|
||||
|
||||
|
||||
def dynamic_loss_plot_init(losses: list):
|
||||
plt.ion()
|
||||
fig, ax = plt.subplots()
|
||||
line, = ax.plot([0], losses, label="Loss")
|
||||
ax.set_xlabel("Epoch")
|
||||
ax.set_ylabel("Loss")
|
||||
ax.set_title("Training Loss")
|
||||
ax.legend()
|
||||
return ax, line
|
||||
|
||||
|
||||
def dynamic_loss_plot_update(ax, line, loss):
|
||||
line.set_xdata(range(len(loss)))
|
||||
line.set_ydata(loss)
|
||||
ax.relim()
|
||||
ax.autoscale_view()
|
||||
plt.draw()
|
||||
plt.pause(0.1)
|
||||
|
||||
|
||||
def dynamic_loss_plot_finish():
|
||||
plt.ioff()
|
||||
plt.show()
|
||||
|
||||
Reference in New Issue
Block a user