diff --git a/autoencoder.py b/autoencoder.py index 7de9ca0..8d05e0d 100644 --- a/autoencoder.py +++ b/autoencoder.py @@ -1,7 +1,12 @@ import numpy as np -from utils import regularize +from utils import (regularize, + dynamic_loss_plot_init, + dynamic_loss_plot_update, + dynamic_loss_plot_finish) import types +LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿'] + class Encoder: def __init__(self, @@ -73,6 +78,44 @@ class Autoencoder: error = v - reconstructed return np.sum(np.abs(error)) + def train_dataset(self, + data_set: list[np.ndarray], + max_epoch: int, + patience: int, + display_loss: bool = False) -> list[float]: + if display_loss is True: + ax, line = dynamic_loss_plot_init() + losses = [] + epoch = 0 + no_improv = 0 + prev_error = float('inf') + while True: + print( + f"{LOADER[epoch % len(LOADER)]} Training \t({epoch=} error={prev_error:.2f})", # noqa + end="\r" + ) + error = 0 + for x in data_set: + input = x.flatten() + error += self.train(input) + error /= len(data_set) + if error - prev_error <= 1e-8: + no_improv += 1 + else: + no_improv = 0 + prev_error = float(error) + losses.append(error) + if display_loss is True: + dynamic_loss_plot_update(ax, line, losses) + if no_improv > patience: + break + if epoch > max_epoch: + break + epoch += 1 + if display_loss is True: + dynamic_loss_plot_finish(ax, line) + return losses + def encode(self, v: np.ndarray) -> np.ndarray: return self.encoder.forward(v) diff --git a/mnist_test.py b/mnist_test.py index 7b62a39..0bd23a9 100644 --- a/mnist_test.py +++ b/mnist_test.py @@ -2,48 +2,23 @@ import matplotlib.pyplot as plt import numpy as np import keras from autoencoder import Autoencoder -from utils import (relu, - dynamic_loss_plot_init, - dynamic_loss_plot_update, - dynamic_loss_plot_finish) +from utils import relu -def mnist_embed( +def mnist_test( bottleneck: int, max_epoch: int, patience: int, ): - (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + (x_train, _), (x_test, _) = keras.datasets.mnist.load_data() + x_train = np.divide(x_train, 255) + x_test = np.divide(x_train, 255) in_len = x_train[0].flatten().shape[0] - autoencoder = Autoencoder(in_len, bottleneck, 0.001, relu) - ax, line = dynamic_loss_plot_init() - no_improv = 0 - prev_error = float('inf') - losses = [] - epoch = 0 + autoencoder = Autoencoder(in_len, bottleneck, 0.0001, relu) x_train = x_train[:] - while True: - error = 0 - for x in x_train: - input = x.flatten() / 255 - error += autoencoder.train(input) - error /= len(x_train) - if error - prev_error <= 1e-8: - no_improv += 1 - else: - no_improv = 0 - prev_error = error - losses.append(error) - dynamic_loss_plot_update(ax, line, losses) - if no_improv > patience: - break - if epoch > max_epoch: - break - epoch += 1 - print("Done!") - dynamic_loss_plot_finish(ax, line) + autoencoder.train_dataset(x_train, max_epoch, patience) example: np.ndarray = x_test[np.random.randint(0, len(x_test))] - code = autoencoder.encode(example.flatten() / 255) + code = autoencoder.encode(example.flatten()) output = autoencoder.decode(code) plt.subplot(1, 2, 1) plt.matshow(example, fignum=False) @@ -58,9 +33,8 @@ if __name__ == "__main__": options = "b:e:p:" parser = argparse.ArgumentParser() - parser.add_argument('-b', type=int, nargs='+', default=50) - parser.add_argument('-e', type=int, nargs='+', default=1000) - parser.add_argument('-p', type=int, nargs='+', default=5) + parser.add_argument('-b', type=int, nargs='?', default=50) + parser.add_argument('-e', type=int, nargs='?', default=1000) + parser.add_argument('-p', type=int, nargs='?', default=5) args = parser.parse_args(sys.argv[1:]) - - mnist_embed(args.b, args.e, args.p) + mnist_test(args.b, args.e, args.p)