diff --git a/layers.py b/layers.py index 0ca8a1b..51c99e7 100644 --- a/layers.py +++ b/layers.py @@ -61,7 +61,7 @@ class SampleLayer: def forward(self, v: np.ndarray) -> np.ndarray: self.input = v self.mean = self.mean_nn.forward(v) - self.logvar = np.clip(self.std_nn.forward(v)) + self.logvar = np.clip(self.std_nn.forward(v), -10, 10) self.std = np.exp(0.5 * self.logvar) self.eps = np.random.normal(0, 1, self.mean.shape) return 0.5 * self.eps * self.std + self.mean