From a4334568ecf58cd3e17c9ea00f0f175f4c5371a9 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Sun, 12 Apr 2026 19:40:04 +0200 Subject: [PATCH] refactor: separate gradient back and weight updates --- src/easyvae/autoencoder.py | 17 +++++++++++------ src/easyvae/layers.py | 33 ++++++++++++++++++++++----------- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/src/easyvae/autoencoder.py b/src/easyvae/autoencoder.py index 3926481..e5f0fed 100644 --- a/src/easyvae/autoencoder.py +++ b/src/easyvae/autoencoder.py @@ -82,9 +82,11 @@ class ClassicalAutoencoder(AAutoencoder): self.encoder.forward(v) ) error = out - v - self.encoder.backprop( - self.decoder.backprop(error) + self.encoder.back( + self.decoder.back(error) ) + self.encoder.backprop() + self.decoder.backprop() return np.sum(np.abs(error)) / len(v) @interruptable @@ -109,7 +111,7 @@ class ClassicalAutoencoder(AAutoencoder): error += self.train(x) error /= len(data_set) derror = prev_error - error - if derror <= 0 or abs(derror) < 1e-4: + if abs(derror) < 1e-4: no_improv += 1 else: no_improv = 0 @@ -167,11 +169,14 @@ class VariationalAutoencoder(AAutoencoder): def train(self, v: np.ndarray) -> tuple[float, float]: out, _ = self.forward(v) error = out - v - self.encoder.backprop( - self.sampler.backprop( - self.decoder.backprop(error) + self.encoder.back( + self.sampler.back( + self.decoder.back(error) ) ) + self.encoder.backprop() + self.sampler.backprop() + self.decoder.backprop() return np.mean(error ** 2), self.sampler.DKL() @interruptable diff --git a/src/easyvae/layers.py b/src/easyvae/layers.py index a79c6cd..327b0f3 100644 --- a/src/easyvae/layers.py +++ b/src/easyvae/layers.py @@ -15,6 +15,7 @@ class NNLayer: self.input = None self.output = None self.output_linear = None + self.error = None self.activation_func = activation_func def __str__(self): @@ -27,15 +28,16 @@ class NNLayer: self.output_linear ) return self.output + + def back(self, error: np.ndarray) -> np.ndarray: + self.error = error * self.activation_func.d(self.output_linear) + return self.W @ self.error - def backprop(self, error: np.ndarray) -> np.ndarray: - error *= self.activation_func.d(self.output_linear) - ret = self.W @ error - dW = np.outer(self.input, error) * self.lr - dB = error * self.lr + def backprop(self) -> np.ndarray: + dW = np.outer(self.input, self.error) * self.lr + dB = self.error * self.lr self.W -= dW self.B -= dB - return ret class SampleLayer: @@ -66,13 +68,17 @@ class SampleLayer: self.eps = np.random.normal(0, 1, self.mean.shape) return 0.5 * self.eps * self.std + self.mean - def backprop(self, error: np.ndarray) -> np.ndarray: + def back(self, error: np.ndarray) -> np.ndarray: dmean = error + self.mean dstd = error * self.eps + 0.5 * (np.exp(self.logvar) - 1) - mean_error = self.mean_nn.backprop(dmean) - logvar_error = self.std_nn.backprop(dstd * self.std) + mean_error = self.mean_nn.back(dmean) + logvar_error = self.std_nn.back(dstd * self.std) return mean_error + logvar_error + def backprop(self): + self.mean_nn.backprop() + self.std_nn.backprop() + class DeepNNLayer: def __init__(self, @@ -100,7 +106,12 @@ class DeepNNLayer: v = layer.forward(v) return v - def backprop(self, error: np.ndarray) -> np.ndarray: + def back(self, error: np.ndarray): for layer in self.layers[::-1]: - error = layer.backprop(error) + error = layer.back(error) return error + + def backprop(self) -> np.ndarray: + for layer in self.layers: + layer.backprop() +