diff --git a/examples/mnist_test.py b/examples/mnist_test.py index 7082e1b..236f507 100644 --- a/examples/mnist_test.py +++ b/examples/mnist_test.py @@ -37,7 +37,7 @@ def mnist_train( autoencoder = cls( [in_len, 256, 2], [2, 256, in_len], - 0.0001, + 0.001, LeakyReLU() ) print("CTRL+C to interrupt training.") @@ -91,8 +91,8 @@ def plot_random_reconstruction( print(f'{code.tolist()}') -def mnist_test(model: str | AAutoencoder): - x_train, _, x_test, y_test = load_mnist() +def mnist_test(model: str | AAutoencoder | LabelingVAE): + x_train, y_train, x_test, y_test = load_mnist() in_len = x_train[0].shape[0] * x_train[0].shape[0] img_shape = x_train[0].shape x_train.resize(x_train.shape[0], in_len) @@ -107,6 +107,11 @@ def mnist_test(model: str | AAutoencoder): print(autoencoder) idx = np.random.randint(0, len(x_test)) example: np.ndarray = x_test[idx] + y_train = [str(int(i)) for i in y_train] + autoencoder.learn_labels(x_train, y_train, 5) + res = autoencoder.label(x_train[idx]) + for k, v in res.items(): + print(f"{k} => {v}") plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx]) if autoencoder.space_dim == 2: plot_mnist_latent_space(autoencoder, x_test, y_test) @@ -150,6 +155,6 @@ if __name__ == "__main__": args.m, args.e, args.p, - VariationalAutoencoder + LabelingVAE ) mnist_test(autoencoder) diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index 490b6c5..3c5dd30 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -290,8 +290,11 @@ class LabelingVAE(VariationalAutoencoder): for label in self.labels: label.cache() - def label(self, x: np.ndarray): - y = self.encode(x) + def label(self, x: np.ndarray, samples=10): + y = np.zeros((samples, self.encoder.out_size)) + for i in range(samples): + y[i] = self.encode(x) + y = np.mean(y, axis=0) probs = {} total = 0 for label in self.labels: @@ -300,4 +303,10 @@ class LabelingVAE(VariationalAutoencoder): total += p for k in probs: probs[k] = float(probs[k] / total) - return dict(sorted(probs.items())) + return dict( + sorted( + probs.items(), + key=lambda item: item[1], + reverse=True + ) + )