From 15865812d898db495b00bf8df72edb1b4ddc54f6 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Thu, 9 Apr 2026 18:31:32 +0200 Subject: [PATCH] refactor(autoencoder.py): __init__ code de-dup --- examples/mnist_test.py | 90 +++++++++++++++++++++----------------- src/easyvae/autoencoder.py | 43 +++++++++--------- 2 files changed, 69 insertions(+), 64 deletions(-) diff --git a/examples/mnist_test.py b/examples/mnist_test.py index 1058590..7aff0f9 100644 --- a/examples/mnist_test.py +++ b/examples/mnist_test.py @@ -8,6 +8,7 @@ from easyvae.autoencoder import ( # noqa AAutoencoder ) from easyvae.activations import LeakyReLU +from easyvae.utils import dynamic_loss_plot_finish def load_mnist() -> list[np.ndarray]: @@ -38,15 +39,15 @@ def mnist_train( autoencoder = cls( [in_len, 256, 2], [2, 256, in_len], - 0.001, + 0.0001, LeakyReLU() ) def handler(signum, frame): print(f"Saving {filename} before exit ...") autoencoder.save(filename) - plt.close() - plt.ioff() + if plt.get_fignums(): + dynamic_loss_plot_finish() mnist_test(autoencoder) exit() @@ -62,6 +63,45 @@ def mnist_train( return autoencoder +def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,): + codes = [] + for x in x: + _, c = autoencoder.forward(x.flatten()) + codes.append(c) + codes = np.array(codes) + if codes.shape[1] == 2: + plt.figure(figsize=(6, 6)) + scatter = plt.scatter( + codes[:, 0], + codes[:, 1], + c=y, + cmap='tab10', + s=5, + alpha=0.7 + ) + plt.colorbar(scatter) + plt.grid(True) + plt.show() + + +def plot_random_reconstruction(autoencoder: AAutoencoder, + example: np.ndarray, + img_shape, + y): + output, code = autoencoder.forward(example.flatten()) + plt.subplot(1, 3, 1) + plt.matshow( + example.reshape(img_shape), + fignum=False) + plt.title(f"Input ({y})") + plt.subplot(1, 3, 2) + plt.matshow( + output.reshape(img_shape), + fignum=False) + plt.title(f"Output ({y})") + print(f'{code=}') + + def mnist_test(model: str | AAutoencoder): x_train, _, x_test, y_test = load_mnist() in_len = x_train[0].shape[0] * x_train[0].shape[0] @@ -76,41 +116,9 @@ def mnist_test(model: str | AAutoencoder): autoencoder = model idx = np.random.randint(0, len(x_test)) example: np.ndarray = x_test[idx] - output, code = autoencoder.forward(example.flatten()) - plt.subplot(1, 3, 1) - plt.matshow( - example.reshape(img_shape), - fignum=False) - plt.title(f"Input ({y_test[idx]})") - plt.subplot(1, 3, 2) - plt.matshow( - output.reshape(img_shape), - fignum=False) - plt.title(f"Output ({y_test[idx]})") - plt.subplot(1, 3, 3) - code = np.reshape(code, (code.shape[0], 1)) - plt.matshow(code, fignum=False) - plt.title(f"Code ({y_test[idx]})") - plt.show() - if code.shape[0] == 2: - codes = [] - for x in x_test: - _, c = autoencoder.forward(x.flatten()) - codes.append(c) - codes = np.array(codes) - if codes.shape[1] == 2: - plt.figure(figsize=(6, 6)) - scatter = plt.scatter( - codes[:, 0], - codes[:, 1], - c=y_test, - cmap='tab10', - s=5, - alpha=0.7 - ) - plt.colorbar(scatter) - plt.grid(True) - plt.show() + plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx]) + if autoencoder.space_dim == 2: + plot_mnist_latent_space(autoencoder, x_test, y_test) if __name__ == "__main__": @@ -122,14 +130,14 @@ if __name__ == "__main__": '-e', type=int, nargs='?', - default=1000, + default=30, help='Max epochs' ) parser.add_argument( '-p', type=int, nargs='?', - default=5, + default=30, help='Patience' ) parser.add_argument( @@ -141,7 +149,7 @@ if __name__ == "__main__": parser.add_argument( '-r', action='store_true', - help='Run mode' + help='Run the model' ) args = parser.parse_args(sys.argv[1:]) if args.r: diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index 410699c..e4a9be9 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -13,6 +13,21 @@ LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿'] class AAutoencoder(ABC): + @abstractmethod + def __init__(self, + encoder_layers: list[int], + decoder_layers: list[int], + lr: float, + activation_func: ActivationFunc): + if encoder_layers[-1] != decoder_layers[0]: + raise Exception( + f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa + ) + self.encoder = DeepNNLayer(encoder_layers, lr, activation_func) + self.decoder = DeepNNLayer(decoder_layers, lr, activation_func) + self.space_dim = decoder_layers[0] + self.lr = lr + def train_dataset(self, data_set: list[np.ndarray], max_epoch: int, @@ -75,17 +90,8 @@ class AAutoencoder(ABC): class ClassicalAutoencoder(AAutoencoder): - def __init__(self, - encoder_layers: list[int], - decoder_layers: list[int], - lr: float, - activation_func: ActivationFunc): - if encoder_layers[-1] != decoder_layers[0]: - raise Exception( - f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa - ) - self.encoder = DeepNNLayer(encoder_layers, lr, activation_func) - self.decoder = DeepNNLayer(decoder_layers, lr, activation_func) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def __str__(self): return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}' @@ -119,18 +125,9 @@ class ClassicalAutoencoder(AAutoencoder): class VariationalAutoencoder(AAutoencoder): - def __init__(self, - encoder_layers: list[int], - decoder_layers: list[int], - lr: float, - activation_func: ActivationFunc): - if encoder_layers[-1] != decoder_layers[0]: - raise Exception( - f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa - ) - self.encoder = DeepNNLayer(encoder_layers, lr, activation_func) - self.decoder = DeepNNLayer(decoder_layers, lr, activation_func) - self.sampler = SampleLayer(self.encoder.out_size, lr, Identity()) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity()) def loss(self, data_set: list[np.ndarray]) -> float: loss = 0