fix: add history mem cap and mid point pruning
This commit is contained in:
@@ -245,20 +245,28 @@ class VariationalAutoencoder(AAutoencoder):
|
|||||||
|
|
||||||
|
|
||||||
class Label:
|
class Label:
|
||||||
def __init__(self, name: str, embedding_size: int):
|
def __init__(self,
|
||||||
|
name: str,
|
||||||
|
embedding_size: int,
|
||||||
|
N=100):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.embedding_size = embedding_size
|
self.embedding_size = embedding_size
|
||||||
self.history = []
|
self.N = N
|
||||||
|
self.idx = 0
|
||||||
|
self.history = np.zeros((self.N, embedding_size))
|
||||||
|
|
||||||
def observe(self, code: np.ndarray):
|
def observe(self, code: np.ndarray):
|
||||||
self.history.append(code)
|
if self.idx < self.N:
|
||||||
|
self.history[self.idx] = code
|
||||||
def cache(self):
|
self.idx += 1
|
||||||
self.history_np = np.array(self.history)
|
else:
|
||||||
|
diffs = np.linalg.norm(self.history - code, axis=0)
|
||||||
|
idx = np.argmin(diffs)
|
||||||
|
self.history[idx] = (self.history[idx] + code) / 2
|
||||||
|
|
||||||
def p(self, x: np.ndarray):
|
def p(self, x: np.ndarray):
|
||||||
return np.mean(
|
return np.mean(
|
||||||
np.exp(-np.abs(self.history_np - x))
|
np.exp(-np.abs(self.history - x))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -282,8 +290,6 @@ class LabelingVAE(VariationalAutoencoder):
|
|||||||
else:
|
else:
|
||||||
label = self.labels[idx]
|
label = self.labels[idx]
|
||||||
label.observe(y_i)
|
label.observe(y_i)
|
||||||
for label in self.labels:
|
|
||||||
label.cache()
|
|
||||||
|
|
||||||
def label(self, x: np.ndarray):
|
def label(self, x: np.ndarray):
|
||||||
probs = {}
|
probs = {}
|
||||||
|
|||||||
Reference in New Issue
Block a user