feat: working implementation of VAE
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from autoencoder import ClassicalAutoencoder
|
||||
from autoencoder import VariationalAutoencoder, AAutoencoder
|
||||
from activations import LeakyReLU
|
||||
import os
|
||||
|
||||
@@ -21,19 +21,21 @@ def mnist_train(
|
||||
filename: str,
|
||||
max_epoch: int,
|
||||
patience: int,
|
||||
) -> ClassicalAutoencoder:
|
||||
cls: type[AAutoencoder]
|
||||
) -> AAutoencoder:
|
||||
x_train, _, x_test, _ = load_mnist()
|
||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||
x_train.resize(x_train.shape[0], in_len)
|
||||
x_test.resize(x_test.shape[0], in_len)
|
||||
x_train = x_train / 255
|
||||
x_test = x_test / 255
|
||||
x_train = x_train[:5000]
|
||||
if os.path.exists(filename):
|
||||
autoencoder = ClassicalAutoencoder.load(filename)
|
||||
autoencoder = cls.load(filename)
|
||||
else:
|
||||
autoencoder = ClassicalAutoencoder(
|
||||
[in_len, 64, 16],
|
||||
[16, 64, in_len],
|
||||
autoencoder = cls(
|
||||
[in_len, 16],
|
||||
[16, in_len],
|
||||
0.01,
|
||||
LeakyReLU()
|
||||
)
|
||||
@@ -46,7 +48,7 @@ def mnist_train(
|
||||
return autoencoder
|
||||
|
||||
|
||||
def mnist_test(model: str | ClassicalAutoencoder):
|
||||
def mnist_test(model: str | AAutoencoder):
|
||||
x_train, _, x_test, y_test = load_mnist()
|
||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||
img_shape = x_train[0].shape
|
||||
@@ -55,7 +57,7 @@ def mnist_test(model: str | ClassicalAutoencoder):
|
||||
x_train = x_train / 255
|
||||
x_test = x_test / 255
|
||||
if isinstance(model, str):
|
||||
autoencoder: ClassicalAutoencoder = ClassicalAutoencoder.load(model)
|
||||
autoencoder: AAutoencoder = AAutoencoder.load(model)
|
||||
else:
|
||||
autoencoder = model
|
||||
print(autoencoder)
|
||||
@@ -114,5 +116,10 @@ if __name__ == "__main__":
|
||||
if args.r:
|
||||
mnist_test(args.m)
|
||||
else:
|
||||
autoencoder = mnist_train(args.m, args.e, args.p)
|
||||
autoencoder = mnist_train(
|
||||
args.m,
|
||||
args.e,
|
||||
args.p,
|
||||
VariationalAutoencoder
|
||||
)
|
||||
mnist_test(autoencoder)
|
||||
|
||||
Reference in New Issue
Block a user