diff --git a/activations.py b/activations.py index 10c5ad2..3f6a704 100644 --- a/activations.py +++ b/activations.py @@ -31,5 +31,5 @@ class Identity(ActivationFunc): def __call__(self, x): return x - def d(x): + def d(self, x): return 1 diff --git a/autoencoder.py b/autoencoder.py index dbb4f8d..c650fb0 100644 --- a/autoencoder.py +++ b/autoencoder.py @@ -4,7 +4,7 @@ from utils import (dynamic_loss_plot_init, dynamic_loss_plot_finish) from tqdm import tqdm from layers import DeepNNLayer, SampleLayer -from activations import ActivationFunc +from activations import ActivationFunc, Identity from abc import ABC, abstractmethod LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿'] @@ -46,7 +46,6 @@ class AAutoencoder(ABC): if epoch > max_epoch: break epoch += 1 - print("Training complete !") if display_loss is True: dynamic_loss_plot_finish(ax, line) return losses @@ -129,12 +128,15 @@ class VariationalAutoencoder(AAutoencoder): ) self.encoder = DeepNNLayer(encoder_layers, lr, activation_func) self.decoder = DeepNNLayer(decoder_layers, lr, activation_func) - self.sampler = SampleLayer(self.encoder.out_size, lr, activation_func) + self.sampler = SampleLayer(self.encoder.out_size, lr, Identity()) def loss(self, data_set: list[np.ndarray]) -> float: loss = 0 for x in data_set: - loss += np.sum(np.abs(x - self.forward(x)[0])) / len(x) + out = self.forward(x)[0] + kl = self.sampler.DKL() + loss += np.mean((out - x) ** 2) + loss += kl return loss / len(data_set) def train(self, v: np.ndarray) -> float: @@ -145,7 +147,7 @@ class VariationalAutoencoder(AAutoencoder): self.decoder.backprop(error) ) ) - return np.sum(np.abs(error)) / len(v) + return np.mean(error ** 2) + self.sampler.DKL() def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]: code = self.encoder.forward(v) diff --git a/layers.py b/layers.py index e2a3b08..0ca8a1b 100644 --- a/layers.py +++ b/layers.py @@ -1,6 +1,5 @@ import numpy as np -from utils import normalize -from activations import ActivationFunc +from activations import ActivationFunc, Identity class NNLayer: @@ -9,7 +8,8 @@ class NNLayer: out_size: int, lr: float, activation_func: ActivationFunc): - self.W = np.random.uniform(-1, 1, (in_size, out_size)) + limit = np.sqrt(6 / (in_size + out_size)) + self.W = np.random.uniform(-limit, limit, (in_size, out_size)) self.B = np.zeros((out_size)) self.lr = lr self.input = None @@ -21,7 +21,7 @@ class NNLayer: return f'[ {self.W.shape[0]} => {self.W.shape[1]}\tlr:{self.lr}\tactivation:{self.activation_func.__class__.__name__} ]' # noqa def forward(self, v: np.ndarray) -> np.ndarray: - self.input = normalize(v) + self.input = v self.output_linear = self.input @ self.W + self.B self.output = self.activation_func( self.output_linear @@ -55,17 +55,23 @@ class SampleLayer: lr, activation_func) + def DKL(self): + return -0.5 * np.mean(1 + self.logvar - self.mean ** 2 - np.exp(self.logvar)) # noqa + def forward(self, v: np.ndarray) -> np.ndarray: self.input = v self.mean = self.mean_nn.forward(v) - self.std = self.std_nn.forward(v) + self.logvar = np.clip(self.std_nn.forward(v)) + self.std = np.exp(0.5 * self.logvar) self.eps = np.random.normal(0, 1, self.mean.shape) - return self.eps * self.std + self.mean + return 0.5 * self.eps * self.std + self.mean def backprop(self, error: np.ndarray) -> np.ndarray: - mu_error = self.mean_nn.backprop(error) - std_error = self.std_nn.backprop(error * self.eps * self.std * 0.5) - return mu_error + std_error + dmean = error + self.mean + dstd = error * self.eps + 0.5 * (np.exp(self.logvar) - 1) + mean_error = self.mean_nn.backprop(dmean) + logvar_error = self.std_nn.backprop(dstd * self.std) + return mean_error + logvar_error class DeepNNLayer: @@ -80,7 +86,8 @@ class DeepNNLayer: layers[i], layers[i+1], lr, - activation_func) + activation_func if i != len(layers) - 2 else Identity() + ) ) self.in_size = layers[0] self.out_size = layers[-1] diff --git a/mnist_test.py b/mnist_test.py index ba28bba..2355006 100644 --- a/mnist_test.py +++ b/mnist_test.py @@ -1,8 +1,11 @@ import matplotlib.pyplot as plt import numpy as np -from autoencoder import VariationalAutoencoder, AAutoencoder -from activations import LeakyReLU import os +import signal +from autoencoder import (VariationalAutoencoder, # noqa + ClassicalAutoencoder, + AAutoencoder) +from activations import LeakyReLU def load_mnist() -> list[np.ndarray]: @@ -21,29 +24,39 @@ def mnist_train( filename: str, max_epoch: int, patience: int, - cls: type[AAutoencoder] - ) -> AAutoencoder: + cls: type[AAutoencoder],) -> AAutoencoder: x_train, _, x_test, _ = load_mnist() in_len = x_train[0].shape[0] * x_train[0].shape[0] x_train.resize(x_train.shape[0], in_len) x_test.resize(x_test.shape[0], in_len) x_train = x_train / 255 - x_test = x_test / 255 if os.path.exists(filename): autoencoder = cls.load(filename) else: autoencoder = cls( - [in_len, 16], - [16, in_len], - 0.01, + [in_len, 256, 2], + [2, 256, in_len], + 0.001, LeakyReLU() ) + + def handler(signum, frame): + print(f"Saving {filename} before exit ...") + autoencoder.save(filename) + plt.close() + plt.ioff() + mnist_test(autoencoder) + exit() + + signal.signal(signal.SIGINT, handler) + print("CTRL+C to exit and save model.") autoencoder.train_dataset( x_train, max_epoch, patience, display_loss=True) autoencoder.save(filename) + print("Training complete !") return autoencoder @@ -59,7 +72,6 @@ def mnist_test(model: str | AAutoencoder): autoencoder: AAutoencoder = AAutoencoder.load(model) else: autoencoder = model - print(autoencoder) idx = np.random.randint(0, len(x_test)) example: np.ndarray = x_test[idx] output, code = autoencoder.forward(example.flatten()) @@ -74,11 +86,29 @@ def mnist_test(model: str | AAutoencoder): fignum=False) plt.title(f"Output ({y_test[idx]})") plt.subplot(1, 3, 3) - s = int(np.ceil(np.sqrt(code.shape[0]))) - code.resize((s, s), refcheck=False) + code = np.reshape(code, (code.shape[0], 1)) plt.matshow(code, fignum=False) plt.title(f"Code ({y_test[idx]})") plt.show() + if code.shape[0] == 2: + codes = [] + for x in x_test: + _, c = autoencoder.forward(x.flatten()) + codes.append(c) + codes = np.array(codes) + if codes.shape[1] == 2: + plt.figure(figsize=(6, 6)) + scatter = plt.scatter( + codes[:, 0], + codes[:, 1], + c=y_test, + cmap='tab10', + s=5, + alpha=0.7 + ) + plt.colorbar(scatter) + plt.grid(True) + plt.show() if __name__ == "__main__":