fix: bad cmp for patience in train_dataset method
This commit is contained in:
@@ -53,8 +53,7 @@ class Decoder:
|
|||||||
self.last_output = regularize(self.activation_func(res))
|
self.last_output = regularize(self.activation_func(res))
|
||||||
return self.last_output
|
return self.last_output
|
||||||
|
|
||||||
def backprop(self, target: np.ndarray):
|
def backprop(self, error: np.ndarray):
|
||||||
error = self.last_output - target
|
|
||||||
dW = np.outer(self.last_input, error)
|
dW = np.outer(self.last_input, error)
|
||||||
self.W -= self.lr * dW
|
self.W -= self.lr * dW
|
||||||
self.B -= self.lr * error
|
self.B -= self.lr * error
|
||||||
@@ -73,7 +72,7 @@ class Autoencoder:
|
|||||||
def train(self, v: np.ndarray) -> float:
|
def train(self, v: np.ndarray) -> float:
|
||||||
encoded = self.encoder.forward(v)
|
encoded = self.encoder.forward(v)
|
||||||
reconstructed = self.decoder.forward(encoded)
|
reconstructed = self.decoder.forward(encoded)
|
||||||
error = self.decoder.backprop(v)
|
error = self.decoder.backprop(v - reconstructed)
|
||||||
self.encoder.backprop(error)
|
self.encoder.backprop(error)
|
||||||
error = v - reconstructed
|
error = v - reconstructed
|
||||||
return np.sum(np.abs(error))
|
return np.sum(np.abs(error))
|
||||||
@@ -99,7 +98,7 @@ class Autoencoder:
|
|||||||
input = x.flatten()
|
input = x.flatten()
|
||||||
error += self.train(input)
|
error += self.train(input)
|
||||||
error /= len(data_set)
|
error /= len(data_set)
|
||||||
if error - prev_error <= 1e-8:
|
if prev_error - error <= 1e-8:
|
||||||
no_improv += 1
|
no_improv += 1
|
||||||
else:
|
else:
|
||||||
no_improv = 0
|
no_improv = 0
|
||||||
@@ -114,6 +113,7 @@ class Autoencoder:
|
|||||||
epoch += 1
|
epoch += 1
|
||||||
if display_loss is True:
|
if display_loss is True:
|
||||||
dynamic_loss_plot_finish(ax, line)
|
dynamic_loss_plot_finish(ax, line)
|
||||||
|
print("\r#Training complete !")
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def encode(self, v: np.ndarray) -> np.ndarray:
|
def encode(self, v: np.ndarray) -> np.ndarray:
|
||||||
|
|||||||
Reference in New Issue
Block a user