fix: missing patience reset

This commit is contained in:
Lenoctambule
2026-03-27 05:23:09 +01:00
parent 1e8a27ddaa
commit 44a55c1871

View File

@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import numpy as np
import keras
from autoencoder import Autoencoder
from utils import (relu,
dynamic_loss_plot_init,
@@ -7,14 +8,16 @@ from utils import (relu,
dynamic_loss_plot_finish)
def mnist_embed():
import keras
def mnist_embed(
bottleneck: int,
max_epoch: int,
patience: int,
):
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
IN_LEN = x_train[0].flatten().shape[0]
DIM = 50
autoencoder = Autoencoder(IN_LEN, DIM, 0.001, relu)
in_len = x_train[0].flatten().shape[0]
autoencoder = Autoencoder(in_len, bottleneck, 0.001, relu)
ax, line = dynamic_loss_plot_init()
NO_IMPROV = 0
no_improv = 0
prev_error = float('inf')
losses = []
epoch = 0
@@ -25,17 +28,19 @@ def mnist_embed():
input = x.flatten() / 255
error += autoencoder.train(input)
error /= len(x_train)
if error >= prev_error:
NO_IMPROV += 1
if error - prev_error <= 1e-8:
no_improv += 1
else:
no_improv = 0
prev_error = error
losses.append(error)
dynamic_loss_plot_update(ax, line, losses)
if NO_IMPROV > 5:
print('Done !')
if no_improv > patience:
break
if epoch > 500:
if epoch > max_epoch:
break
epoch += 1
print("Done!")
dynamic_loss_plot_finish(ax, line)
example: np.ndarray = x_test[np.random.randint(0, len(x_test))]
code = autoencoder.encode(example.flatten() / 255)
@@ -48,4 +53,4 @@ def mnist_embed():
if __name__ == "__main__":
mnist_embed()
mnist_embed(10, 1000, 5)