refactor: code de-dup __str__ method

This commit is contained in:
Lenoctambule
2026-04-17 19:53:58 +02:00
parent 6eaaa43285
commit 583fc796f6
2 changed files with 12 additions and 21 deletions

View File

@@ -123,7 +123,7 @@ def mnist_test(model: str | AAutoencoder | LabelingVAE):
idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx]
labels_train = [str(int(i)) for i in y_train]
if isinstance(model, LabelingVAE):
if isinstance(autoencoder, LabelingVAE):
autoencoder.learn_labels(x_train, labels_train)
labeling_accuracy(autoencoder, x_test, y_test)
res = autoencoder.label(example)

View File

@@ -35,7 +35,7 @@ class AAutoencoder(ABC):
path = path.removesuffix('.npy')
np.save(path, self)
def load(path: str) -> 'ClassicalAutoencoder':
def load(path: str) -> 'AAutoencoder':
path = path.removesuffix('.npy') + '.npy'
data = np.load(path, allow_pickle=True)
return data.item()
@@ -56,6 +56,16 @@ class AAutoencoder(ABC):
def train_dataset(self, *args, **kwargs) -> list[float]:
pass
def __str__(self):
return "\n".join((
f"Type: {self.__class__.__name__}",
"Encoder:",
f"{self.encoder}",
"Decoder:",
f"{self.decoder}"
)
)
class ClassicalAutoencoder(AAutoencoder):
plotter_cls = CAPlotter
@@ -64,16 +74,6 @@ class ClassicalAutoencoder(AAutoencoder):
super().__init__(*args, **kwargs)
self.losses = []
def __str__(self):
return "\n".join((
f"Type: {__class__.__name__}",
"Encoder:",
f"{self.encoder}",
"Decoder:",
f"{self.decoder}"
)
)
def loss(self, data_set: list[np.ndarray]) -> float:
loss = 0
for x in data_set:
@@ -149,15 +149,6 @@ class VariationalAutoencoder(AAutoencoder):
self.KL_losses = []
self.recon_losses = []
def __str__(self):
return "\n".join((
f"Type: {__class__.__name__}",
"Encoder:",
f"{self.encoder}",
"Decoder:",
f"{self.decoder}"
))
def loss(self, data_set: list[np.ndarray]) -> float:
kl_loss = 0
recon_loss = 0