feat: post training and online labeling VAE class

This commit is contained in:
Lenoctambule
2026-04-14 14:12:22 +02:00
parent e6c1229a7e
commit b1dc34e699
2 changed files with 64 additions and 4 deletions

View File

@@ -4,6 +4,7 @@ import os
from easyvae.autoencoder import ( # noqa from easyvae.autoencoder import ( # noqa
VariationalAutoencoder, VariationalAutoencoder,
ClassicalAutoencoder, ClassicalAutoencoder,
LabelingVAE,
AAutoencoder AAutoencoder
) )
from easyvae.activations import LeakyReLU from easyvae.activations import LeakyReLU
@@ -26,10 +27,9 @@ def mnist_train(
max_epoch: int, max_epoch: int,
patience: int, patience: int,
cls: type[AAutoencoder],) -> AAutoencoder: cls: type[AAutoencoder],) -> AAutoencoder:
x_train, _, x_test, _ = load_mnist() x_train, _, _, _ = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0] in_len = x_train[0].shape[0] * x_train[0].shape[0]
x_train.resize(x_train.shape[0], in_len) x_train.resize(x_train.shape[0], in_len)
x_test.resize(x_test.shape[0], in_len)
x_train = x_train / 255 x_train = x_train / 255
if os.path.exists(filename): if os.path.exists(filename):
autoencoder = cls.load(filename) autoencoder = cls.load(filename)

View File

@@ -7,6 +7,7 @@ from .utils import interruptable
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
LOADER = ['', '', '', '', '', '', '', ''] LOADER = ['', '', '', '', '', '', '', '']
SQRT_2PI = np.sqrt(2 * np.pi)
class AAutoencoder(ABC): class AAutoencoder(ABC):
@@ -241,3 +242,62 @@ class VariationalAutoencoder(AAutoencoder):
def decode(self, v: np.ndarray) -> np.ndarray: def decode(self, v: np.ndarray) -> np.ndarray:
return self.decoder.forward(v) return self.decoder.forward(v)
class Label:
def __init__(self, name: str, embedding_size: int):
self.name = name
self.embedding_size = embedding_size
self.history = []
self.mean = np.zeros(embedding_size)
self.std = np.zeros(embedding_size)
def observe(self, code: np.ndarray):
self.history.append(code)
def cache(self):
history = np.array(self.history)
self.mean = np.mean(history, axis=0)
self.std = np.std(history, axis=0, mean=self.mean)
def p(self, x: np.ndarray):
return np.mean(
np.exp(-(x - self.mean) ** 2 / (2 * self.std)) / (self.std * SQRT_2PI) # noqa
)
class LabelingVAE(VariationalAutoencoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.labels: list[Label] = []
self.labels_idxs: dict[str, int] = {}
def learn_labels(self, data: np.ndarray, labels: list[list[str]], epoch=5):
self.labels.clear()
self.labels_idxs.clear()
for _ in range(epoch):
for x_i, labels_i in zip(data, labels):
y_i = self.encode(x_i)
for c in labels_i:
idx = self.labels_idxs.get(c, None)
if idx is None:
label = Label(c, self.encoder.out_size)
self.labels.append(label)
self.labels_idxs[c] = len(self.labels) - 1
else:
label = self.labels[idx]
label.observe(y_i)
for label in self.labels:
label.cache()
def label(self, x: np.ndarray):
y = self.encode(x)
probs = {}
total = 0
for label in self.labels:
p = label.p(y)
probs[label.name] = p
total += p
for k in probs:
probs[k] = float(probs[k] / total)
return dict(sorted(probs.items()))