From c37d1c9c268e4ca8b0933de73502ffcefa6d42b5 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Fri, 27 Mar 2026 22:16:17 +0100 Subject: [PATCH] refactor: use tqdm instead of custom load bar --- autoencoder.py | 53 +++++++++++++++++++++++++----------------------- requirements.txt | 3 ++- 2 files changed, 30 insertions(+), 26 deletions(-) diff --git a/autoencoder.py b/autoencoder.py index 982a798..9f2e7e9 100644 --- a/autoencoder.py +++ b/autoencoder.py @@ -4,6 +4,7 @@ from utils import (regularize, dynamic_loss_plot_update, dynamic_loss_plot_finish) import types +from tqdm import tqdm LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿'] @@ -88,32 +89,34 @@ class Autoencoder: epoch = 0 no_improv = 0 prev_error = float('inf') - while True: - print( - f"{LOADER[epoch % len(LOADER)]} Training \t({epoch=} error={prev_error:.2f})", # noqa - ) - error = 0 - for x in data_set: - input = x.flatten() - error += self.train(input) - error /= len(data_set) - if prev_error - error <= 1e-8: - no_improv += 1 - else: - no_improv = 0 - prev_error = float(error) - losses.append(error) + with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar : + while True: + lbar.set_description( + f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} error={prev_error:.2f})", + ) + lbar.update() + error = 0 + for x in data_set: + input = x.flatten() + error += self.train(input) + error /= len(data_set) + if prev_error - error <= 1e-8: + no_improv += 1 + else: + no_improv = 0 + prev_error = float(error) + losses.append(error) + if display_loss is True: + dynamic_loss_plot_update(ax, line, losses) + if no_improv > patience: + break + if epoch > max_epoch: + break + epoch += 1 if display_loss is True: - dynamic_loss_plot_update(ax, line, losses) - if no_improv > patience: - break - if epoch > max_epoch: - break - epoch += 1 - if display_loss is True: - dynamic_loss_plot_finish(ax, line) - print("#Training complete !") - return losses + dynamic_loss_plot_finish(ax, line) + print("#Training complete !") + return losses def encode(self, v: np.ndarray) -> np.ndarray: return self.encoder.forward(v) diff --git a/requirements.txt b/requirements.txt index 9144eaa..e53d596 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ numpy matplotlib -requests \ No newline at end of file +requests +tqdm \ No newline at end of file