feat: simple distances instead of std+mean for labeling

This commit is contained in:
Lenoctambule
2026-04-14 20:53:13 +02:00
parent b635bf0467
commit 4cc349c22c
2 changed files with 23 additions and 31 deletions

View File

@@ -69,7 +69,6 @@ def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,):
)
plt.colorbar(scatter)
plt.grid(True)
plt.show()
def plot_random_reconstruction(
@@ -107,14 +106,15 @@ def mnist_test(model: str | AAutoencoder | LabelingVAE):
print(autoencoder)
idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx]
y_train = [str(int(i)) for i in y_train]
autoencoder.learn_labels(x_train, y_train, 5)
res = autoencoder.label(x_train[idx])
labels_train = [str(int(i)) for i in y_train]
autoencoder.learn_labels(x_train, labels_train)
res = autoencoder.label(example)
for k, v in res.items():
print(f"{k} => {v}")
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])
if autoencoder.space_dim == 2:
plot_mnist_latent_space(autoencoder, x_test, y_test)
plt.show()
if __name__ == "__main__":