fix: bad type hint and typo for forward method return value

This commit is contained in:
Lenoctambule
2026-04-05 01:44:10 +02:00
parent 5a8fb2c48b
commit 82d61dd10f
2 changed files with 2 additions and 3 deletions

View File

@@ -143,7 +143,7 @@ class VariationalAutoencoder(AAutoencoder):
return data.item() return data.item()
def train(self, v: np.ndarray) -> float: def train(self, v: np.ndarray) -> float:
out = self.forward(v) out, _ = self.forward(v)
error = out - v error = out - v
self.encoder.backprop( self.encoder.backprop(
self.sampler.backprop( self.sampler.backprop(
@@ -152,7 +152,7 @@ class VariationalAutoencoder(AAutoencoder):
) )
return np.sum(np.abs(error)) / len(v) return np.sum(np.abs(error)) / len(v)
def forward(self, v: np.ndarray) -> np.ndarray: def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
code = self.encoder.forward(v) code = self.encoder.forward(v)
sample = self.sampler.forward(code) sample = self.sampler.forward(code)
out = self.decoder.forward(sample) out = self.decoder.forward(sample)

View File

@@ -29,7 +29,6 @@ def mnist_train(
x_test.resize(x_test.shape[0], in_len) x_test.resize(x_test.shape[0], in_len)
x_train = x_train / 255 x_train = x_train / 255
x_test = x_test / 255 x_test = x_test / 255
x_train = x_train[:5000]
if os.path.exists(filename): if os.path.exists(filename):
autoencoder = cls.load(filename) autoencoder = cls.load(filename)
else: else: