refactor: code de-dup __str__ method
This commit is contained in:
@@ -123,7 +123,7 @@ def mnist_test(model: str | AAutoencoder | LabelingVAE):
|
||||
idx = np.random.randint(0, len(x_test))
|
||||
example: np.ndarray = x_test[idx]
|
||||
labels_train = [str(int(i)) for i in y_train]
|
||||
if isinstance(model, LabelingVAE):
|
||||
if isinstance(autoencoder, LabelingVAE):
|
||||
autoencoder.learn_labels(x_train, labels_train)
|
||||
labeling_accuracy(autoencoder, x_test, y_test)
|
||||
res = autoencoder.label(example)
|
||||
|
||||
Reference in New Issue
Block a user