feat: simple distances instead of std+mean for labeling
This commit is contained in:
@@ -69,7 +69,6 @@ def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,):
|
|||||||
)
|
)
|
||||||
plt.colorbar(scatter)
|
plt.colorbar(scatter)
|
||||||
plt.grid(True)
|
plt.grid(True)
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
def plot_random_reconstruction(
|
def plot_random_reconstruction(
|
||||||
@@ -107,14 +106,15 @@ def mnist_test(model: str | AAutoencoder | LabelingVAE):
|
|||||||
print(autoencoder)
|
print(autoencoder)
|
||||||
idx = np.random.randint(0, len(x_test))
|
idx = np.random.randint(0, len(x_test))
|
||||||
example: np.ndarray = x_test[idx]
|
example: np.ndarray = x_test[idx]
|
||||||
y_train = [str(int(i)) for i in y_train]
|
labels_train = [str(int(i)) for i in y_train]
|
||||||
autoencoder.learn_labels(x_train, y_train, 5)
|
autoencoder.learn_labels(x_train, labels_train)
|
||||||
res = autoencoder.label(x_train[idx])
|
res = autoencoder.label(example)
|
||||||
for k, v in res.items():
|
for k, v in res.items():
|
||||||
print(f"{k} => {v}")
|
print(f"{k} => {v}")
|
||||||
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])
|
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])
|
||||||
if autoencoder.space_dim == 2:
|
if autoencoder.space_dim == 2:
|
||||||
plot_mnist_latent_space(autoencoder, x_test, y_test)
|
plot_mnist_latent_space(autoencoder, x_test, y_test)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -249,20 +249,16 @@ class Label:
|
|||||||
self.name = name
|
self.name = name
|
||||||
self.embedding_size = embedding_size
|
self.embedding_size = embedding_size
|
||||||
self.history = []
|
self.history = []
|
||||||
self.mean = np.zeros(embedding_size)
|
|
||||||
self.std = np.zeros(embedding_size)
|
|
||||||
|
|
||||||
def observe(self, code: np.ndarray):
|
def observe(self, code: np.ndarray):
|
||||||
self.history.append(code)
|
self.history.append(code)
|
||||||
|
|
||||||
def cache(self):
|
def cache(self):
|
||||||
history = np.array(self.history)
|
self.history_np = np.array(self.history)
|
||||||
self.mean = np.mean(history, axis=0)
|
|
||||||
self.std = np.std(history, axis=0, mean=self.mean)
|
|
||||||
|
|
||||||
def p(self, x: np.ndarray):
|
def p(self, x: np.ndarray):
|
||||||
return np.mean(
|
return np.mean(
|
||||||
np.exp(-(x - self.mean) ** 2 / (2 * self.std)) / (self.std * SQRT_2PI) # noqa
|
np.exp(-np.abs(self.history_np - x))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -272,10 +268,9 @@ class LabelingVAE(VariationalAutoencoder):
|
|||||||
self.labels: list[Label] = []
|
self.labels: list[Label] = []
|
||||||
self.labels_idxs: dict[str, int] = {}
|
self.labels_idxs: dict[str, int] = {}
|
||||||
|
|
||||||
def learn_labels(self, data: np.ndarray, labels: list[list[str]], epoch=5):
|
def learn_labels(self, data: np.ndarray, labels: list[list[str]]):
|
||||||
self.labels.clear()
|
self.labels.clear()
|
||||||
self.labels_idxs.clear()
|
self.labels_idxs.clear()
|
||||||
for _ in range(epoch):
|
|
||||||
for x_i, labels_i in zip(data, labels):
|
for x_i, labels_i in zip(data, labels):
|
||||||
y_i = self.encode(x_i)
|
y_i = self.encode(x_i)
|
||||||
for c in labels_i:
|
for c in labels_i:
|
||||||
@@ -290,15 +285,12 @@ class LabelingVAE(VariationalAutoencoder):
|
|||||||
for label in self.labels:
|
for label in self.labels:
|
||||||
label.cache()
|
label.cache()
|
||||||
|
|
||||||
def label(self, x: np.ndarray, samples=10):
|
def label(self, x: np.ndarray):
|
||||||
y = np.zeros((samples, self.encoder.out_size))
|
|
||||||
for i in range(samples):
|
|
||||||
y[i] = self.encode(x)
|
|
||||||
y = np.mean(y, axis=0)
|
|
||||||
probs = {}
|
probs = {}
|
||||||
total = 0
|
total = 0
|
||||||
|
code = self.encode(x)
|
||||||
for label in self.labels:
|
for label in self.labels:
|
||||||
p = label.p(y)
|
p = label.p(code)
|
||||||
probs[label.name] = p
|
probs[label.name] = p
|
||||||
total += p
|
total += p
|
||||||
for k in probs:
|
for k in probs:
|
||||||
|
|||||||
Reference in New Issue
Block a user