fix: missing args in np.clip
This commit is contained in:
@@ -61,7 +61,7 @@ class SampleLayer:
|
|||||||
def forward(self, v: np.ndarray) -> np.ndarray:
|
def forward(self, v: np.ndarray) -> np.ndarray:
|
||||||
self.input = v
|
self.input = v
|
||||||
self.mean = self.mean_nn.forward(v)
|
self.mean = self.mean_nn.forward(v)
|
||||||
self.logvar = np.clip(self.std_nn.forward(v))
|
self.logvar = np.clip(self.std_nn.forward(v), -10, 10)
|
||||||
self.std = np.exp(0.5 * self.logvar)
|
self.std = np.exp(0.5 * self.logvar)
|
||||||
self.eps = np.random.normal(0, 1, self.mean.shape)
|
self.eps = np.random.normal(0, 1, self.mean.shape)
|
||||||
return 0.5 * self.eps * self.std + self.mean
|
return 0.5 * self.eps * self.std + self.mean
|
||||||
|
|||||||
Reference in New Issue
Block a user