feat: test label accuracy in mnist example

This commit is contained in:
Lenoctambule
2026-04-15 18:13:24 +02:00
parent 0f1c9f920b
commit 251d66a625

View File

@@ -8,6 +8,7 @@ from easyvae.autoencoder import ( # noqa
AAutoencoder AAutoencoder
) )
from easyvae.activations import LeakyReLU from easyvae.activations import LeakyReLU
from tqdm import tqdm
def load_mnist() -> list[np.ndarray]: def load_mnist() -> list[np.ndarray]:
@@ -90,6 +91,21 @@ def plot_random_reconstruction(
print(f'{code.tolist()}') print(f'{code.tolist()}')
def labeling_accuracy(autoencoder: LabelingVAE, x_test, y_test):
accuracy = 0
for x, y in tqdm(
zip(x_test, y_test),
desc="Testing labeling",
total=len(x_test)
):
res = autoencoder.label(x)
res = list(res.items())[0][0]
if res == str(int(y)):
accuracy += 1
accuracy /= len(y_test)
print(f"Accuracy : {accuracy * 100:.2f}%")
def mnist_test(model: str | AAutoencoder | LabelingVAE): def mnist_test(model: str | AAutoencoder | LabelingVAE):
x_train, y_train, x_test, y_test = load_mnist() x_train, y_train, x_test, y_test = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0] in_len = x_train[0].shape[0] * x_train[0].shape[0]
@@ -107,7 +123,9 @@ def mnist_test(model: str | AAutoencoder | LabelingVAE):
idx = np.random.randint(0, len(x_test)) idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx] example: np.ndarray = x_test[idx]
labels_train = [str(int(i)) for i in y_train] labels_train = [str(int(i)) for i in y_train]
if isinstance(model, LabelingVAE):
autoencoder.learn_labels(x_train, labels_train) autoencoder.learn_labels(x_train, labels_train)
labeling_accuracy(autoencoder, x_test, y_test)
res = autoencoder.label(example) res = autoencoder.label(example)
for k, v in res.items(): for k, v in res.items():
print(f"{k} => {v}") print(f"{k} => {v}")