feat: working implementation of VAE

This commit is contained in:
Lenoctambule
2026-04-05 01:17:51 +02:00
parent 577e679425
commit 5a8fb2c48b
3 changed files with 42 additions and 27 deletions

View File

@@ -121,7 +121,6 @@ class VariationalAutoencoder(AAutoencoder):
def __init__(self,
encoder_layers: list[int],
decoder_layers: list[int],
sampling_size: int,
lr: float,
activation_func: ActivationFunc):
if encoder_layers[-1] != decoder_layers[0]:
@@ -131,7 +130,12 @@ class VariationalAutoencoder(AAutoencoder):
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
self.sampler = SampleLayer(self.encoder.out_size, lr, activation_func)
self.sampling_size = sampling_size
def loss(self, data_set: list[np.ndarray]) -> float:
loss = 0
for x in data_set:
loss += np.sum(np.abs(x - self.forward(x)[0])) / len(x)
return loss / len(data_set)
def load(path: str) -> 'ClassicalAutoencoder':
path = path.removesuffix('.npy') + '.npy'
@@ -139,16 +143,17 @@ class VariationalAutoencoder(AAutoencoder):
return data.item()
def train(self, v: np.ndarray) -> float:
out_enc = self.encoder.forward(v)
in_samples = np.zeros(
(self.sampling_size, self.encoder.out_size)
out = self.forward(v)
error = out - v
self.encoder.backprop(
self.sampler.backprop(
self.decoder.backprop(error)
)
out_samples = np.zeros(
(self.sampling_size, self.decoder.out_size)
)
for i in range(self.sampling_size):
in_samples[i] = self.sampler.forward(out_enc)
out_samples[i] = self.decoder.forward(in_samples[i])
)
return np.sum(np.abs(error)) / len(v)
def forward(self, v: np.ndarray) -> np.ndarray:
pass
code = self.encoder.forward(v)
sample = self.sampler.forward(code)
out = self.decoder.forward(sample)
return out, code

View File

@@ -29,7 +29,7 @@ class NNLayer:
return self.output
def backprop(self, error: np.ndarray) -> np.ndarray:
error *= self.activation_func.derivative(self.output_linear)
error *= self.activation_func.d(self.output_linear)
ret = self.W @ error
dW = np.outer(self.input, error) * self.lr
dB = error * self.lr
@@ -57,12 +57,15 @@ class SampleLayer:
def forward(self, v: np.ndarray) -> np.ndarray:
self.input = v
mean = self.mean_nn.forward(v)
std = self.std_nn.forward(v)
return np.random.normal(mean, std, 1)
self.mean = self.mean_nn.forward(v)
self.std = self.std_nn.forward(v)
self.eps = np.random.normal(0, 1)
return self.eps * self.std + self.mean
def backprop(self, errors: np.ndarray) -> np.ndarray:
pass
def backprop(self, error: np.ndarray) -> np.ndarray:
mu_error = self.mean_nn.backprop(error)
std_error = self.std_nn.backprop(self.eps * error)
return mu_error + std_error
class DeepNNLayer:

View File

@@ -1,6 +1,6 @@
import matplotlib.pyplot as plt
import numpy as np
from autoencoder import ClassicalAutoencoder
from autoencoder import VariationalAutoencoder, AAutoencoder
from activations import LeakyReLU
import os
@@ -21,19 +21,21 @@ def mnist_train(
filename: str,
max_epoch: int,
patience: int,
) -> ClassicalAutoencoder:
cls: type[AAutoencoder]
) -> AAutoencoder:
x_train, _, x_test, _ = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0]
x_train.resize(x_train.shape[0], in_len)
x_test.resize(x_test.shape[0], in_len)
x_train = x_train / 255
x_test = x_test / 255
x_train = x_train[:5000]
if os.path.exists(filename):
autoencoder = ClassicalAutoencoder.load(filename)
autoencoder = cls.load(filename)
else:
autoencoder = ClassicalAutoencoder(
[in_len, 64, 16],
[16, 64, in_len],
autoencoder = cls(
[in_len, 16],
[16, in_len],
0.01,
LeakyReLU()
)
@@ -46,7 +48,7 @@ def mnist_train(
return autoencoder
def mnist_test(model: str | ClassicalAutoencoder):
def mnist_test(model: str | AAutoencoder):
x_train, _, x_test, y_test = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0]
img_shape = x_train[0].shape
@@ -55,7 +57,7 @@ def mnist_test(model: str | ClassicalAutoencoder):
x_train = x_train / 255
x_test = x_test / 255
if isinstance(model, str):
autoencoder: ClassicalAutoencoder = ClassicalAutoencoder.load(model)
autoencoder: AAutoencoder = AAutoencoder.load(model)
else:
autoencoder = model
print(autoencoder)
@@ -114,5 +116,10 @@ if __name__ == "__main__":
if args.r:
mnist_test(args.m)
else:
autoencoder = mnist_train(args.m, args.e, args.p)
autoencoder = mnist_train(
args.m,
args.e,
args.p,
VariationalAutoencoder
)
mnist_test(autoencoder)