feat: test label accuracy in mnist example
This commit is contained in:
@@ -8,6 +8,7 @@ from easyvae.autoencoder import ( # noqa
|
|||||||
AAutoencoder
|
AAutoencoder
|
||||||
)
|
)
|
||||||
from easyvae.activations import LeakyReLU
|
from easyvae.activations import LeakyReLU
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def load_mnist() -> list[np.ndarray]:
|
def load_mnist() -> list[np.ndarray]:
|
||||||
@@ -90,6 +91,21 @@ def plot_random_reconstruction(
|
|||||||
print(f'{code.tolist()}')
|
print(f'{code.tolist()}')
|
||||||
|
|
||||||
|
|
||||||
|
def labeling_accuracy(autoencoder: LabelingVAE, x_test, y_test):
|
||||||
|
accuracy = 0
|
||||||
|
for x, y in tqdm(
|
||||||
|
zip(x_test, y_test),
|
||||||
|
desc="Testing labeling",
|
||||||
|
total=len(x_test)
|
||||||
|
):
|
||||||
|
res = autoencoder.label(x)
|
||||||
|
res = list(res.items())[0][0]
|
||||||
|
if res == str(int(y)):
|
||||||
|
accuracy += 1
|
||||||
|
accuracy /= len(y_test)
|
||||||
|
print(f"Accuracy : {accuracy * 100:.2f}%")
|
||||||
|
|
||||||
|
|
||||||
def mnist_test(model: str | AAutoencoder | LabelingVAE):
|
def mnist_test(model: str | AAutoencoder | LabelingVAE):
|
||||||
x_train, y_train, x_test, y_test = load_mnist()
|
x_train, y_train, x_test, y_test = load_mnist()
|
||||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||||
@@ -107,7 +123,9 @@ def mnist_test(model: str | AAutoencoder | LabelingVAE):
|
|||||||
idx = np.random.randint(0, len(x_test))
|
idx = np.random.randint(0, len(x_test))
|
||||||
example: np.ndarray = x_test[idx]
|
example: np.ndarray = x_test[idx]
|
||||||
labels_train = [str(int(i)) for i in y_train]
|
labels_train = [str(int(i)) for i in y_train]
|
||||||
|
if isinstance(model, LabelingVAE):
|
||||||
autoencoder.learn_labels(x_train, labels_train)
|
autoencoder.learn_labels(x_train, labels_train)
|
||||||
|
labeling_accuracy(autoencoder, x_test, y_test)
|
||||||
res = autoencoder.label(example)
|
res = autoencoder.label(example)
|
||||||
for k, v in res.items():
|
for k, v in res.items():
|
||||||
print(f"{k} => {v}")
|
print(f"{k} => {v}")
|
||||||
|
|||||||
Reference in New Issue
Block a user