feat: save and load methods for Autoencoder

This commit is contained in:
Lenoctambule
2026-03-28 17:50:27 +01:00
parent 6155649655
commit efd328e530
3 changed files with 33 additions and 11 deletions

3
.gitignore vendored
View File

@@ -1,4 +1,5 @@
__pycache__ __pycache__
*.pyc *.pyc
*.npz
*.npy
.venv .venv
mnist.npz

View File

@@ -53,7 +53,8 @@ class Autoencoder:
for x in data_set: for x in data_set:
error += self.train(x) error += self.train(x)
error /= len(data_set) error /= len(data_set)
if prev_error - error <= 1e-8: derror = prev_error - error
if derror <= 0 or abs(derror) < 1e-8:
no_improv += 1 no_improv += 1
else: else:
no_improv = 0 no_improv = 0
@@ -81,3 +82,10 @@ class Autoencoder:
code = self.encode(v) code = self.encode(v)
out = self.decode(code) out = self.decode(code)
return out, code return out, code
def save(self, path: str):
np.save(path, self)
def load(path: str) -> 'Autoencoder':
data = np.load(path, allow_pickle=True)
return data.item()

View File

@@ -17,19 +17,17 @@ def load_mnist() -> list[np.ndarray]:
return res["x_train"], res["y_train"], res["x_test"], res["y_test"] return res["x_train"], res["y_train"], res["x_test"], res["y_test"]
def mnist_test( def mnist_train(
bottleneck: int, bottleneck: int,
max_epoch: int, max_epoch: int,
patience: int, patience: int,
): ):
x_train, _, x_test, _ = load_mnist() x_train, _, x_test, _ = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0] in_len = x_train[0].shape[0] * x_train[0].shape[0]
img_shape = x_train[0].shape
x_train.resize(x_train.shape[0], in_len) x_train.resize(x_train.shape[0], in_len)
x_test.resize(x_test.shape[0], in_len) x_test.resize(x_test.shape[0], in_len)
x_train = np.divide(x_train, 255) x_train = x_train / 255
x_test = np.divide(x_train, 255) x_test = x_test / 255
x_train = x_train[:1000]
autoencoder = Autoencoder( autoencoder = Autoencoder(
[in_len, bottleneck], [in_len, bottleneck],
[bottleneck, in_len], [bottleneck, in_len],
@@ -41,9 +39,20 @@ def mnist_test(
max_epoch, max_epoch,
patience, patience,
display_loss=True) display_loss=True)
autoencoder.save("autoencoder_mnist")
def mnist_test():
x_train, _, x_test, _ = load_mnist()
in_len = x_train[0].shape[0] * x_train[0].shape[0]
img_shape = x_train[0].shape
x_train.resize(x_train.shape[0], in_len)
x_test.resize(x_test.shape[0], in_len)
x_train = x_train / 255
x_test = x_test / 255
autoencoder = Autoencoder.load('autoencoder_mnist.npy')
example: np.ndarray = x_test[np.random.randint(0, len(x_test))] example: np.ndarray = x_test[np.random.randint(0, len(x_test))]
code = autoencoder.encode(example.flatten()) output, _ = autoencoder.forward(example.flatten())
output = autoencoder.decode(code)
plt.subplot(1, 2, 1) plt.subplot(1, 2, 1)
plt.matshow(example.reshape(img_shape), fignum=False) plt.matshow(example.reshape(img_shape), fignum=False)
plt.subplot(1, 2, 2) plt.subplot(1, 2, 2)
@@ -55,10 +64,14 @@ if __name__ == "__main__":
import argparse import argparse
import sys import sys
options = "b:e:p:"
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-b', type=int, nargs='?', default=50) parser.add_argument('-b', type=int, nargs='?', default=50)
parser.add_argument('-e', type=int, nargs='?', default=1000) parser.add_argument('-e', type=int, nargs='?', default=1000)
parser.add_argument('-p', type=int, nargs='?', default=5) parser.add_argument('-p', type=int, nargs='?', default=5)
parser.add_argument('-r', action='store_true')
args = parser.parse_args(sys.argv[1:]) args = parser.parse_args(sys.argv[1:])
mnist_test(args.b, args.e, args.p) if args.r:
mnist_test()
else:
mnist_train(args.b, args.e, args.p)
mnist_test()