refactor(autoencoder.py): __init__ code de-dup

This commit is contained in:
Lenoctambule
2026-04-09 18:31:32 +02:00
parent 058b7a0f2a
commit 15865812d8
2 changed files with 69 additions and 64 deletions

View File

@@ -8,6 +8,7 @@ from easyvae.autoencoder import ( # noqa
AAutoencoder AAutoencoder
) )
from easyvae.activations import LeakyReLU from easyvae.activations import LeakyReLU
from easyvae.utils import dynamic_loss_plot_finish
def load_mnist() -> list[np.ndarray]: def load_mnist() -> list[np.ndarray]:
@@ -38,15 +39,15 @@ def mnist_train(
autoencoder = cls( autoencoder = cls(
[in_len, 256, 2], [in_len, 256, 2],
[2, 256, in_len], [2, 256, in_len],
0.001, 0.0001,
LeakyReLU() LeakyReLU()
) )
def handler(signum, frame): def handler(signum, frame):
print(f"Saving {filename} before exit ...") print(f"Saving {filename} before exit ...")
autoencoder.save(filename) autoencoder.save(filename)
plt.close() if plt.get_fignums():
plt.ioff() dynamic_loss_plot_finish()
mnist_test(autoencoder) mnist_test(autoencoder)
exit() exit()
@@ -62,6 +63,45 @@ def mnist_train(
return autoencoder return autoencoder
def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,):
codes = []
for x in x:
_, c = autoencoder.forward(x.flatten())
codes.append(c)
codes = np.array(codes)
if codes.shape[1] == 2:
plt.figure(figsize=(6, 6))
scatter = plt.scatter(
codes[:, 0],
codes[:, 1],
c=y,
cmap='tab10',
s=5,
alpha=0.7
)
plt.colorbar(scatter)
plt.grid(True)
plt.show()
def plot_random_reconstruction(autoencoder: AAutoencoder,
example: np.ndarray,
img_shape,
y):
output, code = autoencoder.forward(example.flatten())
plt.subplot(1, 3, 1)
plt.matshow(
example.reshape(img_shape),
fignum=False)
plt.title(f"Input ({y})")
plt.subplot(1, 3, 2)
plt.matshow(
output.reshape(img_shape),
fignum=False)
plt.title(f"Output ({y})")
print(f'{code=}')
def mnist_test(model: str | AAutoencoder): def mnist_test(model: str | AAutoencoder):
x_train, _, x_test, y_test = load_mnist() x_train, _, x_test, y_test = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0] in_len = x_train[0].shape[0] * x_train[0].shape[0]
@@ -76,41 +116,9 @@ def mnist_test(model: str | AAutoencoder):
autoencoder = model autoencoder = model
idx = np.random.randint(0, len(x_test)) idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx] example: np.ndarray = x_test[idx]
output, code = autoencoder.forward(example.flatten()) plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])
plt.subplot(1, 3, 1) if autoencoder.space_dim == 2:
plt.matshow( plot_mnist_latent_space(autoencoder, x_test, y_test)
example.reshape(img_shape),
fignum=False)
plt.title(f"Input ({y_test[idx]})")
plt.subplot(1, 3, 2)
plt.matshow(
output.reshape(img_shape),
fignum=False)
plt.title(f"Output ({y_test[idx]})")
plt.subplot(1, 3, 3)
code = np.reshape(code, (code.shape[0], 1))
plt.matshow(code, fignum=False)
plt.title(f"Code ({y_test[idx]})")
plt.show()
if code.shape[0] == 2:
codes = []
for x in x_test:
_, c = autoencoder.forward(x.flatten())
codes.append(c)
codes = np.array(codes)
if codes.shape[1] == 2:
plt.figure(figsize=(6, 6))
scatter = plt.scatter(
codes[:, 0],
codes[:, 1],
c=y_test,
cmap='tab10',
s=5,
alpha=0.7
)
plt.colorbar(scatter)
plt.grid(True)
plt.show()
if __name__ == "__main__": if __name__ == "__main__":
@@ -122,14 +130,14 @@ if __name__ == "__main__":
'-e', '-e',
type=int, type=int,
nargs='?', nargs='?',
default=1000, default=30,
help='Max epochs' help='Max epochs'
) )
parser.add_argument( parser.add_argument(
'-p', '-p',
type=int, type=int,
nargs='?', nargs='?',
default=5, default=30,
help='Patience' help='Patience'
) )
parser.add_argument( parser.add_argument(
@@ -141,7 +149,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
'-r', '-r',
action='store_true', action='store_true',
help='Run mode' help='Run the model'
) )
args = parser.parse_args(sys.argv[1:]) args = parser.parse_args(sys.argv[1:])
if args.r: if args.r:

View File

@@ -13,6 +13,21 @@ LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
class AAutoencoder(ABC): class AAutoencoder(ABC):
@abstractmethod
def __init__(self,
encoder_layers: list[int],
decoder_layers: list[int],
lr: float,
activation_func: ActivationFunc):
if encoder_layers[-1] != decoder_layers[0]:
raise Exception(
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
)
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
self.space_dim = decoder_layers[0]
self.lr = lr
def train_dataset(self, def train_dataset(self,
data_set: list[np.ndarray], data_set: list[np.ndarray],
max_epoch: int, max_epoch: int,
@@ -75,17 +90,8 @@ class AAutoencoder(ABC):
class ClassicalAutoencoder(AAutoencoder): class ClassicalAutoencoder(AAutoencoder):
def __init__(self, def __init__(self, *args, **kwargs):
encoder_layers: list[int], super().__init__(*args, **kwargs)
decoder_layers: list[int],
lr: float,
activation_func: ActivationFunc):
if encoder_layers[-1] != decoder_layers[0]:
raise Exception(
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
)
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
def __str__(self): def __str__(self):
return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}' return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}'
@@ -119,18 +125,9 @@ class ClassicalAutoencoder(AAutoencoder):
class VariationalAutoencoder(AAutoencoder): class VariationalAutoencoder(AAutoencoder):
def __init__(self, def __init__(self, *args, **kwargs):
encoder_layers: list[int], super().__init__(*args, **kwargs)
decoder_layers: list[int], self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity())
lr: float,
activation_func: ActivationFunc):
if encoder_layers[-1] != decoder_layers[0]:
raise Exception(
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
)
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
self.sampler = SampleLayer(self.encoder.out_size, lr, Identity())
def loss(self, data_set: list[np.ndarray]) -> float: def loss(self, data_set: list[np.ndarray]) -> float:
loss = 0 loss = 0