From 82d61dd10f261e2b642c3d6c622eccaa511e0b5d Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Sun, 5 Apr 2026 01:44:10 +0200 Subject: [PATCH] fix: bad type hint and typo for forward method return value --- autoencoder.py | 4 ++-- mnist_test.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/autoencoder.py b/autoencoder.py index 687ae64..910efde 100644 --- a/autoencoder.py +++ b/autoencoder.py @@ -143,7 +143,7 @@ class VariationalAutoencoder(AAutoencoder): return data.item() def train(self, v: np.ndarray) -> float: - out = self.forward(v) + out, _ = self.forward(v) error = out - v self.encoder.backprop( self.sampler.backprop( @@ -152,7 +152,7 @@ class VariationalAutoencoder(AAutoencoder): ) return np.sum(np.abs(error)) / len(v) - def forward(self, v: np.ndarray) -> np.ndarray: + def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]: code = self.encoder.forward(v) sample = self.sampler.forward(code) out = self.decoder.forward(sample) diff --git a/mnist_test.py b/mnist_test.py index e7fc0aa..ba28bba 100644 --- a/mnist_test.py +++ b/mnist_test.py @@ -29,7 +29,6 @@ def mnist_train( x_test.resize(x_test.shape[0], in_len) x_train = x_train / 255 x_test = x_test / 255 - x_train = x_train[:5000] if os.path.exists(filename): autoencoder = cls.load(filename) else: