feat: post training and online labeling VAE class
This commit is contained in:
@@ -4,6 +4,7 @@ import os
|
|||||||
from easyvae.autoencoder import ( # noqa
|
from easyvae.autoencoder import ( # noqa
|
||||||
VariationalAutoencoder,
|
VariationalAutoencoder,
|
||||||
ClassicalAutoencoder,
|
ClassicalAutoencoder,
|
||||||
|
LabelingVAE,
|
||||||
AAutoencoder
|
AAutoencoder
|
||||||
)
|
)
|
||||||
from easyvae.activations import LeakyReLU
|
from easyvae.activations import LeakyReLU
|
||||||
@@ -26,10 +27,9 @@ def mnist_train(
|
|||||||
max_epoch: int,
|
max_epoch: int,
|
||||||
patience: int,
|
patience: int,
|
||||||
cls: type[AAutoencoder],) -> AAutoencoder:
|
cls: type[AAutoencoder],) -> AAutoencoder:
|
||||||
x_train, _, x_test, _ = load_mnist()
|
x_train, _, _, _ = load_mnist()
|
||||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||||
x_train.resize(x_train.shape[0], in_len)
|
x_train.resize(x_train.shape[0], in_len)
|
||||||
x_test.resize(x_test.shape[0], in_len)
|
|
||||||
x_train = x_train / 255
|
x_train = x_train / 255
|
||||||
if os.path.exists(filename):
|
if os.path.exists(filename):
|
||||||
autoencoder = cls.load(filename)
|
autoencoder = cls.load(filename)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from .utils import interruptable
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
||||||
|
SQRT_2PI = np.sqrt(2 * np.pi)
|
||||||
|
|
||||||
|
|
||||||
class AAutoencoder(ABC):
|
class AAutoencoder(ABC):
|
||||||
@@ -241,3 +242,62 @@ class VariationalAutoencoder(AAutoencoder):
|
|||||||
|
|
||||||
def decode(self, v: np.ndarray) -> np.ndarray:
|
def decode(self, v: np.ndarray) -> np.ndarray:
|
||||||
return self.decoder.forward(v)
|
return self.decoder.forward(v)
|
||||||
|
|
||||||
|
|
||||||
|
class Label:
|
||||||
|
def __init__(self, name: str, embedding_size: int):
|
||||||
|
self.name = name
|
||||||
|
self.embedding_size = embedding_size
|
||||||
|
self.history = []
|
||||||
|
self.mean = np.zeros(embedding_size)
|
||||||
|
self.std = np.zeros(embedding_size)
|
||||||
|
|
||||||
|
def observe(self, code: np.ndarray):
|
||||||
|
self.history.append(code)
|
||||||
|
|
||||||
|
def cache(self):
|
||||||
|
history = np.array(self.history)
|
||||||
|
self.mean = np.mean(history, axis=0)
|
||||||
|
self.std = np.std(history, axis=0, mean=self.mean)
|
||||||
|
|
||||||
|
def p(self, x: np.ndarray):
|
||||||
|
return np.mean(
|
||||||
|
np.exp(-(x - self.mean) ** 2 / (2 * self.std)) / (self.std * SQRT_2PI) # noqa
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LabelingVAE(VariationalAutoencoder):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.labels: list[Label] = []
|
||||||
|
self.labels_idxs: dict[str, int] = {}
|
||||||
|
|
||||||
|
def learn_labels(self, data: np.ndarray, labels: list[list[str]], epoch=5):
|
||||||
|
self.labels.clear()
|
||||||
|
self.labels_idxs.clear()
|
||||||
|
for _ in range(epoch):
|
||||||
|
for x_i, labels_i in zip(data, labels):
|
||||||
|
y_i = self.encode(x_i)
|
||||||
|
for c in labels_i:
|
||||||
|
idx = self.labels_idxs.get(c, None)
|
||||||
|
if idx is None:
|
||||||
|
label = Label(c, self.encoder.out_size)
|
||||||
|
self.labels.append(label)
|
||||||
|
self.labels_idxs[c] = len(self.labels) - 1
|
||||||
|
else:
|
||||||
|
label = self.labels[idx]
|
||||||
|
label.observe(y_i)
|
||||||
|
for label in self.labels:
|
||||||
|
label.cache()
|
||||||
|
|
||||||
|
def label(self, x: np.ndarray):
|
||||||
|
y = self.encode(x)
|
||||||
|
probs = {}
|
||||||
|
total = 0
|
||||||
|
for label in self.labels:
|
||||||
|
p = label.p(y)
|
||||||
|
probs[label.name] = p
|
||||||
|
total += p
|
||||||
|
for k in probs:
|
||||||
|
probs[k] = float(probs[k] / total)
|
||||||
|
return dict(sorted(probs.items()))
|
||||||
|
|||||||
Reference in New Issue
Block a user