feat: add monte-carlo method and MSE to labeling method

This commit is contained in:
Lenoctambule
2026-04-14 19:24:20 +02:00
parent b1dc34e699
commit b635bf0467
2 changed files with 21 additions and 7 deletions

View File

@@ -290,8 +290,11 @@ class LabelingVAE(VariationalAutoencoder):
for label in self.labels:
label.cache()
def label(self, x: np.ndarray):
y = self.encode(x)
def label(self, x: np.ndarray, samples=10):
y = np.zeros((samples, self.encoder.out_size))
for i in range(samples):
y[i] = self.encode(x)
y = np.mean(y, axis=0)
probs = {}
total = 0
for label in self.labels:
@@ -300,4 +303,10 @@ class LabelingVAE(VariationalAutoencoder):
total += p
for k in probs:
probs[k] = float(probs[k] / total)
return dict(sorted(probs.items()))
return dict(
sorted(
probs.items(),
key=lambda item: item[1],
reverse=True
)
)