feat: add monte-carlo method and MSE to labeling method
This commit is contained in:
@@ -37,7 +37,7 @@ def mnist_train(
|
|||||||
autoencoder = cls(
|
autoencoder = cls(
|
||||||
[in_len, 256, 2],
|
[in_len, 256, 2],
|
||||||
[2, 256, in_len],
|
[2, 256, in_len],
|
||||||
0.0001,
|
0.001,
|
||||||
LeakyReLU()
|
LeakyReLU()
|
||||||
)
|
)
|
||||||
print("CTRL+C to interrupt training.")
|
print("CTRL+C to interrupt training.")
|
||||||
@@ -91,8 +91,8 @@ def plot_random_reconstruction(
|
|||||||
print(f'{code.tolist()}')
|
print(f'{code.tolist()}')
|
||||||
|
|
||||||
|
|
||||||
def mnist_test(model: str | AAutoencoder):
|
def mnist_test(model: str | AAutoencoder | LabelingVAE):
|
||||||
x_train, _, x_test, y_test = load_mnist()
|
x_train, y_train, x_test, y_test = load_mnist()
|
||||||
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
in_len = x_train[0].shape[0] * x_train[0].shape[0]
|
||||||
img_shape = x_train[0].shape
|
img_shape = x_train[0].shape
|
||||||
x_train.resize(x_train.shape[0], in_len)
|
x_train.resize(x_train.shape[0], in_len)
|
||||||
@@ -107,6 +107,11 @@ def mnist_test(model: str | AAutoencoder):
|
|||||||
print(autoencoder)
|
print(autoencoder)
|
||||||
idx = np.random.randint(0, len(x_test))
|
idx = np.random.randint(0, len(x_test))
|
||||||
example: np.ndarray = x_test[idx]
|
example: np.ndarray = x_test[idx]
|
||||||
|
y_train = [str(int(i)) for i in y_train]
|
||||||
|
autoencoder.learn_labels(x_train, y_train, 5)
|
||||||
|
res = autoencoder.label(x_train[idx])
|
||||||
|
for k, v in res.items():
|
||||||
|
print(f"{k} => {v}")
|
||||||
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])
|
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])
|
||||||
if autoencoder.space_dim == 2:
|
if autoencoder.space_dim == 2:
|
||||||
plot_mnist_latent_space(autoencoder, x_test, y_test)
|
plot_mnist_latent_space(autoencoder, x_test, y_test)
|
||||||
@@ -150,6 +155,6 @@ if __name__ == "__main__":
|
|||||||
args.m,
|
args.m,
|
||||||
args.e,
|
args.e,
|
||||||
args.p,
|
args.p,
|
||||||
VariationalAutoencoder
|
LabelingVAE
|
||||||
)
|
)
|
||||||
mnist_test(autoencoder)
|
mnist_test(autoencoder)
|
||||||
|
|||||||
@@ -290,8 +290,11 @@ class LabelingVAE(VariationalAutoencoder):
|
|||||||
for label in self.labels:
|
for label in self.labels:
|
||||||
label.cache()
|
label.cache()
|
||||||
|
|
||||||
def label(self, x: np.ndarray):
|
def label(self, x: np.ndarray, samples=10):
|
||||||
y = self.encode(x)
|
y = np.zeros((samples, self.encoder.out_size))
|
||||||
|
for i in range(samples):
|
||||||
|
y[i] = self.encode(x)
|
||||||
|
y = np.mean(y, axis=0)
|
||||||
probs = {}
|
probs = {}
|
||||||
total = 0
|
total = 0
|
||||||
for label in self.labels:
|
for label in self.labels:
|
||||||
@@ -300,4 +303,10 @@ class LabelingVAE(VariationalAutoencoder):
|
|||||||
total += p
|
total += p
|
||||||
for k in probs:
|
for k in probs:
|
||||||
probs[k] = float(probs[k] / total)
|
probs[k] = float(probs[k] / total)
|
||||||
return dict(sorted(probs.items()))
|
return dict(
|
||||||
|
sorted(
|
||||||
|
probs.items(),
|
||||||
|
key=lambda item: item[1],
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user