feat: move train over dataset logic to Autoencoder class

This commit is contained in:
Lenoctambule
2026-03-27 07:07:41 +01:00
parent af9a0c70b2
commit 9859863ec9
2 changed files with 56 additions and 39 deletions

View File

@@ -1,7 +1,12 @@
import numpy as np import numpy as np
from utils import regularize from utils import (regularize,
dynamic_loss_plot_init,
dynamic_loss_plot_update,
dynamic_loss_plot_finish)
import types import types
LOADER = ['', '', '', '', '', '', '', '']
class Encoder: class Encoder:
def __init__(self, def __init__(self,
@@ -73,6 +78,44 @@ class Autoencoder:
error = v - reconstructed error = v - reconstructed
return np.sum(np.abs(error)) return np.sum(np.abs(error))
def train_dataset(self,
data_set: list[np.ndarray],
max_epoch: int,
patience: int,
display_loss: bool = False) -> list[float]:
if display_loss is True:
ax, line = dynamic_loss_plot_init()
losses = []
epoch = 0
no_improv = 0
prev_error = float('inf')
while True:
print(
f"{LOADER[epoch % len(LOADER)]} Training \t({epoch=} error={prev_error:.2f})", # noqa
end="\r"
)
error = 0
for x in data_set:
input = x.flatten()
error += self.train(input)
error /= len(data_set)
if error - prev_error <= 1e-8:
no_improv += 1
else:
no_improv = 0
prev_error = float(error)
losses.append(error)
if display_loss is True:
dynamic_loss_plot_update(ax, line, losses)
if no_improv > patience:
break
if epoch > max_epoch:
break
epoch += 1
if display_loss is True:
dynamic_loss_plot_finish(ax, line)
return losses
def encode(self, v: np.ndarray) -> np.ndarray: def encode(self, v: np.ndarray) -> np.ndarray:
return self.encoder.forward(v) return self.encoder.forward(v)

View File

@@ -2,48 +2,23 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import keras import keras
from autoencoder import Autoencoder from autoencoder import Autoencoder
from utils import (relu, from utils import relu
dynamic_loss_plot_init,
dynamic_loss_plot_update,
dynamic_loss_plot_finish)
def mnist_embed( def mnist_test(
bottleneck: int, bottleneck: int,
max_epoch: int, max_epoch: int,
patience: int, patience: int,
): ):
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() (x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
x_train = np.divide(x_train, 255)
x_test = np.divide(x_train, 255)
in_len = x_train[0].flatten().shape[0] in_len = x_train[0].flatten().shape[0]
autoencoder = Autoencoder(in_len, bottleneck, 0.001, relu) autoencoder = Autoencoder(in_len, bottleneck, 0.0001, relu)
ax, line = dynamic_loss_plot_init()
no_improv = 0
prev_error = float('inf')
losses = []
epoch = 0
x_train = x_train[:] x_train = x_train[:]
while True: autoencoder.train_dataset(x_train, max_epoch, patience)
error = 0
for x in x_train:
input = x.flatten() / 255
error += autoencoder.train(input)
error /= len(x_train)
if error - prev_error <= 1e-8:
no_improv += 1
else:
no_improv = 0
prev_error = error
losses.append(error)
dynamic_loss_plot_update(ax, line, losses)
if no_improv > patience:
break
if epoch > max_epoch:
break
epoch += 1
print("Done!")
dynamic_loss_plot_finish(ax, line)
example: np.ndarray = x_test[np.random.randint(0, len(x_test))] example: np.ndarray = x_test[np.random.randint(0, len(x_test))]
code = autoencoder.encode(example.flatten() / 255) code = autoencoder.encode(example.flatten())
output = autoencoder.decode(code) output = autoencoder.decode(code)
plt.subplot(1, 2, 1) plt.subplot(1, 2, 1)
plt.matshow(example, fignum=False) plt.matshow(example, fignum=False)
@@ -58,9 +33,8 @@ if __name__ == "__main__":
options = "b:e:p:" options = "b:e:p:"
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-b', type=int, nargs='+', default=50) parser.add_argument('-b', type=int, nargs='?', default=50)
parser.add_argument('-e', type=int, nargs='+', default=1000) parser.add_argument('-e', type=int, nargs='?', default=1000)
parser.add_argument('-p', type=int, nargs='+', default=5) parser.add_argument('-p', type=int, nargs='?', default=5)
args = parser.parse_args(sys.argv[1:]) args = parser.parse_args(sys.argv[1:])
mnist_test(args.b, args.e, args.p)
mnist_embed(args.b, args.e, args.p)