feat: plot 2d latent space + signal handling + fix SGD in Sampler

This commit is contained in:
Lenoctambule
2026-04-07 22:25:39 +02:00
parent 3440de851a
commit 510ad8720c
4 changed files with 66 additions and 27 deletions

View File

@@ -1,8 +1,11 @@
import matplotlib.pyplot as plt
import numpy as np
from autoencoder import VariationalAutoencoder, AAutoencoder
from activations import LeakyReLU
import os
import signal
from autoencoder import (VariationalAutoencoder, # noqa
ClassicalAutoencoder,
AAutoencoder)
from activations import LeakyReLU
def load_mnist() -> list[np.ndarray]:
@@ -21,29 +24,39 @@ def mnist_train(
filename: str,
max_epoch: int,
patience: int,
cls: type[AAutoencoder]
) -> AAutoencoder:
cls: type[AAutoencoder],) -> AAutoencoder:
x_train, _, x_test, _ = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0]
x_train.resize(x_train.shape[0], in_len)
x_test.resize(x_test.shape[0], in_len)
x_train = x_train / 255
x_test = x_test / 255
if os.path.exists(filename):
autoencoder = cls.load(filename)
else:
autoencoder = cls(
[in_len, 16],
[16, in_len],
0.01,
[in_len, 256, 2],
[2, 256, in_len],
0.001,
LeakyReLU()
)
def handler(signum, frame):
print(f"Saving {filename} before exit ...")
autoencoder.save(filename)
plt.close()
plt.ioff()
mnist_test(autoencoder)
exit()
signal.signal(signal.SIGINT, handler)
print("CTRL+C to exit and save model.")
autoencoder.train_dataset(
x_train,
max_epoch,
patience,
display_loss=True)
autoencoder.save(filename)
print("Training complete !")
return autoencoder
@@ -59,7 +72,6 @@ def mnist_test(model: str | AAutoencoder):
autoencoder: AAutoencoder = AAutoencoder.load(model)
else:
autoencoder = model
print(autoencoder)
idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx]
output, code = autoencoder.forward(example.flatten())
@@ -74,11 +86,29 @@ def mnist_test(model: str | AAutoencoder):
fignum=False)
plt.title(f"Output ({y_test[idx]})")
plt.subplot(1, 3, 3)
s = int(np.ceil(np.sqrt(code.shape[0])))
code.resize((s, s), refcheck=False)
code = np.reshape(code, (code.shape[0], 1))
plt.matshow(code, fignum=False)
plt.title(f"Code ({y_test[idx]})")
plt.show()
if code.shape[0] == 2:
codes = []
for x in x_test:
_, c = autoencoder.forward(x.flatten())
codes.append(c)
codes = np.array(codes)
if codes.shape[1] == 2:
plt.figure(figsize=(6, 6))
scatter = plt.scatter(
codes[:, 0],
codes[:, 1],
c=y_test,
cmap='tab10',
s=5,
alpha=0.7
)
plt.colorbar(scatter)
plt.grid(True)
plt.show()
if __name__ == "__main__":