diff --git a/examples/mnist_test.py b/examples/mnist_test.py index 3d89b1b..9c4fdc4 100644 --- a/examples/mnist_test.py +++ b/examples/mnist_test.py @@ -123,7 +123,7 @@ def mnist_test(model: str | AAutoencoder | LabelingVAE): idx = np.random.randint(0, len(x_test)) example: np.ndarray = x_test[idx] labels_train = [str(int(i)) for i in y_train] - if isinstance(model, LabelingVAE): + if isinstance(autoencoder, LabelingVAE): autoencoder.learn_labels(x_train, labels_train) labeling_accuracy(autoencoder, x_test, y_test) res = autoencoder.label(example) diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index 0d98896..1d8b9c1 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -35,7 +35,7 @@ class AAutoencoder(ABC): path = path.removesuffix('.npy') np.save(path, self) - def load(path: str) -> 'ClassicalAutoencoder': + def load(path: str) -> 'AAutoencoder': path = path.removesuffix('.npy') + '.npy' data = np.load(path, allow_pickle=True) return data.item() @@ -56,6 +56,16 @@ class AAutoencoder(ABC): def train_dataset(self, *args, **kwargs) -> list[float]: pass + def __str__(self): + return "\n".join(( + f"Type: {self.__class__.__name__}", + "Encoder:", + f"{self.encoder}", + "Decoder:", + f"{self.decoder}" + ) + ) + class ClassicalAutoencoder(AAutoencoder): plotter_cls = CAPlotter @@ -64,16 +74,6 @@ class ClassicalAutoencoder(AAutoencoder): super().__init__(*args, **kwargs) self.losses = [] - def __str__(self): - return "\n".join(( - f"Type: {__class__.__name__}", - "Encoder:", - f"{self.encoder}", - "Decoder:", - f"{self.decoder}" - ) - ) - def loss(self, data_set: list[np.ndarray]) -> float: loss = 0 for x in data_set: @@ -149,15 +149,6 @@ class VariationalAutoencoder(AAutoencoder): self.KL_losses = [] self.recon_losses = [] - def __str__(self): - return "\n".join(( - f"Type: {__class__.__name__}", - "Encoder:", - f"{self.encoder}", - "Decoder:", - f"{self.decoder}" - )) - def loss(self, data_set: list[np.ndarray]) -> float: kl_loss = 0 recon_loss = 0