feat: loss method + mv data reshaping out of Autoencoder class
This commit is contained in:
@@ -17,6 +17,12 @@ class Autoencoder:
|
|||||||
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
self.encoder = DeepNNLayer(encoder_layers, lr, activation_func)
|
||||||
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
||||||
|
|
||||||
|
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||||
|
loss = 0
|
||||||
|
for x in data_set:
|
||||||
|
loss += np.sum(np.abs(x - self.forward(x)[0])) / len(x)
|
||||||
|
return loss / len(data_set)
|
||||||
|
|
||||||
def train(self, v: np.ndarray):
|
def train(self, v: np.ndarray):
|
||||||
out = self.decoder.forward(
|
out = self.decoder.forward(
|
||||||
self.encoder.forward(v)
|
self.encoder.forward(v)
|
||||||
@@ -31,12 +37,12 @@ class Autoencoder:
|
|||||||
max_epoch: int,
|
max_epoch: int,
|
||||||
patience: int,
|
patience: int,
|
||||||
display_loss: bool = False) -> list[float]:
|
display_loss: bool = False) -> list[float]:
|
||||||
|
losses = [self.loss(data_set)]
|
||||||
if display_loss is True:
|
if display_loss is True:
|
||||||
ax, line = dynamic_loss_plot_init()
|
ax, line = dynamic_loss_plot_init(losses)
|
||||||
losses = []
|
|
||||||
epoch = 0
|
epoch = 0
|
||||||
no_improv = 0
|
no_improv = 0
|
||||||
prev_error = float('inf')
|
prev_error = losses[0]
|
||||||
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
|
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
|
||||||
while True:
|
while True:
|
||||||
lbar.set_description(
|
lbar.set_description(
|
||||||
@@ -45,8 +51,7 @@ class Autoencoder:
|
|||||||
lbar.update()
|
lbar.update()
|
||||||
error = 0
|
error = 0
|
||||||
for x in data_set:
|
for x in data_set:
|
||||||
input = x.flatten()
|
error += self.train(x)
|
||||||
error += self.train(input)
|
|
||||||
error /= len(data_set)
|
error /= len(data_set)
|
||||||
if prev_error - error <= 1e-8:
|
if prev_error - error <= 1e-8:
|
||||||
no_improv += 1
|
no_improv += 1
|
||||||
@@ -71,3 +76,8 @@ class Autoencoder:
|
|||||||
|
|
||||||
def decode(self, v: np.ndarray) -> np.ndarray:
|
def decode(self, v: np.ndarray) -> np.ndarray:
|
||||||
return self.decoder.forward(v)
|
return self.decoder.forward(v)
|
||||||
|
|
||||||
|
def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
code = self.encode(v)
|
||||||
|
out = self.decode(code)
|
||||||
|
return out, code
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from autoencoder import Autoencoder
|
|||||||
from utils import relu
|
from utils import relu
|
||||||
|
|
||||||
|
|
||||||
def load_mnist():
|
def load_mnist() -> list[np.ndarray]:
|
||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -23,9 +23,13 @@ def mnist_test(
|
|||||||
patience: int,
|
patience: int,
|
||||||
):
|
):
|
||||||
x_train, _, x_test, _ = load_mnist()
|
x_train, _, x_test, _ = load_mnist()
|
||||||
|
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||||
|
img_shape = x_train[0].shape
|
||||||
|
x_train.resize(x_train.shape[0], in_len)
|
||||||
|
x_test.resize(x_test.shape[0], in_len)
|
||||||
x_train = np.divide(x_train, 255)
|
x_train = np.divide(x_train, 255)
|
||||||
x_test = np.divide(x_train, 255)
|
x_test = np.divide(x_train, 255)
|
||||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
x_train = x_train[:1000]
|
||||||
autoencoder = Autoencoder(
|
autoencoder = Autoencoder(
|
||||||
[in_len, bottleneck],
|
[in_len, bottleneck],
|
||||||
[bottleneck, in_len],
|
[bottleneck, in_len],
|
||||||
@@ -41,9 +45,9 @@ def mnist_test(
|
|||||||
code = autoencoder.encode(example.flatten())
|
code = autoencoder.encode(example.flatten())
|
||||||
output = autoencoder.decode(code)
|
output = autoencoder.decode(code)
|
||||||
plt.subplot(1, 2, 1)
|
plt.subplot(1, 2, 1)
|
||||||
plt.matshow(example, fignum=False)
|
plt.matshow(example.reshape(img_shape), fignum=False)
|
||||||
plt.subplot(1, 2, 2)
|
plt.subplot(1, 2, 2)
|
||||||
plt.matshow(output.reshape(example.shape), fignum=False)
|
plt.matshow(output.reshape(img_shape), fignum=False)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
4
utils.py
4
utils.py
@@ -25,10 +25,10 @@ def regularize(v: np.ndarray) -> np.ndarray:
|
|||||||
return (v - v_min) / (v_max - v_min)
|
return (v - v_min) / (v_max - v_min)
|
||||||
|
|
||||||
|
|
||||||
def dynamic_loss_plot_init():
|
def dynamic_loss_plot_init(losses: list):
|
||||||
plt.ion()
|
plt.ion()
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots()
|
||||||
line, = ax.plot([], [], label="Loss")
|
line, = ax.plot([0], losses, label="Loss")
|
||||||
ax.set_xlabel("Epoch")
|
ax.set_xlabel("Epoch")
|
||||||
ax.set_ylabel("Loss")
|
ax.set_ylabel("Loss")
|
||||||
ax.set_title("Training Loss")
|
ax.set_title("Training Loss")
|
||||||
|
|||||||
Reference in New Issue
Block a user