feat: VariationalAutoencoder class + sampling nn layer
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from autoencoder import Autoencoder
|
||||
from autoencoder import ClassicalAutoencoder
|
||||
from activations import LeakyReLU
|
||||
import os
|
||||
|
||||
@@ -21,7 +21,7 @@ def mnist_train(
|
||||
filename: str,
|
||||
max_epoch: int,
|
||||
patience: int,
|
||||
) -> Autoencoder:
|
||||
) -> ClassicalAutoencoder:
|
||||
x_train, _, x_test, _ = load_mnist()
|
||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||
x_train.resize(x_train.shape[0], in_len)
|
||||
@@ -29,9 +29,9 @@ def mnist_train(
|
||||
x_train = x_train / 255
|
||||
x_test = x_test / 255
|
||||
if os.path.exists(filename):
|
||||
autoencoder = Autoencoder.load(filename)
|
||||
autoencoder = ClassicalAutoencoder.load(filename)
|
||||
else:
|
||||
autoencoder = Autoencoder(
|
||||
autoencoder = ClassicalAutoencoder(
|
||||
[in_len, 64, 16],
|
||||
[16, 64, in_len],
|
||||
0.01,
|
||||
@@ -46,7 +46,7 @@ def mnist_train(
|
||||
return autoencoder
|
||||
|
||||
|
||||
def mnist_test(model: str | Autoencoder):
|
||||
def mnist_test(model: str | ClassicalAutoencoder):
|
||||
x_train, _, x_test, y_test = load_mnist()
|
||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||
img_shape = x_train[0].shape
|
||||
@@ -55,7 +55,7 @@ def mnist_test(model: str | Autoencoder):
|
||||
x_train = x_train / 255
|
||||
x_test = x_test / 255
|
||||
if isinstance(model, str):
|
||||
autoencoder: Autoencoder = Autoencoder.load(model)
|
||||
autoencoder: ClassicalAutoencoder = ClassicalAutoencoder.load(model)
|
||||
else:
|
||||
autoencoder = model
|
||||
print(autoencoder)
|
||||
|
||||
Reference in New Issue
Block a user