fix: missing activation func derivative + send error before update
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from autoencoder import Autoencoder
|
||||
from utils import relu
|
||||
from utils import relu, regularize
|
||||
|
||||
|
||||
def load_mnist() -> list[np.ndarray]:
|
||||
@@ -18,7 +18,7 @@ def load_mnist() -> list[np.ndarray]:
|
||||
|
||||
|
||||
def mnist_train(
|
||||
bottleneck: int,
|
||||
filename: str,
|
||||
max_epoch: int,
|
||||
patience: int,
|
||||
):
|
||||
@@ -29,9 +29,9 @@ def mnist_train(
|
||||
x_train = x_train / 255
|
||||
x_test = x_test / 255
|
||||
autoencoder = Autoencoder(
|
||||
[in_len, bottleneck],
|
||||
[bottleneck, in_len],
|
||||
0.1,
|
||||
[in_len, 64, 16],
|
||||
[16, 64, in_len],
|
||||
0.01,
|
||||
relu
|
||||
)
|
||||
autoencoder.train_dataset(
|
||||
@@ -39,24 +39,39 @@ def mnist_train(
|
||||
max_epoch,
|
||||
patience,
|
||||
display_loss=True)
|
||||
autoencoder.save("autoencoder_mnist")
|
||||
autoencoder.save(filename)
|
||||
|
||||
|
||||
def mnist_test():
|
||||
x_train, _, x_test, _ = load_mnist()
|
||||
def mnist_test(filename: str):
|
||||
x_train, _, x_test, y_test = load_mnist()
|
||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||
img_shape = x_train[0].shape
|
||||
x_train.resize(x_train.shape[0], in_len)
|
||||
x_test.resize(x_test.shape[0], in_len)
|
||||
x_train = x_train / 255
|
||||
x_test = x_test / 255
|
||||
autoencoder = Autoencoder.load('autoencoder_mnist.npy')
|
||||
example: np.ndarray = x_test[np.random.randint(0, len(x_test))]
|
||||
output, _ = autoencoder.forward(example.flatten())
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.matshow(example.reshape(img_shape), fignum=False)
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.matshow(output.reshape(img_shape), fignum=False)
|
||||
autoencoder: Autoencoder = Autoencoder.load(filename)
|
||||
for i in autoencoder.encoder.layers:
|
||||
print(len(i.input), len(i.output))
|
||||
idx = np.random.randint(0, len(x_test))
|
||||
example: np.ndarray = x_test[idx]
|
||||
output, code = autoencoder.forward(example.flatten())
|
||||
output = regularize(output)
|
||||
plt.subplot(1, 3, 1)
|
||||
plt.matshow(
|
||||
example.reshape(img_shape),
|
||||
fignum=False)
|
||||
plt.title(f"Input ({y_test[idx]})")
|
||||
plt.subplot(1, 3, 2)
|
||||
plt.matshow(
|
||||
output.reshape(img_shape),
|
||||
fignum=False)
|
||||
plt.title(f"Output ({y_test[idx]})")
|
||||
plt.subplot(1, 3, 3)
|
||||
s = int(np.ceil(np.sqrt(code.shape[0])))
|
||||
code.resize((s, s), refcheck=False)
|
||||
plt.matshow(code, fignum=False)
|
||||
plt.title(f"Code ({y_test[idx]})")
|
||||
plt.show()
|
||||
|
||||
|
||||
@@ -65,13 +80,34 @@ if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-b', type=int, nargs='?', default=50)
|
||||
parser.add_argument('-e', type=int, nargs='?', default=1000)
|
||||
parser.add_argument('-p', type=int, nargs='?', default=5)
|
||||
parser.add_argument('-r', action='store_true')
|
||||
parser.add_argument(
|
||||
'-e',
|
||||
type=int,
|
||||
nargs='?',
|
||||
default=1000,
|
||||
help='Max epochs'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-p',
|
||||
type=int,
|
||||
nargs='?',
|
||||
default=5,
|
||||
help='Patience'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-m',
|
||||
type=str, nargs='?',
|
||||
default='autoencoder_mnist.npy',
|
||||
help='Model filename to save in run mode or load in training mode'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-r',
|
||||
action='store_true',
|
||||
help='Run mode'
|
||||
)
|
||||
args = parser.parse_args(sys.argv[1:])
|
||||
if args.r:
|
||||
mnist_test()
|
||||
mnist_test(args.m)
|
||||
else:
|
||||
mnist_train(args.b, args.e, args.p)
|
||||
mnist_test()
|
||||
mnist_train(args.m, args.e, args.p)
|
||||
mnist_test(args.m)
|
||||
|
||||
Reference in New Issue
Block a user