From efd328e530616f08fd7d8f86d7299c4e99f48a18 Mon Sep 17 00:00:00 2001 From: Lenoctambule <106790775+lenoctambule@users.noreply.github.com> Date: Sat, 28 Mar 2026 17:50:27 +0100 Subject: [PATCH] feat: save and load methods for Autoencoder --- .gitignore | 3 ++- autoencoder.py | 10 +++++++++- mnist_test.py | 31 ++++++++++++++++++++++--------- 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 8eefd96..d541758 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__ *.pyc +*.npz +*.npy .venv -mnist.npz \ No newline at end of file diff --git a/autoencoder.py b/autoencoder.py index 7e2731e..5ddd745 100644 --- a/autoencoder.py +++ b/autoencoder.py @@ -53,7 +53,8 @@ class Autoencoder: for x in data_set: error += self.train(x) error /= len(data_set) - if prev_error - error <= 1e-8: + derror = prev_error - error + if derror <= 0 or abs(derror) < 1e-8: no_improv += 1 else: no_improv = 0 @@ -81,3 +82,10 @@ class Autoencoder: code = self.encode(v) out = self.decode(code) return out, code + + def save(self, path: str): + np.save(path, self) + + def load(path: str) -> 'Autoencoder': + data = np.load(path, allow_pickle=True) + return data.item() diff --git a/mnist_test.py b/mnist_test.py index 8d55d13..db1f9a2 100644 --- a/mnist_test.py +++ b/mnist_test.py @@ -17,19 +17,17 @@ def load_mnist() -> list[np.ndarray]: return res["x_train"], res["y_train"], res["x_test"], res["y_test"] -def mnist_test( +def mnist_train( bottleneck: int, max_epoch: int, patience: int, ): x_train, _, x_test, _ = load_mnist() in_len = x_train[0].shape[0] * x_train[0].shape[0] - img_shape = x_train[0].shape x_train.resize(x_train.shape[0], in_len) x_test.resize(x_test.shape[0], in_len) - x_train = np.divide(x_train, 255) - x_test = np.divide(x_train, 255) - x_train = x_train[:1000] + x_train = x_train / 255 + x_test = x_test / 255 autoencoder = Autoencoder( [in_len, bottleneck], [bottleneck, in_len], @@ -41,9 +39,20 @@ def mnist_test( max_epoch, patience, display_loss=True) + autoencoder.save("autoencoder_mnist") + + +def mnist_test(): + x_train, _, x_test, _ = load_mnist() + in_len = x_train[0].shape[0] * x_train[0].shape[0] + img_shape = x_train[0].shape + x_train.resize(x_train.shape[0], in_len) + x_test.resize(x_test.shape[0], in_len) + x_train = x_train / 255 + x_test = x_test / 255 + autoencoder = Autoencoder.load('autoencoder_mnist.npy') example: np.ndarray = x_test[np.random.randint(0, len(x_test))] - code = autoencoder.encode(example.flatten()) - output = autoencoder.decode(code) + output, _ = autoencoder.forward(example.flatten()) plt.subplot(1, 2, 1) plt.matshow(example.reshape(img_shape), fignum=False) plt.subplot(1, 2, 2) @@ -55,10 +64,14 @@ if __name__ == "__main__": import argparse import sys - options = "b:e:p:" parser = argparse.ArgumentParser() parser.add_argument('-b', type=int, nargs='?', default=50) parser.add_argument('-e', type=int, nargs='?', default=1000) parser.add_argument('-p', type=int, nargs='?', default=5) + parser.add_argument('-r', action='store_true') args = parser.parse_args(sys.argv[1:]) - mnist_test(args.b, args.e, args.p) + if args.r: + mnist_test() + else: + mnist_train(args.b, args.e, args.p) + mnist_test()