diff --git a/autoencoder.py b/autoencoder.py index 687ae64..910efde 100644 --- a/autoencoder.py +++ b/autoencoder.py @@ -143,7 +143,7 @@ class VariationalAutoencoder(AAutoencoder): return data.item() def train(self, v: np.ndarray) -> float: - out = self.forward(v) + out, _ = self.forward(v) error = out - v self.encoder.backprop( self.sampler.backprop( @@ -152,7 +152,7 @@ class VariationalAutoencoder(AAutoencoder): ) return np.sum(np.abs(error)) / len(v) - def forward(self, v: np.ndarray) -> np.ndarray: + def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]: code = self.encoder.forward(v) sample = self.sampler.forward(code) out = self.decoder.forward(sample) diff --git a/mnist_test.py b/mnist_test.py index e7fc0aa..ba28bba 100644 --- a/mnist_test.py +++ b/mnist_test.py @@ -29,7 +29,6 @@ def mnist_train( x_test.resize(x_test.shape[0], in_len) x_train = x_train / 255 x_test = x_test / 255 - x_train = x_train[:5000] if os.path.exists(filename): autoencoder = cls.load(filename) else: