fix: add history mem cap and mid point pruning
This commit is contained in:
@@ -245,20 +245,28 @@ class VariationalAutoencoder(AAutoencoder):
|
||||
|
||||
|
||||
class Label:
|
||||
def __init__(self, name: str, embedding_size: int):
|
||||
def __init__(self,
|
||||
name: str,
|
||||
embedding_size: int,
|
||||
N=100):
|
||||
self.name = name
|
||||
self.embedding_size = embedding_size
|
||||
self.history = []
|
||||
self.N = N
|
||||
self.idx = 0
|
||||
self.history = np.zeros((self.N, embedding_size))
|
||||
|
||||
def observe(self, code: np.ndarray):
|
||||
self.history.append(code)
|
||||
|
||||
def cache(self):
|
||||
self.history_np = np.array(self.history)
|
||||
if self.idx < self.N:
|
||||
self.history[self.idx] = code
|
||||
self.idx += 1
|
||||
else:
|
||||
diffs = np.linalg.norm(self.history - code, axis=0)
|
||||
idx = np.argmin(diffs)
|
||||
self.history[idx] = (self.history[idx] + code) / 2
|
||||
|
||||
def p(self, x: np.ndarray):
|
||||
return np.mean(
|
||||
np.exp(-np.abs(self.history_np - x))
|
||||
np.exp(-np.abs(self.history - x))
|
||||
)
|
||||
|
||||
|
||||
@@ -282,8 +290,6 @@ class LabelingVAE(VariationalAutoencoder):
|
||||
else:
|
||||
label = self.labels[idx]
|
||||
label.observe(y_i)
|
||||
for label in self.labels:
|
||||
label.cache()
|
||||
|
||||
def label(self, x: np.ndarray):
|
||||
probs = {}
|
||||
|
||||
Reference in New Issue
Block a user