feat: plot 2d latent space + signal handling + fix SGD in Sampler
This commit is contained in:
@@ -31,5 +31,5 @@ class Identity(ActivationFunc):
|
|||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def d(x):
|
def d(self, x):
|
||||||
return 1
|
return 1
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from utils import (dynamic_loss_plot_init,
|
|||||||
dynamic_loss_plot_finish)
|
dynamic_loss_plot_finish)
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from layers import DeepNNLayer, SampleLayer
|
from layers import DeepNNLayer, SampleLayer
|
||||||
from activations import ActivationFunc
|
from activations import ActivationFunc, Identity
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
||||||
@@ -46,7 +46,6 @@ class AAutoencoder(ABC):
|
|||||||
if epoch > max_epoch:
|
if epoch > max_epoch:
|
||||||
break
|
break
|
||||||
epoch += 1
|
epoch += 1
|
||||||
print("Training complete !")
|
|
||||||
if display_loss is True:
|
if display_loss is True:
|
||||||
dynamic_loss_plot_finish(ax, line)
|
dynamic_loss_plot_finish(ax, line)
|
||||||
return losses
|
return losses
|
||||||
@@ -129,12 +128,15 @@ class VariationalAutoencoder(AAutoencoder):
|
|||||||
)
|
)
|
||||||
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
||||||
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
||||||
self.sampler = SampleLayer(self.encoder.out_size, lr, activation_func)
|
self.sampler = SampleLayer(self.encoder.out_size, lr, Identity())
|
||||||
|
|
||||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||||
loss = 0
|
loss = 0
|
||||||
for x in data_set:
|
for x in data_set:
|
||||||
loss += np.sum(np.abs(x - self.forward(x)[0])) / len(x)
|
out = self.forward(x)[0]
|
||||||
|
kl = self.sampler.DKL()
|
||||||
|
loss += np.mean((out - x) ** 2)
|
||||||
|
loss += kl
|
||||||
return loss / len(data_set)
|
return loss / len(data_set)
|
||||||
|
|
||||||
def train(self, v: np.ndarray) -> float:
|
def train(self, v: np.ndarray) -> float:
|
||||||
@@ -145,7 +147,7 @@ class VariationalAutoencoder(AAutoencoder):
|
|||||||
self.decoder.backprop(error)
|
self.decoder.backprop(error)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return np.sum(np.abs(error)) / len(v)
|
return np.mean(error ** 2) + self.sampler.DKL()
|
||||||
|
|
||||||
def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||||
code = self.encoder.forward(v)
|
code = self.encoder.forward(v)
|
||||||
|
|||||||
27
layers.py
27
layers.py
@@ -1,6 +1,5 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from utils import normalize
|
from activations import ActivationFunc, Identity
|
||||||
from activations import ActivationFunc
|
|
||||||
|
|
||||||
|
|
||||||
class NNLayer:
|
class NNLayer:
|
||||||
@@ -9,7 +8,8 @@ class NNLayer:
|
|||||||
out_size: int,
|
out_size: int,
|
||||||
lr: float,
|
lr: float,
|
||||||
activation_func: ActivationFunc):
|
activation_func: ActivationFunc):
|
||||||
self.W = np.random.uniform(-1, 1, (in_size, out_size))
|
limit = np.sqrt(6 / (in_size + out_size))
|
||||||
|
self.W = np.random.uniform(-limit, limit, (in_size, out_size))
|
||||||
self.B = np.zeros((out_size))
|
self.B = np.zeros((out_size))
|
||||||
self.lr = lr
|
self.lr = lr
|
||||||
self.input = None
|
self.input = None
|
||||||
@@ -21,7 +21,7 @@ class NNLayer:
|
|||||||
return f'[ {self.W.shape[0]} => {self.W.shape[1]}\tlr:{self.lr}\tactivation:{self.activation_func.__class__.__name__} ]' # noqa
|
return f'[ {self.W.shape[0]} => {self.W.shape[1]}\tlr:{self.lr}\tactivation:{self.activation_func.__class__.__name__} ]' # noqa
|
||||||
|
|
||||||
def forward(self, v: np.ndarray) -> np.ndarray:
|
def forward(self, v: np.ndarray) -> np.ndarray:
|
||||||
self.input = normalize(v)
|
self.input = v
|
||||||
self.output_linear = self.input @ self.W + self.B
|
self.output_linear = self.input @ self.W + self.B
|
||||||
self.output = self.activation_func(
|
self.output = self.activation_func(
|
||||||
self.output_linear
|
self.output_linear
|
||||||
@@ -55,17 +55,23 @@ class SampleLayer:
|
|||||||
lr,
|
lr,
|
||||||
activation_func)
|
activation_func)
|
||||||
|
|
||||||
|
def DKL(self):
|
||||||
|
return -0.5 * np.mean(1 + self.logvar - self.mean ** 2 - np.exp(self.logvar)) # noqa
|
||||||
|
|
||||||
def forward(self, v: np.ndarray) -> np.ndarray:
|
def forward(self, v: np.ndarray) -> np.ndarray:
|
||||||
self.input = v
|
self.input = v
|
||||||
self.mean = self.mean_nn.forward(v)
|
self.mean = self.mean_nn.forward(v)
|
||||||
self.std = self.std_nn.forward(v)
|
self.logvar = np.clip(self.std_nn.forward(v))
|
||||||
|
self.std = np.exp(0.5 * self.logvar)
|
||||||
self.eps = np.random.normal(0, 1, self.mean.shape)
|
self.eps = np.random.normal(0, 1, self.mean.shape)
|
||||||
return self.eps * self.std + self.mean
|
return 0.5 * self.eps * self.std + self.mean
|
||||||
|
|
||||||
def backprop(self, error: np.ndarray) -> np.ndarray:
|
def backprop(self, error: np.ndarray) -> np.ndarray:
|
||||||
mu_error = self.mean_nn.backprop(error)
|
dmean = error + self.mean
|
||||||
std_error = self.std_nn.backprop(error * self.eps * self.std * 0.5)
|
dstd = error * self.eps + 0.5 * (np.exp(self.logvar) - 1)
|
||||||
return mu_error + std_error
|
mean_error = self.mean_nn.backprop(dmean)
|
||||||
|
logvar_error = self.std_nn.backprop(dstd * self.std)
|
||||||
|
return mean_error + logvar_error
|
||||||
|
|
||||||
|
|
||||||
class DeepNNLayer:
|
class DeepNNLayer:
|
||||||
@@ -80,7 +86,8 @@ class DeepNNLayer:
|
|||||||
layers[i],
|
layers[i],
|
||||||
layers[i+1],
|
layers[i+1],
|
||||||
lr,
|
lr,
|
||||||
activation_func)
|
activation_func if i != len(layers) - 2 else Identity()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.in_size = layers[0]
|
self.in_size = layers[0]
|
||||||
self.out_size = layers[-1]
|
self.out_size = layers[-1]
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from autoencoder import VariationalAutoencoder, AAutoencoder
|
|
||||||
from activations import LeakyReLU
|
|
||||||
import os
|
import os
|
||||||
|
import signal
|
||||||
|
from autoencoder import (VariationalAutoencoder, # noqa
|
||||||
|
ClassicalAutoencoder,
|
||||||
|
AAutoencoder)
|
||||||
|
from activations import LeakyReLU
|
||||||
|
|
||||||
|
|
||||||
def load_mnist() -> list[np.ndarray]:
|
def load_mnist() -> list[np.ndarray]:
|
||||||
@@ -21,29 +24,39 @@ def mnist_train(
|
|||||||
filename: str,
|
filename: str,
|
||||||
max_epoch: int,
|
max_epoch: int,
|
||||||
patience: int,
|
patience: int,
|
||||||
cls: type[AAutoencoder]
|
cls: type[AAutoencoder],) -> AAutoencoder:
|
||||||
) -> AAutoencoder:
|
|
||||||
x_train, _, x_test, _ = load_mnist()
|
x_train, _, x_test, _ = load_mnist()
|
||||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||||
x_train.resize(x_train.shape[0], in_len)
|
x_train.resize(x_train.shape[0], in_len)
|
||||||
x_test.resize(x_test.shape[0], in_len)
|
x_test.resize(x_test.shape[0], in_len)
|
||||||
x_train = x_train / 255
|
x_train = x_train / 255
|
||||||
x_test = x_test / 255
|
|
||||||
if os.path.exists(filename):
|
if os.path.exists(filename):
|
||||||
autoencoder = cls.load(filename)
|
autoencoder = cls.load(filename)
|
||||||
else:
|
else:
|
||||||
autoencoder = cls(
|
autoencoder = cls(
|
||||||
[in_len, 16],
|
[in_len, 256, 2],
|
||||||
[16, in_len],
|
[2, 256, in_len],
|
||||||
0.01,
|
0.001,
|
||||||
LeakyReLU()
|
LeakyReLU()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def handler(signum, frame):
|
||||||
|
print(f"Saving {filename} before exit ...")
|
||||||
|
autoencoder.save(filename)
|
||||||
|
plt.close()
|
||||||
|
plt.ioff()
|
||||||
|
mnist_test(autoencoder)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, handler)
|
||||||
|
print("CTRL+C to exit and save model.")
|
||||||
autoencoder.train_dataset(
|
autoencoder.train_dataset(
|
||||||
x_train,
|
x_train,
|
||||||
max_epoch,
|
max_epoch,
|
||||||
patience,
|
patience,
|
||||||
display_loss=True)
|
display_loss=True)
|
||||||
autoencoder.save(filename)
|
autoencoder.save(filename)
|
||||||
|
print("Training complete !")
|
||||||
return autoencoder
|
return autoencoder
|
||||||
|
|
||||||
|
|
||||||
@@ -59,7 +72,6 @@ def mnist_test(model: str | AAutoencoder):
|
|||||||
autoencoder: AAutoencoder = AAutoencoder.load(model)
|
autoencoder: AAutoencoder = AAutoencoder.load(model)
|
||||||
else:
|
else:
|
||||||
autoencoder = model
|
autoencoder = model
|
||||||
print(autoencoder)
|
|
||||||
idx = np.random.randint(0, len(x_test))
|
idx = np.random.randint(0, len(x_test))
|
||||||
example: np.ndarray = x_test[idx]
|
example: np.ndarray = x_test[idx]
|
||||||
output, code = autoencoder.forward(example.flatten())
|
output, code = autoencoder.forward(example.flatten())
|
||||||
@@ -74,11 +86,29 @@ def mnist_test(model: str | AAutoencoder):
|
|||||||
fignum=False)
|
fignum=False)
|
||||||
plt.title(f"Output ({y_test[idx]})")
|
plt.title(f"Output ({y_test[idx]})")
|
||||||
plt.subplot(1, 3, 3)
|
plt.subplot(1, 3, 3)
|
||||||
s = int(np.ceil(np.sqrt(code.shape[0])))
|
code = np.reshape(code, (code.shape[0], 1))
|
||||||
code.resize((s, s), refcheck=False)
|
|
||||||
plt.matshow(code, fignum=False)
|
plt.matshow(code, fignum=False)
|
||||||
plt.title(f"Code ({y_test[idx]})")
|
plt.title(f"Code ({y_test[idx]})")
|
||||||
plt.show()
|
plt.show()
|
||||||
|
if code.shape[0] == 2:
|
||||||
|
codes = []
|
||||||
|
for x in x_test:
|
||||||
|
_, c = autoencoder.forward(x.flatten())
|
||||||
|
codes.append(c)
|
||||||
|
codes = np.array(codes)
|
||||||
|
if codes.shape[1] == 2:
|
||||||
|
plt.figure(figsize=(6, 6))
|
||||||
|
scatter = plt.scatter(
|
||||||
|
codes[:, 0],
|
||||||
|
codes[:, 1],
|
||||||
|
c=y_test,
|
||||||
|
cmap='tab10',
|
||||||
|
s=5,
|
||||||
|
alpha=0.7
|
||||||
|
)
|
||||||
|
plt.colorbar(scatter)
|
||||||
|
plt.grid(True)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user