fix: missing patience reset

This commit is contained in:
Lenoctambule
2026-03-27 05:23:09 +01:00
parent 1e8a27ddaa
commit 44a55c1871

View File

@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import keras
from autoencoder import Autoencoder from autoencoder import Autoencoder
from utils import (relu, from utils import (relu,
dynamic_loss_plot_init, dynamic_loss_plot_init,
@@ -7,14 +8,16 @@ from utils import (relu,
dynamic_loss_plot_finish) dynamic_loss_plot_finish)
def mnist_embed(): def mnist_embed(
import keras bottleneck: int,
max_epoch: int,
patience: int,
):
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
IN_LEN = x_train[0].flatten().shape[0] in_len = x_train[0].flatten().shape[0]
DIM = 50 autoencoder = Autoencoder(in_len, bottleneck, 0.001, relu)
autoencoder = Autoencoder(IN_LEN, DIM, 0.001, relu)
ax, line = dynamic_loss_plot_init() ax, line = dynamic_loss_plot_init()
NO_IMPROV = 0 no_improv = 0
prev_error = float('inf') prev_error = float('inf')
losses = [] losses = []
epoch = 0 epoch = 0
@@ -25,17 +28,19 @@ def mnist_embed():
input = x.flatten() / 255 input = x.flatten() / 255
error += autoencoder.train(input) error += autoencoder.train(input)
error /= len(x_train) error /= len(x_train)
if error >= prev_error: if error - prev_error <= 1e-8:
NO_IMPROV += 1 no_improv += 1
else:
no_improv = 0
prev_error = error prev_error = error
losses.append(error) losses.append(error)
dynamic_loss_plot_update(ax, line, losses) dynamic_loss_plot_update(ax, line, losses)
if NO_IMPROV > 5: if no_improv > patience:
print('Done !')
break break
if epoch > 500: if epoch > max_epoch:
break break
epoch += 1 epoch += 1
print("Done!")
dynamic_loss_plot_finish(ax, line) dynamic_loss_plot_finish(ax, line)
example: np.ndarray = x_test[np.random.randint(0, len(x_test))] example: np.ndarray = x_test[np.random.randint(0, len(x_test))]
code = autoencoder.encode(example.flatten() / 255) code = autoencoder.encode(example.flatten() / 255)
@@ -48,4 +53,4 @@ def mnist_embed():
if __name__ == "__main__": if __name__ == "__main__":
mnist_embed() mnist_embed(10, 1000, 5)