fix: missing patience reset
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import keras
|
||||||
from autoencoder import Autoencoder
|
from autoencoder import Autoencoder
|
||||||
from utils import (relu,
|
from utils import (relu,
|
||||||
dynamic_loss_plot_init,
|
dynamic_loss_plot_init,
|
||||||
@@ -7,14 +8,16 @@ from utils import (relu,
|
|||||||
dynamic_loss_plot_finish)
|
dynamic_loss_plot_finish)
|
||||||
|
|
||||||
|
|
||||||
def mnist_embed():
|
def mnist_embed(
|
||||||
import keras
|
bottleneck: int,
|
||||||
|
max_epoch: int,
|
||||||
|
patience: int,
|
||||||
|
):
|
||||||
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
|
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
|
||||||
IN_LEN = x_train[0].flatten().shape[0]
|
in_len = x_train[0].flatten().shape[0]
|
||||||
DIM = 50
|
autoencoder = Autoencoder(in_len, bottleneck, 0.001, relu)
|
||||||
autoencoder = Autoencoder(IN_LEN, DIM, 0.001, relu)
|
|
||||||
ax, line = dynamic_loss_plot_init()
|
ax, line = dynamic_loss_plot_init()
|
||||||
NO_IMPROV = 0
|
no_improv = 0
|
||||||
prev_error = float('inf')
|
prev_error = float('inf')
|
||||||
losses = []
|
losses = []
|
||||||
epoch = 0
|
epoch = 0
|
||||||
@@ -25,17 +28,19 @@ def mnist_embed():
|
|||||||
input = x.flatten() / 255
|
input = x.flatten() / 255
|
||||||
error += autoencoder.train(input)
|
error += autoencoder.train(input)
|
||||||
error /= len(x_train)
|
error /= len(x_train)
|
||||||
if error >= prev_error:
|
if error - prev_error <= 1e-8:
|
||||||
NO_IMPROV += 1
|
no_improv += 1
|
||||||
|
else:
|
||||||
|
no_improv = 0
|
||||||
prev_error = error
|
prev_error = error
|
||||||
losses.append(error)
|
losses.append(error)
|
||||||
dynamic_loss_plot_update(ax, line, losses)
|
dynamic_loss_plot_update(ax, line, losses)
|
||||||
if NO_IMPROV > 5:
|
if no_improv > patience:
|
||||||
print('Done !')
|
|
||||||
break
|
break
|
||||||
if epoch > 500:
|
if epoch > max_epoch:
|
||||||
break
|
break
|
||||||
epoch += 1
|
epoch += 1
|
||||||
|
print("Done!")
|
||||||
dynamic_loss_plot_finish(ax, line)
|
dynamic_loss_plot_finish(ax, line)
|
||||||
example: np.ndarray = x_test[np.random.randint(0, len(x_test))]
|
example: np.ndarray = x_test[np.random.randint(0, len(x_test))]
|
||||||
code = autoencoder.encode(example.flatten() / 255)
|
code = autoencoder.encode(example.flatten() / 255)
|
||||||
@@ -48,4 +53,4 @@ def mnist_embed():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mnist_embed()
|
mnist_embed(10, 1000, 5)
|
||||||
|
|||||||
Reference in New Issue
Block a user