feat: simple distances instead of std+mean for labeling
This commit is contained in:
@@ -249,20 +249,16 @@ class Label:
|
||||
self.name = name
|
||||
self.embedding_size = embedding_size
|
||||
self.history = []
|
||||
self.mean = np.zeros(embedding_size)
|
||||
self.std = np.zeros(embedding_size)
|
||||
|
||||
def observe(self, code: np.ndarray):
|
||||
self.history.append(code)
|
||||
|
||||
def cache(self):
|
||||
history = np.array(self.history)
|
||||
self.mean = np.mean(history, axis=0)
|
||||
self.std = np.std(history, axis=0, mean=self.mean)
|
||||
self.history_np = np.array(self.history)
|
||||
|
||||
def p(self, x: np.ndarray):
|
||||
return np.mean(
|
||||
np.exp(-(x - self.mean) ** 2 / (2 * self.std)) / (self.std * SQRT_2PI) # noqa
|
||||
np.exp(-np.abs(self.history_np - x))
|
||||
)
|
||||
|
||||
|
||||
@@ -272,33 +268,29 @@ class LabelingVAE(VariationalAutoencoder):
|
||||
self.labels: list[Label] = []
|
||||
self.labels_idxs: dict[str, int] = {}
|
||||
|
||||
def learn_labels(self, data: np.ndarray, labels: list[list[str]], epoch=5):
|
||||
def learn_labels(self, data: np.ndarray, labels: list[list[str]]):
|
||||
self.labels.clear()
|
||||
self.labels_idxs.clear()
|
||||
for _ in range(epoch):
|
||||
for x_i, labels_i in zip(data, labels):
|
||||
y_i = self.encode(x_i)
|
||||
for c in labels_i:
|
||||
idx = self.labels_idxs.get(c, None)
|
||||
if idx is None:
|
||||
label = Label(c, self.encoder.out_size)
|
||||
self.labels.append(label)
|
||||
self.labels_idxs[c] = len(self.labels) - 1
|
||||
else:
|
||||
label = self.labels[idx]
|
||||
label.observe(y_i)
|
||||
for label in self.labels:
|
||||
label.cache()
|
||||
for x_i, labels_i in zip(data, labels):
|
||||
y_i = self.encode(x_i)
|
||||
for c in labels_i:
|
||||
idx = self.labels_idxs.get(c, None)
|
||||
if idx is None:
|
||||
label = Label(c, self.encoder.out_size)
|
||||
self.labels.append(label)
|
||||
self.labels_idxs[c] = len(self.labels) - 1
|
||||
else:
|
||||
label = self.labels[idx]
|
||||
label.observe(y_i)
|
||||
for label in self.labels:
|
||||
label.cache()
|
||||
|
||||
def label(self, x: np.ndarray, samples=10):
|
||||
y = np.zeros((samples, self.encoder.out_size))
|
||||
for i in range(samples):
|
||||
y[i] = self.encode(x)
|
||||
y = np.mean(y, axis=0)
|
||||
def label(self, x: np.ndarray):
|
||||
probs = {}
|
||||
total = 0
|
||||
code = self.encode(x)
|
||||
for label in self.labels:
|
||||
p = label.p(y)
|
||||
p = label.p(code)
|
||||
probs[label.name] = p
|
||||
total += p
|
||||
for k in probs:
|
||||
|
||||
Reference in New Issue
Block a user