diff --git a/README.md b/README.md
index 9e75be6..8c04145 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,15 @@
# Python AutoEncoder from scratch using Numpy
+
+
+
+
+ Latent-space representation of the MNIST dataset using Variational Autoencoder
+
+
+
## Usage
diff --git a/examples/mnist_test.py b/examples/mnist_test.py
index 1058590..5e68c9a 100644
--- a/examples/mnist_test.py
+++ b/examples/mnist_test.py
@@ -8,6 +8,7 @@ from easyvae.autoencoder import ( # noqa
AAutoencoder
)
from easyvae.activations import LeakyReLU
+from easyvae.utils import dynamic_loss_plot_finish
def load_mnist() -> list[np.ndarray]:
@@ -38,15 +39,15 @@ def mnist_train(
autoencoder = cls(
[in_len, 256, 2],
[2, 256, in_len],
- 0.001,
+ 0.0001,
LeakyReLU()
)
def handler(signum, frame):
print(f"Saving {filename} before exit ...")
autoencoder.save(filename)
- plt.close()
- plt.ioff()
+ if plt.get_fignums():
+ dynamic_loss_plot_finish()
mnist_test(autoencoder)
exit()
@@ -62,6 +63,45 @@ def mnist_train(
return autoencoder
+def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,):
+ codes = []
+ for x in x:
+ _, c = autoencoder.forward(x.flatten())
+ codes.append(c)
+ codes = np.array(codes)
+ if codes.shape[1] == 2:
+ plt.figure(figsize=(6, 6))
+ scatter = plt.scatter(
+ codes[:, 0],
+ codes[:, 1],
+ c=y,
+ cmap='tab10',
+ s=5,
+ alpha=0.7
+ )
+ plt.colorbar(scatter)
+ plt.grid(True)
+ plt.show()
+
+
+def plot_random_reconstruction(autoencoder: AAutoencoder,
+ example: np.ndarray,
+ img_shape,
+ y):
+ output, code = autoencoder.forward(example.flatten())
+ plt.subplot(1, 2, 1)
+ plt.matshow(
+ example.reshape(img_shape),
+ fignum=False)
+ plt.title(f"Input ({y})")
+ plt.subplot(1, 2, 2)
+ plt.matshow(
+ output.reshape(img_shape),
+ fignum=False)
+ plt.title(f"Output ({y})")
+ print(f'{code=}')
+
+
def mnist_test(model: str | AAutoencoder):
x_train, _, x_test, y_test = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0]
@@ -76,41 +116,9 @@ def mnist_test(model: str | AAutoencoder):
autoencoder = model
idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx]
- output, code = autoencoder.forward(example.flatten())
- plt.subplot(1, 3, 1)
- plt.matshow(
- example.reshape(img_shape),
- fignum=False)
- plt.title(f"Input ({y_test[idx]})")
- plt.subplot(1, 3, 2)
- plt.matshow(
- output.reshape(img_shape),
- fignum=False)
- plt.title(f"Output ({y_test[idx]})")
- plt.subplot(1, 3, 3)
- code = np.reshape(code, (code.shape[0], 1))
- plt.matshow(code, fignum=False)
- plt.title(f"Code ({y_test[idx]})")
- plt.show()
- if code.shape[0] == 2:
- codes = []
- for x in x_test:
- _, c = autoencoder.forward(x.flatten())
- codes.append(c)
- codes = np.array(codes)
- if codes.shape[1] == 2:
- plt.figure(figsize=(6, 6))
- scatter = plt.scatter(
- codes[:, 0],
- codes[:, 1],
- c=y_test,
- cmap='tab10',
- s=5,
- alpha=0.7
- )
- plt.colorbar(scatter)
- plt.grid(True)
- plt.show()
+ plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])
+ if autoencoder.space_dim == 2:
+ plot_mnist_latent_space(autoencoder, x_test, y_test)
if __name__ == "__main__":
@@ -122,14 +130,14 @@ if __name__ == "__main__":
'-e',
type=int,
nargs='?',
- default=1000,
+ default=30,
help='Max epochs'
)
parser.add_argument(
'-p',
type=int,
nargs='?',
- default=5,
+ default=30,
help='Patience'
)
parser.add_argument(
@@ -141,7 +149,7 @@ if __name__ == "__main__":
parser.add_argument(
'-r',
action='store_true',
- help='Run mode'
+ help='Run the model'
)
args = parser.parse_args(sys.argv[1:])
if args.r:
diff --git a/media/latent-space.png b/media/latent-space.png
new file mode 100644
index 0000000..1ef57b3
Binary files /dev/null and b/media/latent-space.png differ
diff --git a/pyproject.toml b/pyproject.toml
index ec0803a..26b15c0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "easyvae"
-version = "1.0"
+version = "1.1"
authors = [
{ name="Ravaka RALAMBOARIVONY", email="ravaka.rlb.pro@gmail.com" },
]
diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py
index 410699c..5f436c9 100644
--- a/src/easyvae/autoencoder.py
+++ b/src/easyvae/autoencoder.py
@@ -13,6 +13,21 @@ LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
class AAutoencoder(ABC):
+ @abstractmethod
+ def __init__(self,
+ encoder_layers: list[int],
+ decoder_layers: list[int],
+ lr: float,
+ activation_func: ActivationFunc):
+ if encoder_layers[-1] != decoder_layers[0]:
+ raise Exception(
+ f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
+ )
+ self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
+ self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
+ self.space_dim = decoder_layers[0]
+ self.lr = lr
+
def train_dataset(self,
data_set: list[np.ndarray],
max_epoch: int,
@@ -49,7 +64,7 @@ class AAutoencoder(ABC):
break
epoch += 1
if display_loss is True:
- dynamic_loss_plot_finish(ax, line)
+ dynamic_loss_plot_finish()
return losses
def save(self, path: str):
@@ -75,17 +90,8 @@ class AAutoencoder(ABC):
class ClassicalAutoencoder(AAutoencoder):
- def __init__(self,
- encoder_layers: list[int],
- decoder_layers: list[int],
- lr: float,
- activation_func: ActivationFunc):
- if encoder_layers[-1] != decoder_layers[0]:
- raise Exception(
- f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
- )
- self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
- self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
def __str__(self):
return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}'
@@ -119,18 +125,9 @@ class ClassicalAutoencoder(AAutoencoder):
class VariationalAutoencoder(AAutoencoder):
- def __init__(self,
- encoder_layers: list[int],
- decoder_layers: list[int],
- lr: float,
- activation_func: ActivationFunc):
- if encoder_layers[-1] != decoder_layers[0]:
- raise Exception(
- f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
- )
- self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
- self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
- self.sampler = SampleLayer(self.encoder.out_size, lr, Identity())
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity())
def loss(self, data_set: list[np.ndarray]) -> float:
loss = 0
diff --git a/src/easyvae/utils.py b/src/easyvae/utils.py
index e407bff..7e61971 100644
--- a/src/easyvae/utils.py
+++ b/src/easyvae/utils.py
@@ -41,6 +41,6 @@ def dynamic_loss_plot_update(ax, line, loss):
plt.pause(0.1)
-def dynamic_loss_plot_finish(ax, line):
+def dynamic_loss_plot_finish():
plt.ioff()
plt.show()