feat: leaky relu func

This commit is contained in:
Lenoctambule
2026-03-29 09:19:15 +02:00
parent 09835e9afa
commit 7aabc5db48
2 changed files with 8 additions and 3 deletions

View File

@@ -1,7 +1,7 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from autoencoder import Autoencoder from autoencoder import Autoencoder
from utils import relu, regularize from utils import leaky_relu
def load_mnist() -> list[np.ndarray]: def load_mnist() -> list[np.ndarray]:
@@ -32,7 +32,7 @@ def mnist_train(
[in_len, 64, 16], [in_len, 64, 16],
[16, 64, in_len], [16, 64, in_len],
0.01, 0.01,
relu leaky_relu
) )
autoencoder.train_dataset( autoencoder.train_dataset(
x_train, x_train,
@@ -56,7 +56,6 @@ def mnist_test(filename: str):
idx = np.random.randint(0, len(x_test)) idx = np.random.randint(0, len(x_test))
example: np.ndarray = x_test[idx] example: np.ndarray = x_test[idx]
output, code = autoencoder.forward(example.flatten()) output, code = autoencoder.forward(example.flatten())
output = regularize(output)
plt.subplot(1, 3, 1) plt.subplot(1, 3, 1)
plt.matshow( plt.matshow(
example.reshape(img_shape), example.reshape(img_shape),

View File

@@ -15,6 +15,12 @@ def relu(x: np.ndarray, derivative=False) -> np.ndarray:
return x * (x > 0) return x * (x > 0)
def leaky_relu(x: np.ndarray, derivative=False, k=0.01) -> np.ndarray:
if derivative:
return 1 * (x > 0) + k * (x <= 0)
return x * (x > 0) + x * 0.01 * (x <= 0)
def normalize(v: np.ndarray) -> np.ndarray: def normalize(v: np.ndarray) -> np.ndarray:
return v / (np.linalg.norm(v) + 1e-8) return v / (np.linalg.norm(v) + 1e-8)