fix: add history mem cap and mid point pruning

This commit is contained in:
Lenoctambule
2026-04-14 21:03:04 +02:00
parent 4cc349c22c
commit 32a4a39ab9

View File

@@ -245,20 +245,28 @@ class VariationalAutoencoder(AAutoencoder):
class Label:
def __init__(self, name: str, embedding_size: int):
def __init__(self,
name: str,
embedding_size: int,
N=100):
self.name = name
self.embedding_size = embedding_size
self.history = []
self.N = N
self.idx = 0
self.history = np.zeros((self.N, embedding_size))
def observe(self, code: np.ndarray):
self.history.append(code)
def cache(self):
self.history_np = np.array(self.history)
if self.idx < self.N:
self.history[self.idx] = code
self.idx += 1
else:
diffs = np.linalg.norm(self.history - code, axis=0)
idx = np.argmin(diffs)
self.history[idx] = (self.history[idx] + code) / 2
def p(self, x: np.ndarray):
return np.mean(
np.exp(-np.abs(self.history_np - x))
np.exp(-np.abs(self.history - x))
)
@@ -282,8 +290,6 @@ class LabelingVAE(VariationalAutoencoder):
else:
label = self.labels[idx]
label.observe(y_i)
for label in self.labels:
label.cache()
def label(self, x: np.ndarray):
probs = {}