From 44a55c1871710afaad8b7bfd8492af746465b9c3 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Fri, 27 Mar 2026 05:23:09 +0100 Subject: [PATCH] fix: missing patience reset --- mnist_test.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/mnist_test.py b/mnist_test.py index e10a4c1..daae9e5 100644 --- a/mnist_test.py +++ b/mnist_test.py @@ -1,5 +1,6 @@ import matplotlib.pyplot as plt import numpy as np +import keras from autoencoder import Autoencoder from utils import (relu, dynamic_loss_plot_init, @@ -7,14 +8,16 @@ from utils import (relu, dynamic_loss_plot_finish) -def mnist_embed(): - import keras +def mnist_embed( + bottleneck: int, + max_epoch: int, + patience: int, + ): (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() - IN_LEN = x_train[0].flatten().shape[0] - DIM = 50 - autoencoder = Autoencoder(IN_LEN, DIM, 0.001, relu) + in_len = x_train[0].flatten().shape[0] + autoencoder = Autoencoder(in_len, bottleneck, 0.001, relu) ax, line = dynamic_loss_plot_init() - NO_IMPROV = 0 + no_improv = 0 prev_error = float('inf') losses = [] epoch = 0 @@ -25,17 +28,19 @@ def mnist_embed(): input = x.flatten() / 255 error += autoencoder.train(input) error /= len(x_train) - if error >= prev_error: - NO_IMPROV += 1 + if error - prev_error <= 1e-8: + no_improv += 1 + else: + no_improv = 0 prev_error = error losses.append(error) dynamic_loss_plot_update(ax, line, losses) - if NO_IMPROV > 5: - print('Done !') + if no_improv > patience: break - if epoch > 500: + if epoch > max_epoch: break epoch += 1 + print("Done!") dynamic_loss_plot_finish(ax, line) example: np.ndarray = x_test[np.random.randint(0, len(x_test))] code = autoencoder.encode(example.flatten() / 255) @@ -48,4 +53,4 @@ def mnist_embed(): if __name__ == "__main__": - mnist_embed() + mnist_embed(10, 1000, 5)