feat: VariationalAutoencoder class + sampling nn layer

This commit is contained in:
Lenoctambule
2026-04-01 22:32:35 +02:00
parent cc74b62afd
commit 577e679425
4 changed files with 130 additions and 48 deletions

View File

@@ -1,6 +1,6 @@
import matplotlib.pyplot as plt
import numpy as np
from autoencoder import Autoencoder
from autoencoder import ClassicalAutoencoder
from activations import LeakyReLU
import os
@@ -21,7 +21,7 @@ def mnist_train(
filename: str,
max_epoch: int,
patience: int,
) -> Autoencoder:
) -> ClassicalAutoencoder:
x_train, _, x_test, _ = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0]
x_train.resize(x_train.shape[0], in_len)
@@ -29,9 +29,9 @@ def mnist_train(
x_train = x_train / 255
x_test = x_test / 255
if os.path.exists(filename):
autoencoder = Autoencoder.load(filename)
autoencoder = ClassicalAutoencoder.load(filename)
else:
autoencoder = Autoencoder(
autoencoder = ClassicalAutoencoder(
[in_len, 64, 16],
[16, 64, in_len],
0.01,
@@ -46,7 +46,7 @@ def mnist_train(
return autoencoder
def mnist_test(model: str | Autoencoder):
def mnist_test(model: str | ClassicalAutoencoder):
x_train, _, x_test, y_test = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0]
img_shape = x_train[0].shape
@@ -55,7 +55,7 @@ def mnist_test(model: str | Autoencoder):
x_train = x_train / 255
x_test = x_test / 255
if isinstance(model, str):
autoencoder: Autoencoder = Autoencoder.load(model)
autoencoder: ClassicalAutoencoder = ClassicalAutoencoder.load(model)
else:
autoencoder = model
print(autoencoder)