feat: leaky relu func
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from autoencoder import Autoencoder
|
from autoencoder import Autoencoder
|
||||||
from utils import relu, regularize
|
from utils import leaky_relu
|
||||||
|
|
||||||
|
|
||||||
def load_mnist() -> list[np.ndarray]:
|
def load_mnist() -> list[np.ndarray]:
|
||||||
@@ -32,7 +32,7 @@ def mnist_train(
|
|||||||
[in_len, 64, 16],
|
[in_len, 64, 16],
|
||||||
[16, 64, in_len],
|
[16, 64, in_len],
|
||||||
0.01,
|
0.01,
|
||||||
relu
|
leaky_relu
|
||||||
)
|
)
|
||||||
autoencoder.train_dataset(
|
autoencoder.train_dataset(
|
||||||
x_train,
|
x_train,
|
||||||
@@ -56,7 +56,6 @@ def mnist_test(filename: str):
|
|||||||
idx = np.random.randint(0, len(x_test))
|
idx = np.random.randint(0, len(x_test))
|
||||||
example: np.ndarray = x_test[idx]
|
example: np.ndarray = x_test[idx]
|
||||||
output, code = autoencoder.forward(example.flatten())
|
output, code = autoencoder.forward(example.flatten())
|
||||||
output = regularize(output)
|
|
||||||
plt.subplot(1, 3, 1)
|
plt.subplot(1, 3, 1)
|
||||||
plt.matshow(
|
plt.matshow(
|
||||||
example.reshape(img_shape),
|
example.reshape(img_shape),
|
||||||
|
|||||||
6
utils.py
6
utils.py
@@ -15,6 +15,12 @@ def relu(x: np.ndarray, derivative=False) -> np.ndarray:
|
|||||||
return x * (x > 0)
|
return x * (x > 0)
|
||||||
|
|
||||||
|
|
||||||
|
def leaky_relu(x: np.ndarray, derivative=False, k=0.01) -> np.ndarray:
|
||||||
|
if derivative:
|
||||||
|
return 1 * (x > 0) + k * (x <= 0)
|
||||||
|
return x * (x > 0) + x * 0.01 * (x <= 0)
|
||||||
|
|
||||||
|
|
||||||
def normalize(v: np.ndarray) -> np.ndarray:
|
def normalize(v: np.ndarray) -> np.ndarray:
|
||||||
return v / (np.linalg.norm(v) + 1e-8)
|
return v / (np.linalg.norm(v) + 1e-8)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user