Merge pull request #3 from lenoctambule/dev
Add simple post-training labeling + Noise layer
This commit is contained in:
@@ -4,6 +4,7 @@ import os
|
|||||||
from easyvae.autoencoder import ( # noqa
|
from easyvae.autoencoder import ( # noqa
|
||||||
VariationalAutoencoder,
|
VariationalAutoencoder,
|
||||||
ClassicalAutoencoder,
|
ClassicalAutoencoder,
|
||||||
|
LabelingVAE,
|
||||||
AAutoencoder
|
AAutoencoder
|
||||||
)
|
)
|
||||||
from easyvae.activations import LeakyReLU
|
from easyvae.activations import LeakyReLU
|
||||||
@@ -26,10 +27,9 @@ def mnist_train(
|
|||||||
max_epoch: int,
|
max_epoch: int,
|
||||||
patience: int,
|
patience: int,
|
||||||
cls: type[AAutoencoder],) -> AAutoencoder:
|
cls: type[AAutoencoder],) -> AAutoencoder:
|
||||||
x_train, _, x_test, _ = load_mnist()
|
x_train, _, _, _ = load_mnist()
|
||||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||||
x_train.resize(x_train.shape[0], in_len)
|
x_train.resize(x_train.shape[0], in_len)
|
||||||
x_test.resize(x_test.shape[0], in_len)
|
|
||||||
x_train = x_train / 255
|
x_train = x_train / 255
|
||||||
if os.path.exists(filename):
|
if os.path.exists(filename):
|
||||||
autoencoder = cls.load(filename)
|
autoencoder = cls.load(filename)
|
||||||
@@ -37,7 +37,7 @@ def mnist_train(
|
|||||||
autoencoder = cls(
|
autoencoder = cls(
|
||||||
[in_len, 256, 2],
|
[in_len, 256, 2],
|
||||||
[2, 256, in_len],
|
[2, 256, in_len],
|
||||||
0.0001,
|
0.001,
|
||||||
LeakyReLU()
|
LeakyReLU()
|
||||||
)
|
)
|
||||||
print("CTRL+C to interrupt training.")
|
print("CTRL+C to interrupt training.")
|
||||||
@@ -69,7 +69,6 @@ def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,):
|
|||||||
)
|
)
|
||||||
plt.colorbar(scatter)
|
plt.colorbar(scatter)
|
||||||
plt.grid(True)
|
plt.grid(True)
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
def plot_random_reconstruction(
|
def plot_random_reconstruction(
|
||||||
@@ -91,8 +90,8 @@ def plot_random_reconstruction(
|
|||||||
print(f'{code.tolist()}')
|
print(f'{code.tolist()}')
|
||||||
|
|
||||||
|
|
||||||
def mnist_test(model: str | AAutoencoder):
|
def mnist_test(model: str | AAutoencoder | LabelingVAE):
|
||||||
x_train, _, x_test, y_test = load_mnist()
|
x_train, y_train, x_test, y_test = load_mnist()
|
||||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||||
img_shape = x_train[0].shape
|
img_shape = x_train[0].shape
|
||||||
x_train.resize(x_train.shape[0], in_len)
|
x_train.resize(x_train.shape[0], in_len)
|
||||||
@@ -107,9 +106,15 @@ def mnist_test(model: str | AAutoencoder):
|
|||||||
print(autoencoder)
|
print(autoencoder)
|
||||||
idx = np.random.randint(0, len(x_test))
|
idx = np.random.randint(0, len(x_test))
|
||||||
example: np.ndarray = x_test[idx]
|
example: np.ndarray = x_test[idx]
|
||||||
|
labels_train = [str(int(i)) for i in y_train]
|
||||||
|
autoencoder.learn_labels(x_train, labels_train)
|
||||||
|
res = autoencoder.label(example)
|
||||||
|
for k, v in res.items():
|
||||||
|
print(f"{k} => {v}")
|
||||||
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])
|
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])
|
||||||
if autoencoder.space_dim == 2:
|
if autoencoder.space_dim == 2:
|
||||||
plot_mnist_latent_space(autoencoder, x_test, y_test)
|
plot_mnist_latent_space(autoencoder, x_test, y_test)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -150,6 +155,6 @@ if __name__ == "__main__":
|
|||||||
args.m,
|
args.m,
|
||||||
args.e,
|
args.e,
|
||||||
args.p,
|
args.p,
|
||||||
VariationalAutoencoder
|
LabelingVAE
|
||||||
)
|
)
|
||||||
mnist_test(autoencoder)
|
mnist_test(autoencoder)
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from .layers import DeepNNLayer, SampleLayer
|
from .layers import DeepNNLayer, SampleLayer, NoiseLayer
|
||||||
from .activations import ActivationFunc, Identity
|
from .activations import ActivationFunc, Identity
|
||||||
from .plotters import Plotter, CAPlotter, VAEPlotter
|
from .plotters import Plotter, CAPlotter, VAEPlotter
|
||||||
from .utils import interruptable
|
from .utils import interruptable
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
||||||
|
SQRT_2PI = np.sqrt(2 * np.pi)
|
||||||
|
|
||||||
|
|
||||||
class AAutoencoder(ABC):
|
class AAutoencoder(ABC):
|
||||||
@@ -17,13 +18,15 @@ class AAutoencoder(ABC):
|
|||||||
encoder_layers: list[int],
|
encoder_layers: list[int],
|
||||||
decoder_layers: list[int],
|
decoder_layers: list[int],
|
||||||
lr: float,
|
lr: float,
|
||||||
activation_func: ActivationFunc):
|
activation_func: ActivationFunc,
|
||||||
|
noise=0):
|
||||||
if encoder_layers[-1] != decoder_layers[0]:
|
if encoder_layers[-1] != decoder_layers[0]:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
|
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
|
||||||
)
|
)
|
||||||
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
||||||
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
||||||
|
self.noise = NoiseLayer(noise)
|
||||||
self.space_dim = decoder_layers[0]
|
self.space_dim = decoder_layers[0]
|
||||||
self.lr = lr
|
self.lr = lr
|
||||||
self.losses = [0]
|
self.losses = [0]
|
||||||
@@ -78,13 +81,15 @@ class ClassicalAutoencoder(AAutoencoder):
|
|||||||
return loss / len(data_set)
|
return loss / len(data_set)
|
||||||
|
|
||||||
def train(self, v: np.ndarray):
|
def train(self, v: np.ndarray):
|
||||||
out = self.decoder.forward(
|
out, _ = self.forward(
|
||||||
self.encoder.forward(v)
|
self.noise.forward(v)
|
||||||
)
|
)
|
||||||
error = out - v
|
error = out - v
|
||||||
self.encoder.backprop(
|
self.encoder.back(
|
||||||
self.decoder.backprop(error)
|
self.decoder.back(error)
|
||||||
)
|
)
|
||||||
|
self.encoder.backprop()
|
||||||
|
self.decoder.backprop()
|
||||||
return np.sum(np.abs(error)) / len(v)
|
return np.sum(np.abs(error)) / len(v)
|
||||||
|
|
||||||
@interruptable
|
@interruptable
|
||||||
@@ -94,7 +99,8 @@ class ClassicalAutoencoder(AAutoencoder):
|
|||||||
patience: int,
|
patience: int,
|
||||||
display_loss: bool = False) -> list[float]:
|
display_loss: bool = False) -> list[float]:
|
||||||
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
|
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
|
||||||
self.losses = [self.loss(data_set)]
|
if len(self.losses) == 0:
|
||||||
|
self.losses = [self.loss(data_set)]
|
||||||
epoch = 0
|
epoch = 0
|
||||||
no_improv = 0
|
no_improv = 0
|
||||||
prev_error = self.losses[0]
|
prev_error = self.losses[0]
|
||||||
@@ -109,7 +115,7 @@ class ClassicalAutoencoder(AAutoencoder):
|
|||||||
error += self.train(x)
|
error += self.train(x)
|
||||||
error /= len(data_set)
|
error /= len(data_set)
|
||||||
derror = prev_error - error
|
derror = prev_error - error
|
||||||
if derror <= 0 or abs(derror) < 1e-4:
|
if abs(derror) < 1e-4:
|
||||||
no_improv += 1
|
no_improv += 1
|
||||||
else:
|
else:
|
||||||
no_improv = 0
|
no_improv = 0
|
||||||
@@ -165,13 +171,18 @@ class VariationalAutoencoder(AAutoencoder):
|
|||||||
return recon_loss, kl_loss
|
return recon_loss, kl_loss
|
||||||
|
|
||||||
def train(self, v: np.ndarray) -> tuple[float, float]:
|
def train(self, v: np.ndarray) -> tuple[float, float]:
|
||||||
out, _ = self.forward(v)
|
out, _ = self.forward(
|
||||||
|
self.noise.forward(v)
|
||||||
|
)
|
||||||
error = out - v
|
error = out - v
|
||||||
self.encoder.backprop(
|
self.encoder.back(
|
||||||
self.sampler.backprop(
|
self.sampler.back(
|
||||||
self.decoder.backprop(error)
|
self.decoder.back(error)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self.encoder.backprop()
|
||||||
|
self.sampler.backprop()
|
||||||
|
self.decoder.backprop()
|
||||||
return np.mean(error ** 2), self.sampler.DKL()
|
return np.mean(error ** 2), self.sampler.DKL()
|
||||||
|
|
||||||
@interruptable
|
@interruptable
|
||||||
@@ -181,9 +192,10 @@ class VariationalAutoencoder(AAutoencoder):
|
|||||||
patience: int,
|
patience: int,
|
||||||
display_loss: bool = False) -> list[float]:
|
display_loss: bool = False) -> list[float]:
|
||||||
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
|
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
|
||||||
recon_0, kl_0 = self.loss(data_set)
|
if len(self.recon_losses) == 0:
|
||||||
self.recon_losses = [recon_0]
|
recon_0, kl_0 = self.loss(data_set)
|
||||||
self.KL_losses = [kl_0]
|
self.recon_losses = [recon_0]
|
||||||
|
self.KL_losses = [kl_0]
|
||||||
epoch = 0
|
epoch = 0
|
||||||
no_improv = 0
|
no_improv = 0
|
||||||
prev_loss = self.recon_losses[0] + self.KL_losses[0]
|
prev_loss = self.recon_losses[0] + self.KL_losses[0]
|
||||||
@@ -221,12 +233,78 @@ class VariationalAutoencoder(AAutoencoder):
|
|||||||
code = self.encoder.forward(v)
|
code = self.encoder.forward(v)
|
||||||
sample = self.sampler.forward(code)
|
sample = self.sampler.forward(code)
|
||||||
out = self.decoder.forward(sample)
|
out = self.decoder.forward(sample)
|
||||||
return out, code
|
return out, sample
|
||||||
|
|
||||||
def encode(self, v: np.ndarray) -> np.ndarray:
|
def encode(self, v: np.ndarray) -> np.ndarray:
|
||||||
return self.sampler.forward(
|
return self.sampler.forward(
|
||||||
self.encoder.forward(v)
|
self.encoder.forward(v)
|
||||||
)
|
)
|
||||||
|
|
||||||
def decode(self, v: np.ndarray) -> np.ndarray:
|
def decode(self, v: np.ndarray) -> np.ndarray:
|
||||||
return self.decoder.forward(v)
|
return self.decoder.forward(v)
|
||||||
|
|
||||||
|
|
||||||
|
class Label:
|
||||||
|
def __init__(self,
|
||||||
|
name: str,
|
||||||
|
embedding_size: int,
|
||||||
|
N=100):
|
||||||
|
self.name = name
|
||||||
|
self.embedding_size = embedding_size
|
||||||
|
self.N = N
|
||||||
|
self.idx = 0
|
||||||
|
self.history = np.zeros((self.N, embedding_size))
|
||||||
|
|
||||||
|
def observe(self, code: np.ndarray):
|
||||||
|
if self.idx < self.N:
|
||||||
|
self.history[self.idx] = code
|
||||||
|
self.idx += 1
|
||||||
|
else:
|
||||||
|
diffs = np.linalg.norm(self.history - code, axis=0)
|
||||||
|
idx = np.argmin(diffs)
|
||||||
|
self.history[idx] = (self.history[idx] + code) / 2
|
||||||
|
|
||||||
|
def p(self, x: np.ndarray):
|
||||||
|
return np.mean(
|
||||||
|
np.exp(-np.abs(self.history - x))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LabelingVAE(VariationalAutoencoder):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.labels: list[Label] = []
|
||||||
|
self.labels_idxs: dict[str, int] = {}
|
||||||
|
|
||||||
|
def learn_labels(self, data: np.ndarray, labels: list[list[str]]):
|
||||||
|
self.labels.clear()
|
||||||
|
self.labels_idxs.clear()
|
||||||
|
for x_i, labels_i in zip(data, labels):
|
||||||
|
y_i = self.encode(x_i)
|
||||||
|
for c in labels_i:
|
||||||
|
idx = self.labels_idxs.get(c, None)
|
||||||
|
if idx is None:
|
||||||
|
label = Label(c, self.encoder.out_size)
|
||||||
|
self.labels.append(label)
|
||||||
|
self.labels_idxs[c] = len(self.labels) - 1
|
||||||
|
else:
|
||||||
|
label = self.labels[idx]
|
||||||
|
label.observe(y_i)
|
||||||
|
|
||||||
|
def label(self, x: np.ndarray):
|
||||||
|
probs = {}
|
||||||
|
total = 0
|
||||||
|
code = self.encode(x)
|
||||||
|
for label in self.labels:
|
||||||
|
p = label.p(code)
|
||||||
|
probs[label.name] = p
|
||||||
|
total += p
|
||||||
|
for k in probs:
|
||||||
|
probs[k] = float(probs[k] / total)
|
||||||
|
return dict(
|
||||||
|
sorted(
|
||||||
|
probs.items(),
|
||||||
|
key=lambda item: item[1],
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ class NNLayer:
|
|||||||
self.input = None
|
self.input = None
|
||||||
self.output = None
|
self.output = None
|
||||||
self.output_linear = None
|
self.output_linear = None
|
||||||
|
self.error = None
|
||||||
self.activation_func = activation_func
|
self.activation_func = activation_func
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@@ -28,14 +29,15 @@ class NNLayer:
|
|||||||
)
|
)
|
||||||
return self.output
|
return self.output
|
||||||
|
|
||||||
def backprop(self, error: np.ndarray) -> np.ndarray:
|
def back(self, error: np.ndarray) -> np.ndarray:
|
||||||
error *= self.activation_func.d(self.output_linear)
|
self.error = error * self.activation_func.d(self.output_linear)
|
||||||
ret = self.W @ error
|
return self.W @ self.error
|
||||||
dW = np.outer(self.input, error) * self.lr
|
|
||||||
dB = error * self.lr
|
def backprop(self) -> np.ndarray:
|
||||||
|
dW = np.outer(self.input, self.error) * self.lr
|
||||||
|
dB = self.error * self.lr
|
||||||
self.W -= dW
|
self.W -= dW
|
||||||
self.B -= dB
|
self.B -= dB
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
class SampleLayer:
|
class SampleLayer:
|
||||||
@@ -66,13 +68,17 @@ class SampleLayer:
|
|||||||
self.eps = np.random.normal(0, 1, self.mean.shape)
|
self.eps = np.random.normal(0, 1, self.mean.shape)
|
||||||
return 0.5 * self.eps * self.std + self.mean
|
return 0.5 * self.eps * self.std + self.mean
|
||||||
|
|
||||||
def backprop(self, error: np.ndarray) -> np.ndarray:
|
def back(self, error: np.ndarray) -> np.ndarray:
|
||||||
dmean = error + self.mean
|
dmean = error + self.mean
|
||||||
dstd = error * self.eps + 0.5 * (np.exp(self.logvar) - 1)
|
dstd = error * self.eps + 0.5 * (np.exp(self.logvar) - 1)
|
||||||
mean_error = self.mean_nn.backprop(dmean)
|
mean_error = self.mean_nn.back(dmean)
|
||||||
logvar_error = self.std_nn.backprop(dstd * self.std)
|
logvar_error = self.std_nn.back(dstd * self.std)
|
||||||
return mean_error + logvar_error
|
return mean_error + logvar_error
|
||||||
|
|
||||||
|
def backprop(self):
|
||||||
|
self.mean_nn.backprop()
|
||||||
|
self.std_nn.backprop()
|
||||||
|
|
||||||
|
|
||||||
class DeepNNLayer:
|
class DeepNNLayer:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@@ -100,7 +106,21 @@ class DeepNNLayer:
|
|||||||
v = layer.forward(v)
|
v = layer.forward(v)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
def backprop(self, error: np.ndarray) -> np.ndarray:
|
def back(self, error: np.ndarray):
|
||||||
for layer in self.layers[::-1]:
|
for layer in self.layers[::-1]:
|
||||||
error = layer.backprop(error)
|
error = layer.back(error)
|
||||||
return error
|
return error
|
||||||
|
|
||||||
|
def backprop(self) -> np.ndarray:
|
||||||
|
for layer in self.layers:
|
||||||
|
layer.backprop()
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseLayer:
|
||||||
|
def __init__(self, amount=0.1):
|
||||||
|
self.amount = amount
|
||||||
|
|
||||||
|
def forward(self, v: np.ndarray):
|
||||||
|
if self.amount == 0:
|
||||||
|
return v
|
||||||
|
return v + np.random.normal(0, self.amount, v.shape)
|
||||||
|
|||||||
Reference in New Issue
Block a user