refactor: use tqdm instead of custom load bar
This commit is contained in:
@@ -4,6 +4,7 @@ from utils import (regularize,
|
||||
dynamic_loss_plot_update,
|
||||
dynamic_loss_plot_finish)
|
||||
import types
|
||||
from tqdm import tqdm
|
||||
|
||||
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
||||
|
||||
@@ -88,32 +89,34 @@ class Autoencoder:
|
||||
epoch = 0
|
||||
no_improv = 0
|
||||
prev_error = float('inf')
|
||||
while True:
|
||||
print(
|
||||
f"{LOADER[epoch % len(LOADER)]} Training \t({epoch=} error={prev_error:.2f})", # noqa
|
||||
)
|
||||
error = 0
|
||||
for x in data_set:
|
||||
input = x.flatten()
|
||||
error += self.train(input)
|
||||
error /= len(data_set)
|
||||
if prev_error - error <= 1e-8:
|
||||
no_improv += 1
|
||||
else:
|
||||
no_improv = 0
|
||||
prev_error = float(error)
|
||||
losses.append(error)
|
||||
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar :
|
||||
while True:
|
||||
lbar.set_description(
|
||||
f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} error={prev_error:.2f})",
|
||||
)
|
||||
lbar.update()
|
||||
error = 0
|
||||
for x in data_set:
|
||||
input = x.flatten()
|
||||
error += self.train(input)
|
||||
error /= len(data_set)
|
||||
if prev_error - error <= 1e-8:
|
||||
no_improv += 1
|
||||
else:
|
||||
no_improv = 0
|
||||
prev_error = float(error)
|
||||
losses.append(error)
|
||||
if display_loss is True:
|
||||
dynamic_loss_plot_update(ax, line, losses)
|
||||
if no_improv > patience:
|
||||
break
|
||||
if epoch > max_epoch:
|
||||
break
|
||||
epoch += 1
|
||||
if display_loss is True:
|
||||
dynamic_loss_plot_update(ax, line, losses)
|
||||
if no_improv > patience:
|
||||
break
|
||||
if epoch > max_epoch:
|
||||
break
|
||||
epoch += 1
|
||||
if display_loss is True:
|
||||
dynamic_loss_plot_finish(ax, line)
|
||||
print("#Training complete !")
|
||||
return losses
|
||||
dynamic_loss_plot_finish(ax, line)
|
||||
print("#Training complete !")
|
||||
return losses
|
||||
|
||||
def encode(self, v: np.ndarray) -> np.ndarray:
|
||||
return self.encoder.forward(v)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
numpy
|
||||
matplotlib
|
||||
requests
|
||||
requests
|
||||
tqdm
|
||||
Reference in New Issue
Block a user