refactor: code de-dup __str__ method
This commit is contained in:
@@ -123,7 +123,7 @@ def mnist_test(model: str | AAutoencoder | LabelingVAE):
|
||||
idx = np.random.randint(0, len(x_test))
|
||||
example: np.ndarray = x_test[idx]
|
||||
labels_train = [str(int(i)) for i in y_train]
|
||||
if isinstance(model, LabelingVAE):
|
||||
if isinstance(autoencoder, LabelingVAE):
|
||||
autoencoder.learn_labels(x_train, labels_train)
|
||||
labeling_accuracy(autoencoder, x_test, y_test)
|
||||
res = autoencoder.label(example)
|
||||
|
||||
@@ -35,7 +35,7 @@ class AAutoencoder(ABC):
|
||||
path = path.removesuffix('.npy')
|
||||
np.save(path, self)
|
||||
|
||||
def load(path: str) -> 'ClassicalAutoencoder':
|
||||
def load(path: str) -> 'AAutoencoder':
|
||||
path = path.removesuffix('.npy') + '.npy'
|
||||
data = np.load(path, allow_pickle=True)
|
||||
return data.item()
|
||||
@@ -56,6 +56,16 @@ class AAutoencoder(ABC):
|
||||
def train_dataset(self, *args, **kwargs) -> list[float]:
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join((
|
||||
f"Type: {self.__class__.__name__}",
|
||||
"Encoder:",
|
||||
f"{self.encoder}",
|
||||
"Decoder:",
|
||||
f"{self.decoder}"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ClassicalAutoencoder(AAutoencoder):
|
||||
plotter_cls = CAPlotter
|
||||
@@ -64,16 +74,6 @@ class ClassicalAutoencoder(AAutoencoder):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.losses = []
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join((
|
||||
f"Type: {__class__.__name__}",
|
||||
"Encoder:",
|
||||
f"{self.encoder}",
|
||||
"Decoder:",
|
||||
f"{self.decoder}"
|
||||
)
|
||||
)
|
||||
|
||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||
loss = 0
|
||||
for x in data_set:
|
||||
@@ -149,15 +149,6 @@ class VariationalAutoencoder(AAutoencoder):
|
||||
self.KL_losses = []
|
||||
self.recon_losses = []
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join((
|
||||
f"Type: {__class__.__name__}",
|
||||
"Encoder:",
|
||||
f"{self.encoder}",
|
||||
"Decoder:",
|
||||
f"{self.decoder}"
|
||||
))
|
||||
|
||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||
kl_loss = 0
|
||||
recon_loss = 0
|
||||
|
||||
Reference in New Issue
Block a user