feat: post training and online labeling VAE class
This commit is contained in:
@@ -4,6 +4,7 @@ import os
|
||||
from easyvae.autoencoder import ( # noqa
|
||||
VariationalAutoencoder,
|
||||
ClassicalAutoencoder,
|
||||
LabelingVAE,
|
||||
AAutoencoder
|
||||
)
|
||||
from easyvae.activations import LeakyReLU
|
||||
@@ -26,10 +27,9 @@ def mnist_train(
|
||||
max_epoch: int,
|
||||
patience: int,
|
||||
cls: type[AAutoencoder],) -> AAutoencoder:
|
||||
x_train, _, x_test, _ = load_mnist()
|
||||
x_train, _, _, _ = load_mnist()
|
||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||
x_train.resize(x_train.shape[0], in_len)
|
||||
x_test.resize(x_test.shape[0], in_len)
|
||||
x_train = x_train / 255
|
||||
if os.path.exists(filename):
|
||||
autoencoder = cls.load(filename)
|
||||
|
||||
@@ -7,6 +7,7 @@ from .utils import interruptable
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
||||
SQRT_2PI = np.sqrt(2 * np.pi)
|
||||
|
||||
|
||||
class AAutoencoder(ABC):
|
||||
@@ -236,8 +237,67 @@ class VariationalAutoencoder(AAutoencoder):
|
||||
|
||||
def encode(self, v: np.ndarray) -> np.ndarray:
|
||||
return self.sampler.forward(
|
||||
self.encoder.forward(v)
|
||||
)
|
||||
self.encoder.forward(v)
|
||||
)
|
||||
|
||||
def decode(self, v: np.ndarray) -> np.ndarray:
|
||||
return self.decoder.forward(v)
|
||||
|
||||
|
||||
class Label:
|
||||
def __init__(self, name: str, embedding_size: int):
|
||||
self.name = name
|
||||
self.embedding_size = embedding_size
|
||||
self.history = []
|
||||
self.mean = np.zeros(embedding_size)
|
||||
self.std = np.zeros(embedding_size)
|
||||
|
||||
def observe(self, code: np.ndarray):
|
||||
self.history.append(code)
|
||||
|
||||
def cache(self):
|
||||
history = np.array(self.history)
|
||||
self.mean = np.mean(history, axis=0)
|
||||
self.std = np.std(history, axis=0, mean=self.mean)
|
||||
|
||||
def p(self, x: np.ndarray):
|
||||
return np.mean(
|
||||
np.exp(-(x - self.mean) ** 2 / (2 * self.std)) / (self.std * SQRT_2PI) # noqa
|
||||
)
|
||||
|
||||
|
||||
class LabelingVAE(VariationalAutoencoder):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.labels: list[Label] = []
|
||||
self.labels_idxs: dict[str, int] = {}
|
||||
|
||||
def learn_labels(self, data: np.ndarray, labels: list[list[str]], epoch=5):
|
||||
self.labels.clear()
|
||||
self.labels_idxs.clear()
|
||||
for _ in range(epoch):
|
||||
for x_i, labels_i in zip(data, labels):
|
||||
y_i = self.encode(x_i)
|
||||
for c in labels_i:
|
||||
idx = self.labels_idxs.get(c, None)
|
||||
if idx is None:
|
||||
label = Label(c, self.encoder.out_size)
|
||||
self.labels.append(label)
|
||||
self.labels_idxs[c] = len(self.labels) - 1
|
||||
else:
|
||||
label = self.labels[idx]
|
||||
label.observe(y_i)
|
||||
for label in self.labels:
|
||||
label.cache()
|
||||
|
||||
def label(self, x: np.ndarray):
|
||||
y = self.encode(x)
|
||||
probs = {}
|
||||
total = 0
|
||||
for label in self.labels:
|
||||
p = label.p(y)
|
||||
probs[label.name] = p
|
||||
total += p
|
||||
for k in probs:
|
||||
probs[k] = float(probs[k] / total)
|
||||
return dict(sorted(probs.items()))
|
||||
|
||||
Reference in New Issue
Block a user