feat: post training and online labeling VAE class
This commit is contained in:
@@ -4,6 +4,7 @@ import os
|
||||
from easyvae.autoencoder import ( # noqa
|
||||
VariationalAutoencoder,
|
||||
ClassicalAutoencoder,
|
||||
LabelingVAE,
|
||||
AAutoencoder
|
||||
)
|
||||
from easyvae.activations import LeakyReLU
|
||||
@@ -26,10 +27,9 @@ def mnist_train(
|
||||
max_epoch: int,
|
||||
patience: int,
|
||||
cls: type[AAutoencoder],) -> AAutoencoder:
|
||||
x_train, _, x_test, _ = load_mnist()
|
||||
x_train, _, _, _ = load_mnist()
|
||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||
x_train.resize(x_train.shape[0], in_len)
|
||||
x_test.resize(x_test.shape[0], in_len)
|
||||
x_train = x_train / 255
|
||||
if os.path.exists(filename):
|
||||
autoencoder = cls.load(filename)
|
||||
|
||||
Reference in New Issue
Block a user