From b1dc34e699a1d4fa296422e21ec40ae9efb993d9 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Tue, 14 Apr 2026 14:12:22 +0200 Subject: [PATCH] feat: post training and online labeling VAE class --- examples/mnist_test.py | 4 +-- src/easyvae/autoencoder.py | 64 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/examples/mnist_test.py b/examples/mnist_test.py index 160cb16..7082e1b 100644 --- a/examples/mnist_test.py +++ b/examples/mnist_test.py @@ -4,6 +4,7 @@ import os from easyvae.autoencoder import ( # noqa VariationalAutoencoder, ClassicalAutoencoder, + LabelingVAE, AAutoencoder ) from easyvae.activations import LeakyReLU @@ -26,10 +27,9 @@ def mnist_train( max_epoch: int, patience: int, cls: type[AAutoencoder],) -> AAutoencoder: - x_train, _, x_test, _ = load_mnist() + x_train, _, _, _ = load_mnist() in_len = x_train[0].shape[0] * x_train[0].shape[0] x_train.resize(x_train.shape[0], in_len) - x_test.resize(x_test.shape[0], in_len) x_train = x_train / 255 if os.path.exists(filename): autoencoder = cls.load(filename) diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index 1880781..490b6c5 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -7,6 +7,7 @@ from .utils import interruptable from abc import ABC, abstractmethod LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿'] +SQRT_2PI = np.sqrt(2 * np.pi) class AAutoencoder(ABC): @@ -236,8 +237,67 @@ class VariationalAutoencoder(AAutoencoder): def encode(self, v: np.ndarray) -> np.ndarray: return self.sampler.forward( - self.encoder.forward(v) - ) + self.encoder.forward(v) + ) def decode(self, v: np.ndarray) -> np.ndarray: return self.decoder.forward(v) + + +class Label: + def __init__(self, name: str, embedding_size: int): + self.name = name + self.embedding_size = embedding_size + self.history = [] + self.mean = np.zeros(embedding_size) + self.std = np.zeros(embedding_size) + + def observe(self, code: np.ndarray): + self.history.append(code) + + def cache(self): + history = np.array(self.history) + self.mean = np.mean(history, axis=0) + self.std = np.std(history, axis=0, mean=self.mean) + + def p(self, x: np.ndarray): + return np.mean( + np.exp(-(x - self.mean) ** 2 / (2 * self.std)) / (self.std * SQRT_2PI) # noqa + ) + + +class LabelingVAE(VariationalAutoencoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.labels: list[Label] = [] + self.labels_idxs: dict[str, int] = {} + + def learn_labels(self, data: np.ndarray, labels: list[list[str]], epoch=5): + self.labels.clear() + self.labels_idxs.clear() + for _ in range(epoch): + for x_i, labels_i in zip(data, labels): + y_i = self.encode(x_i) + for c in labels_i: + idx = self.labels_idxs.get(c, None) + if idx is None: + label = Label(c, self.encoder.out_size) + self.labels.append(label) + self.labels_idxs[c] = len(self.labels) - 1 + else: + label = self.labels[idx] + label.observe(y_i) + for label in self.labels: + label.cache() + + def label(self, x: np.ndarray): + y = self.encode(x) + probs = {} + total = 0 + for label in self.labels: + p = label.p(y) + probs[label.name] = p + total += p + for k in probs: + probs[k] = float(probs[k] / total) + return dict(sorted(probs.items()))