feat: working implementation of VAE

This commit is contained in:
Lenoctambule
2026-04-05 01:17:51 +02:00
parent 577e679425
commit 5a8fb2c48b
3 changed files with 42 additions and 27 deletions

View File

@@ -1,6 +1,6 @@
import matplotlib.pyplot as plt
import numpy as np
from autoencoder import ClassicalAutoencoder
from autoencoder import VariationalAutoencoder, AAutoencoder
from activations import LeakyReLU
import os
@@ -21,19 +21,21 @@ def mnist_train(
filename: str,
max_epoch: int,
patience: int,
) -> ClassicalAutoencoder:
cls: type[AAutoencoder]
) -> AAutoencoder:
x_train, _, x_test, _ = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0]
x_train.resize(x_train.shape[0], in_len)
x_test.resize(x_test.shape[0], in_len)
x_train = x_train / 255
x_test = x_test / 255
x_train = x_train[:5000]
if os.path.exists(filename):
autoencoder = ClassicalAutoencoder.load(filename)
autoencoder = cls.load(filename)
else:
autoencoder = ClassicalAutoencoder(
[in_len, 64, 16],
[16, 64, in_len],
autoencoder = cls(
[in_len, 16],
[16, in_len],
0.01,
LeakyReLU()
)
@@ -46,7 +48,7 @@ def mnist_train(
return autoencoder
def mnist_test(model: str | ClassicalAutoencoder):
def mnist_test(model: str | AAutoencoder):
x_train, _, x_test, y_test = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0]
img_shape = x_train[0].shape
@@ -55,7 +57,7 @@ def mnist_test(model: str | ClassicalAutoencoder):
x_train = x_train / 255
x_test = x_test / 255
if isinstance(model, str):
autoencoder: ClassicalAutoencoder = ClassicalAutoencoder.load(model)
autoencoder: AAutoencoder = AAutoencoder.load(model)
else:
autoencoder = model
print(autoencoder)
@@ -114,5 +116,10 @@ if __name__ == "__main__":
if args.r:
mnist_test(args.m)
else:
autoencoder = mnist_train(args.m, args.e, args.p)
autoencoder = mnist_train(
args.m,
args.e,
args.p,
VariationalAutoencoder
)
mnist_test(autoencoder)