Merge pull request #1 from lenoctambule/dev
Bug fixes and README improvements
This commit is contained in:
11
README.md
11
README.md
@@ -1,4 +1,15 @@
|
|||||||
# Python AutoEncoder from scratch using Numpy
|
# Python AutoEncoder from scratch using Numpy
|
||||||
|
<center>
|
||||||
|
<figure>
|
||||||
|
<img
|
||||||
|
src="./media/latent-space.png"
|
||||||
|
alt="Latent-space of the MNIST dataset"
|
||||||
|
width=70%>
|
||||||
|
<figcaption>
|
||||||
|
Latent-space representation of the MNIST dataset using Variational Autoencoder
|
||||||
|
</figcaption>
|
||||||
|
</figure>
|
||||||
|
</center>
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from easyvae.autoencoder import ( # noqa
|
|||||||
AAutoencoder
|
AAutoencoder
|
||||||
)
|
)
|
||||||
from easyvae.activations import LeakyReLU
|
from easyvae.activations import LeakyReLU
|
||||||
|
from easyvae.utils import dynamic_loss_plot_finish
|
||||||
|
|
||||||
|
|
||||||
def load_mnist() -> list[np.ndarray]:
|
def load_mnist() -> list[np.ndarray]:
|
||||||
@@ -38,15 +39,15 @@ def mnist_train(
|
|||||||
autoencoder = cls(
|
autoencoder = cls(
|
||||||
[in_len, 256, 2],
|
[in_len, 256, 2],
|
||||||
[2, 256, in_len],
|
[2, 256, in_len],
|
||||||
0.001,
|
0.0001,
|
||||||
LeakyReLU()
|
LeakyReLU()
|
||||||
)
|
)
|
||||||
|
|
||||||
def handler(signum, frame):
|
def handler(signum, frame):
|
||||||
print(f"Saving {filename} before exit ...")
|
print(f"Saving {filename} before exit ...")
|
||||||
autoencoder.save(filename)
|
autoencoder.save(filename)
|
||||||
plt.close()
|
if plt.get_fignums():
|
||||||
plt.ioff()
|
dynamic_loss_plot_finish()
|
||||||
mnist_test(autoencoder)
|
mnist_test(autoencoder)
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
@@ -62,6 +63,45 @@ def mnist_train(
|
|||||||
return autoencoder
|
return autoencoder
|
||||||
|
|
||||||
|
|
||||||
|
def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,):
|
||||||
|
codes = []
|
||||||
|
for x in x:
|
||||||
|
_, c = autoencoder.forward(x.flatten())
|
||||||
|
codes.append(c)
|
||||||
|
codes = np.array(codes)
|
||||||
|
if codes.shape[1] == 2:
|
||||||
|
plt.figure(figsize=(6, 6))
|
||||||
|
scatter = plt.scatter(
|
||||||
|
codes[:, 0],
|
||||||
|
codes[:, 1],
|
||||||
|
c=y,
|
||||||
|
cmap='tab10',
|
||||||
|
s=5,
|
||||||
|
alpha=0.7
|
||||||
|
)
|
||||||
|
plt.colorbar(scatter)
|
||||||
|
plt.grid(True)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_random_reconstruction(autoencoder: AAutoencoder,
|
||||||
|
example: np.ndarray,
|
||||||
|
img_shape,
|
||||||
|
y):
|
||||||
|
output, code = autoencoder.forward(example.flatten())
|
||||||
|
plt.subplot(1, 2, 1)
|
||||||
|
plt.matshow(
|
||||||
|
example.reshape(img_shape),
|
||||||
|
fignum=False)
|
||||||
|
plt.title(f"Input ({y})")
|
||||||
|
plt.subplot(1, 2, 2)
|
||||||
|
plt.matshow(
|
||||||
|
output.reshape(img_shape),
|
||||||
|
fignum=False)
|
||||||
|
plt.title(f"Output ({y})")
|
||||||
|
print(f'{code=}')
|
||||||
|
|
||||||
|
|
||||||
def mnist_test(model: str | AAutoencoder):
|
def mnist_test(model: str | AAutoencoder):
|
||||||
x_train, _, x_test, y_test = load_mnist()
|
x_train, _, x_test, y_test = load_mnist()
|
||||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||||
@@ -76,41 +116,9 @@ def mnist_test(model: str | AAutoencoder):
|
|||||||
autoencoder = model
|
autoencoder = model
|
||||||
idx = np.random.randint(0, len(x_test))
|
idx = np.random.randint(0, len(x_test))
|
||||||
example: np.ndarray = x_test[idx]
|
example: np.ndarray = x_test[idx]
|
||||||
output, code = autoencoder.forward(example.flatten())
|
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])
|
||||||
plt.subplot(1, 3, 1)
|
if autoencoder.space_dim == 2:
|
||||||
plt.matshow(
|
plot_mnist_latent_space(autoencoder, x_test, y_test)
|
||||||
example.reshape(img_shape),
|
|
||||||
fignum=False)
|
|
||||||
plt.title(f"Input ({y_test[idx]})")
|
|
||||||
plt.subplot(1, 3, 2)
|
|
||||||
plt.matshow(
|
|
||||||
output.reshape(img_shape),
|
|
||||||
fignum=False)
|
|
||||||
plt.title(f"Output ({y_test[idx]})")
|
|
||||||
plt.subplot(1, 3, 3)
|
|
||||||
code = np.reshape(code, (code.shape[0], 1))
|
|
||||||
plt.matshow(code, fignum=False)
|
|
||||||
plt.title(f"Code ({y_test[idx]})")
|
|
||||||
plt.show()
|
|
||||||
if code.shape[0] == 2:
|
|
||||||
codes = []
|
|
||||||
for x in x_test:
|
|
||||||
_, c = autoencoder.forward(x.flatten())
|
|
||||||
codes.append(c)
|
|
||||||
codes = np.array(codes)
|
|
||||||
if codes.shape[1] == 2:
|
|
||||||
plt.figure(figsize=(6, 6))
|
|
||||||
scatter = plt.scatter(
|
|
||||||
codes[:, 0],
|
|
||||||
codes[:, 1],
|
|
||||||
c=y_test,
|
|
||||||
cmap='tab10',
|
|
||||||
s=5,
|
|
||||||
alpha=0.7
|
|
||||||
)
|
|
||||||
plt.colorbar(scatter)
|
|
||||||
plt.grid(True)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -122,14 +130,14 @@ if __name__ == "__main__":
|
|||||||
'-e',
|
'-e',
|
||||||
type=int,
|
type=int,
|
||||||
nargs='?',
|
nargs='?',
|
||||||
default=1000,
|
default=30,
|
||||||
help='Max epochs'
|
help='Max epochs'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-p',
|
'-p',
|
||||||
type=int,
|
type=int,
|
||||||
nargs='?',
|
nargs='?',
|
||||||
default=5,
|
default=30,
|
||||||
help='Patience'
|
help='Patience'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -141,7 +149,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-r',
|
'-r',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Run mode'
|
help='Run the model'
|
||||||
)
|
)
|
||||||
args = parser.parse_args(sys.argv[1:])
|
args = parser.parse_args(sys.argv[1:])
|
||||||
if args.r:
|
if args.r:
|
||||||
|
|||||||
BIN
media/latent-space.png
Normal file
BIN
media/latent-space.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 181 KiB |
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "easyvae"
|
name = "easyvae"
|
||||||
version = "1.0"
|
version = "1.1"
|
||||||
authors = [
|
authors = [
|
||||||
{ name="Ravaka RALAMBOARIVONY", email="ravaka.rlb.pro@gmail.com" },
|
{ name="Ravaka RALAMBOARIVONY", email="ravaka.rlb.pro@gmail.com" },
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -13,6 +13,21 @@ LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
|||||||
|
|
||||||
|
|
||||||
class AAutoencoder(ABC):
|
class AAutoencoder(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self,
|
||||||
|
encoder_layers: list[int],
|
||||||
|
decoder_layers: list[int],
|
||||||
|
lr: float,
|
||||||
|
activation_func: ActivationFunc):
|
||||||
|
if encoder_layers[-1] != decoder_layers[0]:
|
||||||
|
raise Exception(
|
||||||
|
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
|
||||||
|
)
|
||||||
|
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
||||||
|
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
||||||
|
self.space_dim = decoder_layers[0]
|
||||||
|
self.lr = lr
|
||||||
|
|
||||||
def train_dataset(self,
|
def train_dataset(self,
|
||||||
data_set: list[np.ndarray],
|
data_set: list[np.ndarray],
|
||||||
max_epoch: int,
|
max_epoch: int,
|
||||||
@@ -49,7 +64,7 @@ class AAutoencoder(ABC):
|
|||||||
break
|
break
|
||||||
epoch += 1
|
epoch += 1
|
||||||
if display_loss is True:
|
if display_loss is True:
|
||||||
dynamic_loss_plot_finish(ax, line)
|
dynamic_loss_plot_finish()
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def save(self, path: str):
|
def save(self, path: str):
|
||||||
@@ -75,17 +90,8 @@ class AAutoencoder(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class ClassicalAutoencoder(AAutoencoder):
|
class ClassicalAutoencoder(AAutoencoder):
|
||||||
def __init__(self,
|
def __init__(self, *args, **kwargs):
|
||||||
encoder_layers: list[int],
|
super().__init__(*args, **kwargs)
|
||||||
decoder_layers: list[int],
|
|
||||||
lr: float,
|
|
||||||
activation_func: ActivationFunc):
|
|
||||||
if encoder_layers[-1] != decoder_layers[0]:
|
|
||||||
raise Exception(
|
|
||||||
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
|
|
||||||
)
|
|
||||||
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
|
||||||
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}'
|
return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}'
|
||||||
@@ -119,18 +125,9 @@ class ClassicalAutoencoder(AAutoencoder):
|
|||||||
|
|
||||||
|
|
||||||
class VariationalAutoencoder(AAutoencoder):
|
class VariationalAutoencoder(AAutoencoder):
|
||||||
def __init__(self,
|
def __init__(self, *args, **kwargs):
|
||||||
encoder_layers: list[int],
|
super().__init__(*args, **kwargs)
|
||||||
decoder_layers: list[int],
|
self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity())
|
||||||
lr: float,
|
|
||||||
activation_func: ActivationFunc):
|
|
||||||
if encoder_layers[-1] != decoder_layers[0]:
|
|
||||||
raise Exception(
|
|
||||||
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
|
|
||||||
)
|
|
||||||
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
|
||||||
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
|
||||||
self.sampler = SampleLayer(self.encoder.out_size, lr, Identity())
|
|
||||||
|
|
||||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||||
loss = 0
|
loss = 0
|
||||||
|
|||||||
@@ -41,6 +41,6 @@ def dynamic_loss_plot_update(ax, line, loss):
|
|||||||
plt.pause(0.1)
|
plt.pause(0.1)
|
||||||
|
|
||||||
|
|
||||||
def dynamic_loss_plot_finish(ax, line):
|
def dynamic_loss_plot_finish():
|
||||||
plt.ioff()
|
plt.ioff()
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|||||||
Reference in New Issue
Block a user