refactor: code de-dup __str__ method

This commit is contained in:
Lenoctambule
2026-04-17 19:53:58 +02:00
parent 6eaaa43285
commit 583fc796f6
2 changed files with 12 additions and 21 deletions

View File

@@ -123,7 +123,7 @@ def mnist_test(model: str | AAutoencoder | LabelingVAE):
idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx]
labels_train = [str(int(i)) for i in y_train]
if isinstance(model, LabelingVAE):
if isinstance(autoencoder, LabelingVAE):
autoencoder.learn_labels(x_train, labels_train)
labeling_accuracy(autoencoder, x_test, y_test)
res = autoencoder.label(example)