Merge pull request #4 from lenoctambule/dev
Labeling test in MNIST example + fixes
This commit is contained in:
@@ -8,6 +8,7 @@ from easyvae.autoencoder import ( # noqa
|
|||||||
AAutoencoder
|
AAutoencoder
|
||||||
)
|
)
|
||||||
from easyvae.activations import LeakyReLU
|
from easyvae.activations import LeakyReLU
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def load_mnist() -> list[np.ndarray]:
|
def load_mnist() -> list[np.ndarray]:
|
||||||
@@ -90,6 +91,21 @@ def plot_random_reconstruction(
|
|||||||
print(f'{code.tolist()}')
|
print(f'{code.tolist()}')
|
||||||
|
|
||||||
|
|
||||||
|
def labeling_accuracy(autoencoder: LabelingVAE, x_test, y_test):
|
||||||
|
accuracy = 0
|
||||||
|
for x, y in tqdm(
|
||||||
|
zip(x_test, y_test),
|
||||||
|
desc="Testing labeling",
|
||||||
|
total=len(x_test)
|
||||||
|
):
|
||||||
|
res = autoencoder.label(x)
|
||||||
|
res = list(res.items())[0][0]
|
||||||
|
if res == str(int(y)):
|
||||||
|
accuracy += 1
|
||||||
|
accuracy /= len(y_test)
|
||||||
|
print(f"Accuracy : {accuracy * 100:.2f}%")
|
||||||
|
|
||||||
|
|
||||||
def mnist_test(model: str | AAutoencoder | LabelingVAE):
|
def mnist_test(model: str | AAutoencoder | LabelingVAE):
|
||||||
x_train, y_train, x_test, y_test = load_mnist()
|
x_train, y_train, x_test, y_test = load_mnist()
|
||||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||||
@@ -107,7 +123,9 @@ def mnist_test(model: str | AAutoencoder | LabelingVAE):
|
|||||||
idx = np.random.randint(0, len(x_test))
|
idx = np.random.randint(0, len(x_test))
|
||||||
example: np.ndarray = x_test[idx]
|
example: np.ndarray = x_test[idx]
|
||||||
labels_train = [str(int(i)) for i in y_train]
|
labels_train = [str(int(i)) for i in y_train]
|
||||||
|
if isinstance(autoencoder, LabelingVAE):
|
||||||
autoencoder.learn_labels(x_train, labels_train)
|
autoencoder.learn_labels(x_train, labels_train)
|
||||||
|
labeling_accuracy(autoencoder, x_test, y_test)
|
||||||
res = autoencoder.label(example)
|
res = autoencoder.label(example)
|
||||||
for k, v in res.items():
|
for k, v in res.items():
|
||||||
print(f"{k} => {v}")
|
print(f"{k} => {v}")
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class AAutoencoder(ABC):
|
|||||||
path = path.removesuffix('.npy')
|
path = path.removesuffix('.npy')
|
||||||
np.save(path, self)
|
np.save(path, self)
|
||||||
|
|
||||||
def load(path: str) -> 'ClassicalAutoencoder':
|
def load(path: str) -> 'AAutoencoder':
|
||||||
path = path.removesuffix('.npy') + '.npy'
|
path = path.removesuffix('.npy') + '.npy'
|
||||||
data = np.load(path, allow_pickle=True)
|
data = np.load(path, allow_pickle=True)
|
||||||
return data.item()
|
return data.item()
|
||||||
@@ -56,6 +56,16 @@ class AAutoencoder(ABC):
|
|||||||
def train_dataset(self, *args, **kwargs) -> list[float]:
|
def train_dataset(self, *args, **kwargs) -> list[float]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "\n".join((
|
||||||
|
f"Type: {self.__class__.__name__}",
|
||||||
|
"Encoder:",
|
||||||
|
f"{self.encoder}",
|
||||||
|
"Decoder:",
|
||||||
|
f"{self.decoder}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ClassicalAutoencoder(AAutoencoder):
|
class ClassicalAutoencoder(AAutoencoder):
|
||||||
plotter_cls = CAPlotter
|
plotter_cls = CAPlotter
|
||||||
@@ -64,16 +74,6 @@ class ClassicalAutoencoder(AAutoencoder):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.losses = []
|
self.losses = []
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "\n".join((
|
|
||||||
f"Type: {__class__.__name__}",
|
|
||||||
"Encoder:",
|
|
||||||
f"{self.encoder}",
|
|
||||||
"Decoder:",
|
|
||||||
f"{self.decoder}"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||||
loss = 0
|
loss = 0
|
||||||
for x in data_set:
|
for x in data_set:
|
||||||
@@ -103,7 +103,7 @@ class ClassicalAutoencoder(AAutoencoder):
|
|||||||
self.losses = [self.loss(data_set)]
|
self.losses = [self.loss(data_set)]
|
||||||
epoch = 0
|
epoch = 0
|
||||||
no_improv = 0
|
no_improv = 0
|
||||||
prev_error = self.losses[0]
|
prev_error = self.losses[-1]
|
||||||
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
|
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
|
||||||
while True:
|
while True:
|
||||||
lbar.set_description(
|
lbar.set_description(
|
||||||
@@ -149,15 +149,6 @@ class VariationalAutoencoder(AAutoencoder):
|
|||||||
self.KL_losses = []
|
self.KL_losses = []
|
||||||
self.recon_losses = []
|
self.recon_losses = []
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "\n".join((
|
|
||||||
f"Type: {__class__.__name__}",
|
|
||||||
"Encoder:",
|
|
||||||
f"{self.encoder}",
|
|
||||||
"Decoder:",
|
|
||||||
f"{self.decoder}"
|
|
||||||
))
|
|
||||||
|
|
||||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||||
kl_loss = 0
|
kl_loss = 0
|
||||||
recon_loss = 0
|
recon_loss = 0
|
||||||
@@ -198,7 +189,7 @@ class VariationalAutoencoder(AAutoencoder):
|
|||||||
self.KL_losses = [kl_0]
|
self.KL_losses = [kl_0]
|
||||||
epoch = 0
|
epoch = 0
|
||||||
no_improv = 0
|
no_improv = 0
|
||||||
prev_loss = self.recon_losses[0] + self.KL_losses[0]
|
prev_loss = self.recon_losses[-1] + self.KL_losses[-1]
|
||||||
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
|
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
|
||||||
while True:
|
while True:
|
||||||
lbar.set_description(
|
lbar.set_description(
|
||||||
@@ -260,14 +251,12 @@ class Label:
|
|||||||
self.history[self.idx] = code
|
self.history[self.idx] = code
|
||||||
self.idx += 1
|
self.idx += 1
|
||||||
else:
|
else:
|
||||||
diffs = np.linalg.norm(self.history - code, axis=0)
|
diffs = np.linalg.norm(self.history - code, axis=1)
|
||||||
idx = np.argmin(diffs)
|
idx = np.argmin(diffs)
|
||||||
self.history[idx] = (self.history[idx] + code) / 2
|
self.history[idx] = (self.history[idx] + code) / 2
|
||||||
|
|
||||||
def p(self, x: np.ndarray):
|
def p(self, x: np.ndarray):
|
||||||
return np.mean(
|
return 1 / (1e-4 + np.mean(np.abs(self.history - x)))
|
||||||
np.exp(-np.abs(self.history - x))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LabelingVAE(VariationalAutoencoder):
|
class LabelingVAE(VariationalAutoencoder):
|
||||||
|
|||||||
Reference in New Issue
Block a user