fix: bad type hint and typo for forward method return value
This commit is contained in:
@@ -143,7 +143,7 @@ class VariationalAutoencoder(AAutoencoder):
|
|||||||
return data.item()
|
return data.item()
|
||||||
|
|
||||||
def train(self, v: np.ndarray) -> float:
|
def train(self, v: np.ndarray) -> float:
|
||||||
out = self.forward(v)
|
out, _ = self.forward(v)
|
||||||
error = out - v
|
error = out - v
|
||||||
self.encoder.backprop(
|
self.encoder.backprop(
|
||||||
self.sampler.backprop(
|
self.sampler.backprop(
|
||||||
@@ -152,7 +152,7 @@ class VariationalAutoencoder(AAutoencoder):
|
|||||||
)
|
)
|
||||||
return np.sum(np.abs(error)) / len(v)
|
return np.sum(np.abs(error)) / len(v)
|
||||||
|
|
||||||
def forward(self, v: np.ndarray) -> np.ndarray:
|
def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||||
code = self.encoder.forward(v)
|
code = self.encoder.forward(v)
|
||||||
sample = self.sampler.forward(code)
|
sample = self.sampler.forward(code)
|
||||||
out = self.decoder.forward(sample)
|
out = self.decoder.forward(sample)
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ def mnist_train(
|
|||||||
x_test.resize(x_test.shape[0], in_len)
|
x_test.resize(x_test.shape[0], in_len)
|
||||||
x_train = x_train / 255
|
x_train = x_train / 255
|
||||||
x_test = x_test / 255
|
x_test = x_test / 255
|
||||||
x_train = x_train[:5000]
|
|
||||||
if os.path.exists(filename):
|
if os.path.exists(filename):
|
||||||
autoencoder = cls.load(filename)
|
autoencoder = cls.load(filename)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user