diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index e6c035e..1c0c978 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -245,20 +245,28 @@ class VariationalAutoencoder(AAutoencoder): class Label: - def __init__(self, name: str, embedding_size: int): + def __init__(self, + name: str, + embedding_size: int, + N=100): self.name = name self.embedding_size = embedding_size - self.history = [] + self.N = N + self.idx = 0 + self.history = np.zeros((self.N, embedding_size)) def observe(self, code: np.ndarray): - self.history.append(code) - - def cache(self): - self.history_np = np.array(self.history) + if self.idx < self.N: + self.history[self.idx] = code + self.idx += 1 + else: + diffs = np.linalg.norm(self.history - code, axis=0) + idx = np.argmin(diffs) + self.history[idx] = (self.history[idx] + code) / 2 def p(self, x: np.ndarray): return np.mean( - np.exp(-np.abs(self.history_np - x)) + np.exp(-np.abs(self.history - x)) ) @@ -282,8 +290,6 @@ class LabelingVAE(VariationalAutoencoder): else: label = self.labels[idx] label.observe(y_i) - for label in self.labels: - label.cache() def label(self, x: np.ndarray): probs = {}