refactor(autoencoder.py): __init__ code de-dup

This commit is contained in:
Lenoctambule
2026-04-09 18:31:32 +02:00
parent 058b7a0f2a
commit 15865812d8
2 changed files with 69 additions and 64 deletions

View File

@@ -13,6 +13,21 @@ LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
class AAutoencoder(ABC):
@abstractmethod
def __init__(self,
encoder_layers: list[int],
decoder_layers: list[int],
lr: float,
activation_func: ActivationFunc):
if encoder_layers[-1] != decoder_layers[0]:
raise Exception(
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
)
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
self.space_dim = decoder_layers[0]
self.lr = lr
def train_dataset(self,
data_set: list[np.ndarray],
max_epoch: int,
@@ -75,17 +90,8 @@ class AAutoencoder(ABC):
class ClassicalAutoencoder(AAutoencoder):
def __init__(self,
encoder_layers: list[int],
decoder_layers: list[int],
lr: float,
activation_func: ActivationFunc):
if encoder_layers[-1] != decoder_layers[0]:
raise Exception(
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
)
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __str__(self):
return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}'
@@ -119,18 +125,9 @@ class ClassicalAutoencoder(AAutoencoder):
class VariationalAutoencoder(AAutoencoder):
def __init__(self,
encoder_layers: list[int],
decoder_layers: list[int],
lr: float,
activation_func: ActivationFunc):
if encoder_layers[-1] != decoder_layers[0]:
raise Exception(
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
)
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
self.sampler = SampleLayer(self.encoder.out_size, lr, Identity())
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity())
def loss(self, data_set: list[np.ndarray]) -> float:
loss = 0