refactor(autoencoder.py): __init__ code de-dup
This commit is contained in:
@@ -13,6 +13,21 @@ LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
||||
|
||||
|
||||
class AAutoencoder(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self,
|
||||
encoder_layers: list[int],
|
||||
decoder_layers: list[int],
|
||||
lr: float,
|
||||
activation_func: ActivationFunc):
|
||||
if encoder_layers[-1] != decoder_layers[0]:
|
||||
raise Exception(
|
||||
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
|
||||
)
|
||||
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
||||
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
||||
self.space_dim = decoder_layers[0]
|
||||
self.lr = lr
|
||||
|
||||
def train_dataset(self,
|
||||
data_set: list[np.ndarray],
|
||||
max_epoch: int,
|
||||
@@ -75,17 +90,8 @@ class AAutoencoder(ABC):
|
||||
|
||||
|
||||
class ClassicalAutoencoder(AAutoencoder):
|
||||
def __init__(self,
|
||||
encoder_layers: list[int],
|
||||
decoder_layers: list[int],
|
||||
lr: float,
|
||||
activation_func: ActivationFunc):
|
||||
if encoder_layers[-1] != decoder_layers[0]:
|
||||
raise Exception(
|
||||
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
|
||||
)
|
||||
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
||||
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}'
|
||||
@@ -119,18 +125,9 @@ class ClassicalAutoencoder(AAutoencoder):
|
||||
|
||||
|
||||
class VariationalAutoencoder(AAutoencoder):
|
||||
def __init__(self,
|
||||
encoder_layers: list[int],
|
||||
decoder_layers: list[int],
|
||||
lr: float,
|
||||
activation_func: ActivationFunc):
|
||||
if encoder_layers[-1] != decoder_layers[0]:
|
||||
raise Exception(
|
||||
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
|
||||
)
|
||||
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
||||
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
||||
self.sampler = SampleLayer(self.encoder.out_size, lr, Identity())
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity())
|
||||
|
||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||
loss = 0
|
||||
|
||||
Reference in New Issue
Block a user