diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index 4209976..0d98896 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -103,7 +103,7 @@ class ClassicalAutoencoder(AAutoencoder): self.losses = [self.loss(data_set)] epoch = 0 no_improv = 0 - prev_error = self.losses[0] + prev_error = self.losses[-1] with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar: while True: lbar.set_description( @@ -198,7 +198,7 @@ class VariationalAutoencoder(AAutoencoder): self.KL_losses = [kl_0] epoch = 0 no_improv = 0 - prev_loss = self.recon_losses[0] + self.KL_losses[0] + prev_loss = self.recon_losses[-1] + self.KL_losses[-1] with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar: while True: lbar.set_description( @@ -265,9 +265,7 @@ class Label: self.history[idx] = (self.history[idx] + code) / 2 def p(self, x: np.ndarray): - return np.mean( - np.exp(-np.abs(self.history - x)) - ) + return 1 / (1e-4 + np.mean(np.abs(self.history - x))) class LabelingVAE(VariationalAutoencoder):