feat: save and load methods for Autoencoder
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,4 +1,5 @@
|
|||||||
__pycache__
|
__pycache__
|
||||||
*.pyc
|
*.pyc
|
||||||
|
*.npz
|
||||||
|
*.npy
|
||||||
.venv
|
.venv
|
||||||
mnist.npz
|
|
||||||
@@ -53,7 +53,8 @@ class Autoencoder:
|
|||||||
for x in data_set:
|
for x in data_set:
|
||||||
error += self.train(x)
|
error += self.train(x)
|
||||||
error /= len(data_set)
|
error /= len(data_set)
|
||||||
if prev_error - error <= 1e-8:
|
derror = prev_error - error
|
||||||
|
if derror <= 0 or abs(derror) < 1e-8:
|
||||||
no_improv += 1
|
no_improv += 1
|
||||||
else:
|
else:
|
||||||
no_improv = 0
|
no_improv = 0
|
||||||
@@ -81,3 +82,10 @@ class Autoencoder:
|
|||||||
code = self.encode(v)
|
code = self.encode(v)
|
||||||
out = self.decode(code)
|
out = self.decode(code)
|
||||||
return out, code
|
return out, code
|
||||||
|
|
||||||
|
def save(self, path: str):
|
||||||
|
np.save(path, self)
|
||||||
|
|
||||||
|
def load(path: str) -> 'Autoencoder':
|
||||||
|
data = np.load(path, allow_pickle=True)
|
||||||
|
return data.item()
|
||||||
|
|||||||
@@ -17,19 +17,17 @@ def load_mnist() -> list[np.ndarray]:
|
|||||||
return res["x_train"], res["y_train"], res["x_test"], res["y_test"]
|
return res["x_train"], res["y_train"], res["x_test"], res["y_test"]
|
||||||
|
|
||||||
|
|
||||||
def mnist_test(
|
def mnist_train(
|
||||||
bottleneck: int,
|
bottleneck: int,
|
||||||
max_epoch: int,
|
max_epoch: int,
|
||||||
patience: int,
|
patience: int,
|
||||||
):
|
):
|
||||||
x_train, _, x_test, _ = load_mnist()
|
x_train, _, x_test, _ = load_mnist()
|
||||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||||
img_shape = x_train[0].shape
|
|
||||||
x_train.resize(x_train.shape[0], in_len)
|
x_train.resize(x_train.shape[0], in_len)
|
||||||
x_test.resize(x_test.shape[0], in_len)
|
x_test.resize(x_test.shape[0], in_len)
|
||||||
x_train = np.divide(x_train, 255)
|
x_train = x_train / 255
|
||||||
x_test = np.divide(x_train, 255)
|
x_test = x_test / 255
|
||||||
x_train = x_train[:1000]
|
|
||||||
autoencoder = Autoencoder(
|
autoencoder = Autoencoder(
|
||||||
[in_len, bottleneck],
|
[in_len, bottleneck],
|
||||||
[bottleneck, in_len],
|
[bottleneck, in_len],
|
||||||
@@ -41,9 +39,20 @@ def mnist_test(
|
|||||||
max_epoch,
|
max_epoch,
|
||||||
patience,
|
patience,
|
||||||
display_loss=True)
|
display_loss=True)
|
||||||
|
autoencoder.save("autoencoder_mnist")
|
||||||
|
|
||||||
|
|
||||||
|
def mnist_test():
|
||||||
|
x_train, _, x_test, _ = load_mnist()
|
||||||
|
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||||
|
img_shape = x_train[0].shape
|
||||||
|
x_train.resize(x_train.shape[0], in_len)
|
||||||
|
x_test.resize(x_test.shape[0], in_len)
|
||||||
|
x_train = x_train / 255
|
||||||
|
x_test = x_test / 255
|
||||||
|
autoencoder = Autoencoder.load('autoencoder_mnist.npy')
|
||||||
example: np.ndarray = x_test[np.random.randint(0, len(x_test))]
|
example: np.ndarray = x_test[np.random.randint(0, len(x_test))]
|
||||||
code = autoencoder.encode(example.flatten())
|
output, _ = autoencoder.forward(example.flatten())
|
||||||
output = autoencoder.decode(code)
|
|
||||||
plt.subplot(1, 2, 1)
|
plt.subplot(1, 2, 1)
|
||||||
plt.matshow(example.reshape(img_shape), fignum=False)
|
plt.matshow(example.reshape(img_shape), fignum=False)
|
||||||
plt.subplot(1, 2, 2)
|
plt.subplot(1, 2, 2)
|
||||||
@@ -55,10 +64,14 @@ if __name__ == "__main__":
|
|||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
options = "b:e:p:"
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-b', type=int, nargs='?', default=50)
|
parser.add_argument('-b', type=int, nargs='?', default=50)
|
||||||
parser.add_argument('-e', type=int, nargs='?', default=1000)
|
parser.add_argument('-e', type=int, nargs='?', default=1000)
|
||||||
parser.add_argument('-p', type=int, nargs='?', default=5)
|
parser.add_argument('-p', type=int, nargs='?', default=5)
|
||||||
|
parser.add_argument('-r', action='store_true')
|
||||||
args = parser.parse_args(sys.argv[1:])
|
args = parser.parse_args(sys.argv[1:])
|
||||||
mnist_test(args.b, args.e, args.p)
|
if args.r:
|
||||||
|
mnist_test()
|
||||||
|
else:
|
||||||
|
mnist_train(args.b, args.e, args.p)
|
||||||
|
mnist_test()
|
||||||
|
|||||||
Reference in New Issue
Block a user