From 251d66a62521b07a6a5a1a46112fd2bf64148d35 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Wed, 15 Apr 2026 18:13:24 +0200 Subject: [PATCH] feat: test label accuracy in mnist example --- examples/mnist_test.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/examples/mnist_test.py b/examples/mnist_test.py index 883128a..3d89b1b 100644 --- a/examples/mnist_test.py +++ b/examples/mnist_test.py @@ -8,6 +8,7 @@ from easyvae.autoencoder import ( # noqa AAutoencoder ) from easyvae.activations import LeakyReLU +from tqdm import tqdm def load_mnist() -> list[np.ndarray]: @@ -90,6 +91,21 @@ def plot_random_reconstruction( print(f'{code.tolist()}') +def labeling_accuracy(autoencoder: LabelingVAE, x_test, y_test): + accuracy = 0 + for x, y in tqdm( + zip(x_test, y_test), + desc="Testing labeling", + total=len(x_test) + ): + res = autoencoder.label(x) + res = list(res.items())[0][0] + if res == str(int(y)): + accuracy += 1 + accuracy /= len(y_test) + print(f"Accuracy : {accuracy * 100:.2f}%") + + def mnist_test(model: str | AAutoencoder | LabelingVAE): x_train, y_train, x_test, y_test = load_mnist() in_len = x_train[0].shape[0] * x_train[0].shape[0] @@ -107,10 +123,12 @@ def mnist_test(model: str | AAutoencoder | LabelingVAE): idx = np.random.randint(0, len(x_test)) example: np.ndarray = x_test[idx] labels_train = [str(int(i)) for i in y_train] - autoencoder.learn_labels(x_train, labels_train) - res = autoencoder.label(example) - for k, v in res.items(): - print(f"{k} => {v}") + if isinstance(model, LabelingVAE): + autoencoder.learn_labels(x_train, labels_train) + labeling_accuracy(autoencoder, x_test, y_test) + res = autoencoder.label(example) + for k, v in res.items(): + print(f"{k} => {v}") plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx]) if autoencoder.space_dim == 2: plot_mnist_latent_space(autoencoder, x_test, y_test)