From 4cc349c22c8012c9109e14238bbf42f6ab00c617 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Tue, 14 Apr 2026 20:53:13 +0200 Subject: [PATCH] feat: simple distances instead of std+mean for labeling --- examples/mnist_test.py | 8 +++---- src/easyvae/autoencoder.py | 46 ++++++++++++++++---------------------- 2 files changed, 23 insertions(+), 31 deletions(-) diff --git a/examples/mnist_test.py b/examples/mnist_test.py index 236f507..883128a 100644 --- a/examples/mnist_test.py +++ b/examples/mnist_test.py @@ -69,7 +69,6 @@ def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,): ) plt.colorbar(scatter) plt.grid(True) - plt.show() def plot_random_reconstruction( @@ -107,14 +106,15 @@ def mnist_test(model: str | AAutoencoder | LabelingVAE): print(autoencoder) idx = np.random.randint(0, len(x_test)) example: np.ndarray = x_test[idx] - y_train = [str(int(i)) for i in y_train] - autoencoder.learn_labels(x_train, y_train, 5) - res = autoencoder.label(x_train[idx]) + labels_train = [str(int(i)) for i in y_train] + autoencoder.learn_labels(x_train, labels_train) + res = autoencoder.label(example) for k, v in res.items(): print(f"{k} => {v}") plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx]) if autoencoder.space_dim == 2: plot_mnist_latent_space(autoencoder, x_test, y_test) + plt.show() if __name__ == "__main__": diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index 3c5dd30..e6c035e 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -249,20 +249,16 @@ class Label: self.name = name self.embedding_size = embedding_size self.history = [] - self.mean = np.zeros(embedding_size) - self.std = np.zeros(embedding_size) def observe(self, code: np.ndarray): self.history.append(code) def cache(self): - history = np.array(self.history) - self.mean = np.mean(history, axis=0) - self.std = np.std(history, axis=0, mean=self.mean) + self.history_np = np.array(self.history) def p(self, x: np.ndarray): return np.mean( - np.exp(-(x - self.mean) ** 2 / (2 * self.std)) / (self.std * SQRT_2PI) # noqa + np.exp(-np.abs(self.history_np - x)) ) @@ -272,33 +268,29 @@ class LabelingVAE(VariationalAutoencoder): self.labels: list[Label] = [] self.labels_idxs: dict[str, int] = {} - def learn_labels(self, data: np.ndarray, labels: list[list[str]], epoch=5): + def learn_labels(self, data: np.ndarray, labels: list[list[str]]): self.labels.clear() self.labels_idxs.clear() - for _ in range(epoch): - for x_i, labels_i in zip(data, labels): - y_i = self.encode(x_i) - for c in labels_i: - idx = self.labels_idxs.get(c, None) - if idx is None: - label = Label(c, self.encoder.out_size) - self.labels.append(label) - self.labels_idxs[c] = len(self.labels) - 1 - else: - label = self.labels[idx] - label.observe(y_i) - for label in self.labels: - label.cache() + for x_i, labels_i in zip(data, labels): + y_i = self.encode(x_i) + for c in labels_i: + idx = self.labels_idxs.get(c, None) + if idx is None: + label = Label(c, self.encoder.out_size) + self.labels.append(label) + self.labels_idxs[c] = len(self.labels) - 1 + else: + label = self.labels[idx] + label.observe(y_i) + for label in self.labels: + label.cache() - def label(self, x: np.ndarray, samples=10): - y = np.zeros((samples, self.encoder.out_size)) - for i in range(samples): - y[i] = self.encode(x) - y = np.mean(y, axis=0) + def label(self, x: np.ndarray): probs = {} total = 0 + code = self.encode(x) for label in self.labels: - p = label.p(y) + p = label.p(code) probs[label.name] = p total += p for k in probs: