refactor(autoencoder.py): __init__ code de-dup
This commit is contained in:
@@ -8,6 +8,7 @@ from easyvae.autoencoder import ( # noqa
|
||||
AAutoencoder
|
||||
)
|
||||
from easyvae.activations import LeakyReLU
|
||||
from easyvae.utils import dynamic_loss_plot_finish
|
||||
|
||||
|
||||
def load_mnist() -> list[np.ndarray]:
|
||||
@@ -38,15 +39,15 @@ def mnist_train(
|
||||
autoencoder = cls(
|
||||
[in_len, 256, 2],
|
||||
[2, 256, in_len],
|
||||
0.001,
|
||||
0.0001,
|
||||
LeakyReLU()
|
||||
)
|
||||
|
||||
def handler(signum, frame):
|
||||
print(f"Saving {filename} before exit ...")
|
||||
autoencoder.save(filename)
|
||||
plt.close()
|
||||
plt.ioff()
|
||||
if plt.get_fignums():
|
||||
dynamic_loss_plot_finish()
|
||||
mnist_test(autoencoder)
|
||||
exit()
|
||||
|
||||
@@ -62,6 +63,45 @@ def mnist_train(
|
||||
return autoencoder
|
||||
|
||||
|
||||
def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,):
|
||||
codes = []
|
||||
for x in x:
|
||||
_, c = autoencoder.forward(x.flatten())
|
||||
codes.append(c)
|
||||
codes = np.array(codes)
|
||||
if codes.shape[1] == 2:
|
||||
plt.figure(figsize=(6, 6))
|
||||
scatter = plt.scatter(
|
||||
codes[:, 0],
|
||||
codes[:, 1],
|
||||
c=y,
|
||||
cmap='tab10',
|
||||
s=5,
|
||||
alpha=0.7
|
||||
)
|
||||
plt.colorbar(scatter)
|
||||
plt.grid(True)
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_random_reconstruction(autoencoder: AAutoencoder,
|
||||
example: np.ndarray,
|
||||
img_shape,
|
||||
y):
|
||||
output, code = autoencoder.forward(example.flatten())
|
||||
plt.subplot(1, 3, 1)
|
||||
plt.matshow(
|
||||
example.reshape(img_shape),
|
||||
fignum=False)
|
||||
plt.title(f"Input ({y})")
|
||||
plt.subplot(1, 3, 2)
|
||||
plt.matshow(
|
||||
output.reshape(img_shape),
|
||||
fignum=False)
|
||||
plt.title(f"Output ({y})")
|
||||
print(f'{code=}')
|
||||
|
||||
|
||||
def mnist_test(model: str | AAutoencoder):
|
||||
x_train, _, x_test, y_test = load_mnist()
|
||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||
@@ -76,41 +116,9 @@ def mnist_test(model: str | AAutoencoder):
|
||||
autoencoder = model
|
||||
idx = np.random.randint(0, len(x_test))
|
||||
example: np.ndarray = x_test[idx]
|
||||
output, code = autoencoder.forward(example.flatten())
|
||||
plt.subplot(1, 3, 1)
|
||||
plt.matshow(
|
||||
example.reshape(img_shape),
|
||||
fignum=False)
|
||||
plt.title(f"Input ({y_test[idx]})")
|
||||
plt.subplot(1, 3, 2)
|
||||
plt.matshow(
|
||||
output.reshape(img_shape),
|
||||
fignum=False)
|
||||
plt.title(f"Output ({y_test[idx]})")
|
||||
plt.subplot(1, 3, 3)
|
||||
code = np.reshape(code, (code.shape[0], 1))
|
||||
plt.matshow(code, fignum=False)
|
||||
plt.title(f"Code ({y_test[idx]})")
|
||||
plt.show()
|
||||
if code.shape[0] == 2:
|
||||
codes = []
|
||||
for x in x_test:
|
||||
_, c = autoencoder.forward(x.flatten())
|
||||
codes.append(c)
|
||||
codes = np.array(codes)
|
||||
if codes.shape[1] == 2:
|
||||
plt.figure(figsize=(6, 6))
|
||||
scatter = plt.scatter(
|
||||
codes[:, 0],
|
||||
codes[:, 1],
|
||||
c=y_test,
|
||||
cmap='tab10',
|
||||
s=5,
|
||||
alpha=0.7
|
||||
)
|
||||
plt.colorbar(scatter)
|
||||
plt.grid(True)
|
||||
plt.show()
|
||||
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])
|
||||
if autoencoder.space_dim == 2:
|
||||
plot_mnist_latent_space(autoencoder, x_test, y_test)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -122,14 +130,14 @@ if __name__ == "__main__":
|
||||
'-e',
|
||||
type=int,
|
||||
nargs='?',
|
||||
default=1000,
|
||||
default=30,
|
||||
help='Max epochs'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-p',
|
||||
type=int,
|
||||
nargs='?',
|
||||
default=5,
|
||||
default=30,
|
||||
help='Patience'
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -141,7 +149,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
'-r',
|
||||
action='store_true',
|
||||
help='Run mode'
|
||||
help='Run the model'
|
||||
)
|
||||
args = parser.parse_args(sys.argv[1:])
|
||||
if args.r:
|
||||
|
||||
@@ -13,6 +13,21 @@ LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
||||
|
||||
|
||||
class AAutoencoder(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self,
|
||||
encoder_layers: list[int],
|
||||
decoder_layers: list[int],
|
||||
lr: float,
|
||||
activation_func: ActivationFunc):
|
||||
if encoder_layers[-1] != decoder_layers[0]:
|
||||
raise Exception(
|
||||
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
|
||||
)
|
||||
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
||||
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
||||
self.space_dim = decoder_layers[0]
|
||||
self.lr = lr
|
||||
|
||||
def train_dataset(self,
|
||||
data_set: list[np.ndarray],
|
||||
max_epoch: int,
|
||||
@@ -75,17 +90,8 @@ class AAutoencoder(ABC):
|
||||
|
||||
|
||||
class ClassicalAutoencoder(AAutoencoder):
|
||||
def __init__(self,
|
||||
encoder_layers: list[int],
|
||||
decoder_layers: list[int],
|
||||
lr: float,
|
||||
activation_func: ActivationFunc):
|
||||
if encoder_layers[-1] != decoder_layers[0]:
|
||||
raise Exception(
|
||||
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
|
||||
)
|
||||
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
||||
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}'
|
||||
@@ -119,18 +125,9 @@ class ClassicalAutoencoder(AAutoencoder):
|
||||
|
||||
|
||||
class VariationalAutoencoder(AAutoencoder):
|
||||
def __init__(self,
|
||||
encoder_layers: list[int],
|
||||
decoder_layers: list[int],
|
||||
lr: float,
|
||||
activation_func: ActivationFunc):
|
||||
if encoder_layers[-1] != decoder_layers[0]:
|
||||
raise Exception(
|
||||
f"Encoder output and decoder input don't match {encoder_layers[-1]} != {encoder_layers[0]}" # noqa
|
||||
)
|
||||
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
||||
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
||||
self.sampler = SampleLayer(self.encoder.out_size, lr, Identity())
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity())
|
||||
|
||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||
loss = 0
|
||||
|
||||
Reference in New Issue
Block a user