Merge pull request #4 from lenoctambule/dev

Labeling test in MNIST example + fixes
This commit is contained in:
Lenoctambule
2026-04-18 19:34:36 +02:00
committed by GitHub
2 changed files with 37 additions and 30 deletions

View File

@@ -8,6 +8,7 @@ from easyvae.autoencoder import ( # noqa
AAutoencoder AAutoencoder
) )
from easyvae.activations import LeakyReLU from easyvae.activations import LeakyReLU
from tqdm import tqdm
def load_mnist() -> list[np.ndarray]: def load_mnist() -> list[np.ndarray]:
@@ -90,6 +91,21 @@ def plot_random_reconstruction(
print(f'{code.tolist()}') print(f'{code.tolist()}')
def labeling_accuracy(autoencoder: LabelingVAE, x_test, y_test):
accuracy = 0
for x, y in tqdm(
zip(x_test, y_test),
desc="Testing labeling",
total=len(x_test)
):
res = autoencoder.label(x)
res = list(res.items())[0][0]
if res == str(int(y)):
accuracy += 1
accuracy /= len(y_test)
print(f"Accuracy : {accuracy * 100:.2f}%")
def mnist_test(model: str | AAutoencoder | LabelingVAE): def mnist_test(model: str | AAutoencoder | LabelingVAE):
x_train, y_train, x_test, y_test = load_mnist() x_train, y_train, x_test, y_test = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0] in_len = x_train[0].shape[0] * x_train[0].shape[0]
@@ -107,7 +123,9 @@ def mnist_test(model: str | AAutoencoder | LabelingVAE):
idx = np.random.randint(0, len(x_test)) idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx] example: np.ndarray = x_test[idx]
labels_train = [str(int(i)) for i in y_train] labels_train = [str(int(i)) for i in y_train]
if isinstance(autoencoder, LabelingVAE):
autoencoder.learn_labels(x_train, labels_train) autoencoder.learn_labels(x_train, labels_train)
labeling_accuracy(autoencoder, x_test, y_test)
res = autoencoder.label(example) res = autoencoder.label(example)
for k, v in res.items(): for k, v in res.items():
print(f"{k} => {v}") print(f"{k} => {v}")

View File

@@ -35,7 +35,7 @@ class AAutoencoder(ABC):
path = path.removesuffix('.npy') path = path.removesuffix('.npy')
np.save(path, self) np.save(path, self)
def load(path: str) -> 'ClassicalAutoencoder': def load(path: str) -> 'AAutoencoder':
path = path.removesuffix('.npy') + '.npy' path = path.removesuffix('.npy') + '.npy'
data = np.load(path, allow_pickle=True) data = np.load(path, allow_pickle=True)
return data.item() return data.item()
@@ -56,6 +56,16 @@ class AAutoencoder(ABC):
def train_dataset(self, *args, **kwargs) -> list[float]: def train_dataset(self, *args, **kwargs) -> list[float]:
pass pass
def __str__(self):
return "\n".join((
f"Type: {self.__class__.__name__}",
"Encoder:",
f"{self.encoder}",
"Decoder:",
f"{self.decoder}"
)
)
class ClassicalAutoencoder(AAutoencoder): class ClassicalAutoencoder(AAutoencoder):
plotter_cls = CAPlotter plotter_cls = CAPlotter
@@ -64,16 +74,6 @@ class ClassicalAutoencoder(AAutoencoder):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.losses = [] self.losses = []
def __str__(self):
return "\n".join((
f"Type: {__class__.__name__}",
"Encoder:",
f"{self.encoder}",
"Decoder:",
f"{self.decoder}"
)
)
def loss(self, data_set: list[np.ndarray]) -> float: def loss(self, data_set: list[np.ndarray]) -> float:
loss = 0 loss = 0
for x in data_set: for x in data_set:
@@ -103,7 +103,7 @@ class ClassicalAutoencoder(AAutoencoder):
self.losses = [self.loss(data_set)] self.losses = [self.loss(data_set)]
epoch = 0 epoch = 0
no_improv = 0 no_improv = 0
prev_error = self.losses[0] prev_error = self.losses[-1]
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar: with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
while True: while True:
lbar.set_description( lbar.set_description(
@@ -149,15 +149,6 @@ class VariationalAutoencoder(AAutoencoder):
self.KL_losses = [] self.KL_losses = []
self.recon_losses = [] self.recon_losses = []
def __str__(self):
return "\n".join((
f"Type: {__class__.__name__}",
"Encoder:",
f"{self.encoder}",
"Decoder:",
f"{self.decoder}"
))
def loss(self, data_set: list[np.ndarray]) -> float: def loss(self, data_set: list[np.ndarray]) -> float:
kl_loss = 0 kl_loss = 0
recon_loss = 0 recon_loss = 0
@@ -198,7 +189,7 @@ class VariationalAutoencoder(AAutoencoder):
self.KL_losses = [kl_0] self.KL_losses = [kl_0]
epoch = 0 epoch = 0
no_improv = 0 no_improv = 0
prev_loss = self.recon_losses[0] + self.KL_losses[0] prev_loss = self.recon_losses[-1] + self.KL_losses[-1]
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar: with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
while True: while True:
lbar.set_description( lbar.set_description(
@@ -260,14 +251,12 @@ class Label:
self.history[self.idx] = code self.history[self.idx] = code
self.idx += 1 self.idx += 1
else: else:
diffs = np.linalg.norm(self.history - code, axis=0) diffs = np.linalg.norm(self.history - code, axis=1)
idx = np.argmin(diffs) idx = np.argmin(diffs)
self.history[idx] = (self.history[idx] + code) / 2 self.history[idx] = (self.history[idx] + code) / 2
def p(self, x: np.ndarray): def p(self, x: np.ndarray):
return np.mean( return 1 / (1e-4 + np.mean(np.abs(self.history - x)))
np.exp(-np.abs(self.history - x))
)
class LabelingVAE(VariationalAutoencoder): class LabelingVAE(VariationalAutoencoder):