feat: simple distances instead of std+mean for labeling

This commit is contained in:
Lenoctambule
2026-04-14 20:53:13 +02:00
parent b635bf0467
commit 4cc349c22c
2 changed files with 23 additions and 31 deletions

View File

@@ -69,7 +69,6 @@ def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,):
) )
plt.colorbar(scatter) plt.colorbar(scatter)
plt.grid(True) plt.grid(True)
plt.show()
def plot_random_reconstruction( def plot_random_reconstruction(
@@ -107,14 +106,15 @@ def mnist_test(model: str | AAutoencoder | LabelingVAE):
print(autoencoder) print(autoencoder)
idx = np.random.randint(0, len(x_test)) idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx] example: np.ndarray = x_test[idx]
y_train = [str(int(i)) for i in y_train] labels_train = [str(int(i)) for i in y_train]
autoencoder.learn_labels(x_train, y_train, 5) autoencoder.learn_labels(x_train, labels_train)
res = autoencoder.label(x_train[idx]) res = autoencoder.label(example)
for k, v in res.items(): for k, v in res.items():
print(f"{k} => {v}") print(f"{k} => {v}")
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx]) plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])
if autoencoder.space_dim == 2: if autoencoder.space_dim == 2:
plot_mnist_latent_space(autoencoder, x_test, y_test) plot_mnist_latent_space(autoencoder, x_test, y_test)
plt.show()
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -249,20 +249,16 @@ class Label:
self.name = name self.name = name
self.embedding_size = embedding_size self.embedding_size = embedding_size
self.history = [] self.history = []
self.mean = np.zeros(embedding_size)
self.std = np.zeros(embedding_size)
def observe(self, code: np.ndarray): def observe(self, code: np.ndarray):
self.history.append(code) self.history.append(code)
def cache(self): def cache(self):
history = np.array(self.history) self.history_np = np.array(self.history)
self.mean = np.mean(history, axis=0)
self.std = np.std(history, axis=0, mean=self.mean)
def p(self, x: np.ndarray): def p(self, x: np.ndarray):
return np.mean( return np.mean(
np.exp(-(x - self.mean) ** 2 / (2 * self.std)) / (self.std * SQRT_2PI) # noqa np.exp(-np.abs(self.history_np - x))
) )
@@ -272,10 +268,9 @@ class LabelingVAE(VariationalAutoencoder):
self.labels: list[Label] = [] self.labels: list[Label] = []
self.labels_idxs: dict[str, int] = {} self.labels_idxs: dict[str, int] = {}
def learn_labels(self, data: np.ndarray, labels: list[list[str]], epoch=5): def learn_labels(self, data: np.ndarray, labels: list[list[str]]):
self.labels.clear() self.labels.clear()
self.labels_idxs.clear() self.labels_idxs.clear()
for _ in range(epoch):
for x_i, labels_i in zip(data, labels): for x_i, labels_i in zip(data, labels):
y_i = self.encode(x_i) y_i = self.encode(x_i)
for c in labels_i: for c in labels_i:
@@ -290,15 +285,12 @@ class LabelingVAE(VariationalAutoencoder):
for label in self.labels: for label in self.labels:
label.cache() label.cache()
def label(self, x: np.ndarray, samples=10): def label(self, x: np.ndarray):
y = np.zeros((samples, self.encoder.out_size))
for i in range(samples):
y[i] = self.encode(x)
y = np.mean(y, axis=0)
probs = {} probs = {}
total = 0 total = 0
code = self.encode(x)
for label in self.labels: for label in self.labels:
p = label.p(y) p = label.p(code)
probs[label.name] = p probs[label.name] = p
total += p total += p
for k in probs: for k in probs: