refactor: use tqdm instead of custom load bar
This commit is contained in:
@@ -4,6 +4,7 @@ from utils import (regularize,
|
|||||||
dynamic_loss_plot_update,
|
dynamic_loss_plot_update,
|
||||||
dynamic_loss_plot_finish)
|
dynamic_loss_plot_finish)
|
||||||
import types
|
import types
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
||||||
|
|
||||||
@@ -88,32 +89,34 @@ class Autoencoder:
|
|||||||
epoch = 0
|
epoch = 0
|
||||||
no_improv = 0
|
no_improv = 0
|
||||||
prev_error = float('inf')
|
prev_error = float('inf')
|
||||||
while True:
|
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar :
|
||||||
print(
|
while True:
|
||||||
f"{LOADER[epoch % len(LOADER)]} Training \t({epoch=} error={prev_error:.2f})", # noqa
|
lbar.set_description(
|
||||||
)
|
f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} error={prev_error:.2f})",
|
||||||
error = 0
|
)
|
||||||
for x in data_set:
|
lbar.update()
|
||||||
input = x.flatten()
|
error = 0
|
||||||
error += self.train(input)
|
for x in data_set:
|
||||||
error /= len(data_set)
|
input = x.flatten()
|
||||||
if prev_error - error <= 1e-8:
|
error += self.train(input)
|
||||||
no_improv += 1
|
error /= len(data_set)
|
||||||
else:
|
if prev_error - error <= 1e-8:
|
||||||
no_improv = 0
|
no_improv += 1
|
||||||
prev_error = float(error)
|
else:
|
||||||
losses.append(error)
|
no_improv = 0
|
||||||
|
prev_error = float(error)
|
||||||
|
losses.append(error)
|
||||||
|
if display_loss is True:
|
||||||
|
dynamic_loss_plot_update(ax, line, losses)
|
||||||
|
if no_improv > patience:
|
||||||
|
break
|
||||||
|
if epoch > max_epoch:
|
||||||
|
break
|
||||||
|
epoch += 1
|
||||||
if display_loss is True:
|
if display_loss is True:
|
||||||
dynamic_loss_plot_update(ax, line, losses)
|
dynamic_loss_plot_finish(ax, line)
|
||||||
if no_improv > patience:
|
print("#Training complete !")
|
||||||
break
|
return losses
|
||||||
if epoch > max_epoch:
|
|
||||||
break
|
|
||||||
epoch += 1
|
|
||||||
if display_loss is True:
|
|
||||||
dynamic_loss_plot_finish(ax, line)
|
|
||||||
print("#Training complete !")
|
|
||||||
return losses
|
|
||||||
|
|
||||||
def encode(self, v: np.ndarray) -> np.ndarray:
|
def encode(self, v: np.ndarray) -> np.ndarray:
|
||||||
return self.encoder.forward(v)
|
return self.encoder.forward(v)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
numpy
|
numpy
|
||||||
matplotlib
|
matplotlib
|
||||||
requests
|
requests
|
||||||
|
tqdm
|
||||||
Reference in New Issue
Block a user