feat: add monte-carlo method and MSE to labeling method
This commit is contained in:
@@ -290,8 +290,11 @@ class LabelingVAE(VariationalAutoencoder):
|
||||
for label in self.labels:
|
||||
label.cache()
|
||||
|
||||
def label(self, x: np.ndarray):
|
||||
y = self.encode(x)
|
||||
def label(self, x: np.ndarray, samples=10):
|
||||
y = np.zeros((samples, self.encoder.out_size))
|
||||
for i in range(samples):
|
||||
y[i] = self.encode(x)
|
||||
y = np.mean(y, axis=0)
|
||||
probs = {}
|
||||
total = 0
|
||||
for label in self.labels:
|
||||
@@ -300,4 +303,10 @@ class LabelingVAE(VariationalAutoencoder):
|
||||
total += p
|
||||
for k in probs:
|
||||
probs[k] = float(probs[k] / total)
|
||||
return dict(sorted(probs.items()))
|
||||
return dict(
|
||||
sorted(
|
||||
probs.items(),
|
||||
key=lambda item: item[1],
|
||||
reverse=True
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user