From 251d66a62521b07a6a5a1a46112fd2bf64148d35 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Wed, 15 Apr 2026 18:13:24 +0200 Subject: [PATCH 1/4] feat: test label accuracy in mnist example --- examples/mnist_test.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/examples/mnist_test.py b/examples/mnist_test.py index 883128a..3d89b1b 100644 --- a/examples/mnist_test.py +++ b/examples/mnist_test.py @@ -8,6 +8,7 @@ from easyvae.autoencoder import ( # noqa AAutoencoder ) from easyvae.activations import LeakyReLU +from tqdm import tqdm def load_mnist() -> list[np.ndarray]: @@ -90,6 +91,21 @@ def plot_random_reconstruction( print(f'{code.tolist()}') +def labeling_accuracy(autoencoder: LabelingVAE, x_test, y_test): + accuracy = 0 + for x, y in tqdm( + zip(x_test, y_test), + desc="Testing labeling", + total=len(x_test) + ): + res = autoencoder.label(x) + res = list(res.items())[0][0] + if res == str(int(y)): + accuracy += 1 + accuracy /= len(y_test) + print(f"Accuracy : {accuracy * 100:.2f}%") + + def mnist_test(model: str | AAutoencoder | LabelingVAE): x_train, y_train, x_test, y_test = load_mnist() in_len = x_train[0].shape[0] * x_train[0].shape[0] @@ -107,10 +123,12 @@ def mnist_test(model: str | AAutoencoder | LabelingVAE): idx = np.random.randint(0, len(x_test)) example: np.ndarray = x_test[idx] labels_train = [str(int(i)) for i in y_train] - autoencoder.learn_labels(x_train, labels_train) - res = autoencoder.label(example) - for k, v in res.items(): - print(f"{k} => {v}") + if isinstance(model, LabelingVAE): + autoencoder.learn_labels(x_train, labels_train) + labeling_accuracy(autoencoder, x_test, y_test) + res = autoencoder.label(example) + for k, v in res.items(): + print(f"{k} => {v}") plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx]) if autoencoder.space_dim == 2: plot_mnist_latent_space(autoencoder, x_test, y_test) From 65c6d3bbee294e6f75aa7e528bfb71f7d3e7bf4a Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Fri, 17 Apr 2026 02:31:33 +0200 Subject: [PATCH 2/4] fix: wrong axis typo in Label's observe method --- src/easyvae/autoencoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index 1c0c978..4209976 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -260,7 +260,7 @@ class Label: self.history[self.idx] = code self.idx += 1 else: - diffs = np.linalg.norm(self.history - code, axis=0) + diffs = np.linalg.norm(self.history - code, axis=1) idx = np.argmin(diffs) self.history[idx] = (self.history[idx] + code) / 2 From 6eaaa4328593cc51105220f456610a29f3ffc522 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Fri, 17 Apr 2026 04:52:23 +0200 Subject: [PATCH 3/4] fix: use inverse func in p method in Label class --- src/easyvae/autoencoder.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index 4209976..0d98896 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -103,7 +103,7 @@ class ClassicalAutoencoder(AAutoencoder): self.losses = [self.loss(data_set)] epoch = 0 no_improv = 0 - prev_error = self.losses[0] + prev_error = self.losses[-1] with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar: while True: lbar.set_description( @@ -198,7 +198,7 @@ class VariationalAutoencoder(AAutoencoder): self.KL_losses = [kl_0] epoch = 0 no_improv = 0 - prev_loss = self.recon_losses[0] + self.KL_losses[0] + prev_loss = self.recon_losses[-1] + self.KL_losses[-1] with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar: while True: lbar.set_description( @@ -265,9 +265,7 @@ class Label: self.history[idx] = (self.history[idx] + code) / 2 def p(self, x: np.ndarray): - return np.mean( - np.exp(-np.abs(self.history - x)) - ) + return 1 / (1e-4 + np.mean(np.abs(self.history - x))) class LabelingVAE(VariationalAutoencoder): From 583fc796f69c74d2e974635f721c061a6eaceb12 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Fri, 17 Apr 2026 19:53:58 +0200 Subject: [PATCH 4/4] refactor: code de-dup __str__ method --- examples/mnist_test.py | 2 +- src/easyvae/autoencoder.py | 31 +++++++++++-------------------- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/examples/mnist_test.py b/examples/mnist_test.py index 3d89b1b..9c4fdc4 100644 --- a/examples/mnist_test.py +++ b/examples/mnist_test.py @@ -123,7 +123,7 @@ def mnist_test(model: str | AAutoencoder | LabelingVAE): idx = np.random.randint(0, len(x_test)) example: np.ndarray = x_test[idx] labels_train = [str(int(i)) for i in y_train] - if isinstance(model, LabelingVAE): + if isinstance(autoencoder, LabelingVAE): autoencoder.learn_labels(x_train, labels_train) labeling_accuracy(autoencoder, x_test, y_test) res = autoencoder.label(example) diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index 0d98896..1d8b9c1 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -35,7 +35,7 @@ class AAutoencoder(ABC): path = path.removesuffix('.npy') np.save(path, self) - def load(path: str) -> 'ClassicalAutoencoder': + def load(path: str) -> 'AAutoencoder': path = path.removesuffix('.npy') + '.npy' data = np.load(path, allow_pickle=True) return data.item() @@ -56,6 +56,16 @@ class AAutoencoder(ABC): def train_dataset(self, *args, **kwargs) -> list[float]: pass + def __str__(self): + return "\n".join(( + f"Type: {self.__class__.__name__}", + "Encoder:", + f"{self.encoder}", + "Decoder:", + f"{self.decoder}" + ) + ) + class ClassicalAutoencoder(AAutoencoder): plotter_cls = CAPlotter @@ -64,16 +74,6 @@ class ClassicalAutoencoder(AAutoencoder): super().__init__(*args, **kwargs) self.losses = [] - def __str__(self): - return "\n".join(( - f"Type: {__class__.__name__}", - "Encoder:", - f"{self.encoder}", - "Decoder:", - f"{self.decoder}" - ) - ) - def loss(self, data_set: list[np.ndarray]) -> float: loss = 0 for x in data_set: @@ -149,15 +149,6 @@ class VariationalAutoencoder(AAutoencoder): self.KL_losses = [] self.recon_losses = [] - def __str__(self): - return "\n".join(( - f"Type: {__class__.__name__}", - "Encoder:", - f"{self.encoder}", - "Decoder:", - f"{self.decoder}" - )) - def loss(self, data_set: list[np.ndarray]) -> float: kl_loss = 0 recon_loss = 0