feat: post training and online labeling VAE class

This commit is contained in:
Lenoctambule
2026-04-14 14:12:22 +02:00
parent e6c1229a7e
commit b1dc34e699
2 changed files with 64 additions and 4 deletions

View File

@@ -4,6 +4,7 @@ import os
from easyvae.autoencoder import ( # noqa
VariationalAutoencoder,
ClassicalAutoencoder,
LabelingVAE,
AAutoencoder
)
from easyvae.activations import LeakyReLU
@@ -26,10 +27,9 @@ def mnist_train(
max_epoch: int,
patience: int,
cls: type[AAutoencoder],) -> AAutoencoder:
x_train, _, x_test, _ = load_mnist()
x_train, _, _, _ = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0]
x_train.resize(x_train.shape[0], in_len)
x_test.resize(x_test.shape[0], in_len)
x_train = x_train / 255
if os.path.exists(filename):
autoencoder = cls.load(filename)