Merge pull request #2 from lenoctambule/dev
Refactor of plotting and kb interrupt logic
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -5,3 +5,4 @@ __pycache__
|
|||||||
.venv
|
.venv
|
||||||
dist
|
dist
|
||||||
*.egg-info
|
*.egg-info
|
||||||
|
.env
|
||||||
@@ -6,7 +6,7 @@
|
|||||||
src="https://raw.githubusercontent.com/lenoctambule/autoencoder/refs/heads/main/media/latent-space.png"
|
src="https://raw.githubusercontent.com/lenoctambule/autoencoder/refs/heads/main/media/latent-space.png"
|
||||||
alt="Latent-space of the MNIST dataset"
|
alt="Latent-space of the MNIST dataset"
|
||||||
width=70%>
|
width=70%>
|
||||||
<figcaption align=center>
|
<figcaption>
|
||||||
<p align="center">
|
<p align="center">
|
||||||
Latent-space representation of the MNIST dataset using Variational Autoencoder
|
Latent-space representation of the MNIST dataset using Variational Autoencoder
|
||||||
</p>
|
</p>
|
||||||
|
|||||||
@@ -1,14 +1,12 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import signal
|
|
||||||
from easyvae.autoencoder import ( # noqa
|
from easyvae.autoencoder import ( # noqa
|
||||||
VariationalAutoencoder,
|
VariationalAutoencoder,
|
||||||
ClassicalAutoencoder,
|
ClassicalAutoencoder,
|
||||||
AAutoencoder
|
AAutoencoder
|
||||||
)
|
)
|
||||||
from easyvae.activations import LeakyReLU
|
from easyvae.activations import LeakyReLU
|
||||||
from easyvae.utils import dynamic_loss_plot_finish
|
|
||||||
|
|
||||||
|
|
||||||
def load_mnist() -> list[np.ndarray]:
|
def load_mnist() -> list[np.ndarray]:
|
||||||
@@ -42,17 +40,7 @@ def mnist_train(
|
|||||||
0.0001,
|
0.0001,
|
||||||
LeakyReLU()
|
LeakyReLU()
|
||||||
)
|
)
|
||||||
|
print("CTRL+C to interrupt training.")
|
||||||
def handler(signum, frame):
|
|
||||||
print(f"Saving {filename} before exit ...")
|
|
||||||
autoencoder.save(filename)
|
|
||||||
if plt.get_fignums():
|
|
||||||
dynamic_loss_plot_finish()
|
|
||||||
mnist_test(autoencoder)
|
|
||||||
exit()
|
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, handler)
|
|
||||||
print("CTRL+C to exit and save model.")
|
|
||||||
autoencoder.train_dataset(
|
autoencoder.train_dataset(
|
||||||
x_train,
|
x_train,
|
||||||
max_epoch,
|
max_epoch,
|
||||||
@@ -84,10 +72,11 @@ def plot_mnist_latent_space(autoencoder: AAutoencoder, x: np.ndarray, y,):
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
def plot_random_reconstruction(autoencoder: AAutoencoder,
|
def plot_random_reconstruction(
|
||||||
example: np.ndarray,
|
autoencoder: AAutoencoder,
|
||||||
img_shape,
|
example: np.ndarray,
|
||||||
y):
|
img_shape,
|
||||||
|
y):
|
||||||
output, code = autoencoder.forward(example.flatten())
|
output, code = autoencoder.forward(example.flatten())
|
||||||
plt.subplot(1, 2, 1)
|
plt.subplot(1, 2, 1)
|
||||||
plt.matshow(
|
plt.matshow(
|
||||||
@@ -99,7 +88,7 @@ def plot_random_reconstruction(autoencoder: AAutoencoder,
|
|||||||
output.reshape(img_shape),
|
output.reshape(img_shape),
|
||||||
fignum=False)
|
fignum=False)
|
||||||
plt.title(f"Output ({y})")
|
plt.title(f"Output ({y})")
|
||||||
print(f'{code=}')
|
print(f'{code.tolist()}')
|
||||||
|
|
||||||
|
|
||||||
def mnist_test(model: str | AAutoencoder):
|
def mnist_test(model: str | AAutoencoder):
|
||||||
@@ -114,6 +103,8 @@ def mnist_test(model: str | AAutoencoder):
|
|||||||
autoencoder: AAutoencoder = AAutoencoder.load(model)
|
autoencoder: AAutoencoder = AAutoencoder.load(model)
|
||||||
else:
|
else:
|
||||||
autoencoder = model
|
autoencoder = model
|
||||||
|
print("Testing model ...\n")
|
||||||
|
print(autoencoder)
|
||||||
idx = np.random.randint(0, len(x_test))
|
idx = np.random.randint(0, len(x_test))
|
||||||
example: np.ndarray = x_test[idx]
|
example: np.ndarray = x_test[idx]
|
||||||
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])
|
plot_random_reconstruction(autoencoder, example, img_shape, y_test[idx])
|
||||||
|
|||||||
@@ -1,18 +1,17 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from .utils import (
|
|
||||||
dynamic_loss_plot_init,
|
|
||||||
dynamic_loss_plot_update,
|
|
||||||
dynamic_loss_plot_finish
|
|
||||||
)
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from .layers import DeepNNLayer, SampleLayer
|
from .layers import DeepNNLayer, SampleLayer
|
||||||
from .activations import ActivationFunc, Identity
|
from .activations import ActivationFunc, Identity
|
||||||
|
from .plotters import Plotter, CAPlotter, VAEPlotter
|
||||||
|
from .utils import interruptable
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
LOADER = ['⡿', '⣟', '⣯', '⣷', '⣾', '⣽', '⣻', '⢿']
|
||||||
|
|
||||||
|
|
||||||
class AAutoencoder(ABC):
|
class AAutoencoder(ABC):
|
||||||
|
plotter_cls = Plotter
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
encoder_layers: list[int],
|
encoder_layers: list[int],
|
||||||
@@ -27,45 +26,7 @@ class AAutoencoder(ABC):
|
|||||||
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
self.decoder = DeepNNLayer(decoder_layers, lr, activation_func)
|
||||||
self.space_dim = decoder_layers[0]
|
self.space_dim = decoder_layers[0]
|
||||||
self.lr = lr
|
self.lr = lr
|
||||||
|
self.losses = [0]
|
||||||
def train_dataset(self,
|
|
||||||
data_set: list[np.ndarray],
|
|
||||||
max_epoch: int,
|
|
||||||
patience: int,
|
|
||||||
display_loss: bool = False) -> list[float]:
|
|
||||||
losses = [self.loss(data_set)]
|
|
||||||
if display_loss is True:
|
|
||||||
ax, line = dynamic_loss_plot_init(losses)
|
|
||||||
epoch = 0
|
|
||||||
no_improv = 0
|
|
||||||
prev_error = losses[0]
|
|
||||||
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
|
|
||||||
while True:
|
|
||||||
lbar.set_description(
|
|
||||||
f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} error={float(prev_error):.6f})", # noqa
|
|
||||||
)
|
|
||||||
lbar.update()
|
|
||||||
error = 0
|
|
||||||
for x in tqdm(data_set, leave=False):
|
|
||||||
error += self.train(x)
|
|
||||||
error /= len(data_set)
|
|
||||||
derror = prev_error - error
|
|
||||||
if derror <= 0 or abs(derror) < 1e-4:
|
|
||||||
no_improv += 1
|
|
||||||
else:
|
|
||||||
no_improv = 0
|
|
||||||
prev_error = float(error)
|
|
||||||
losses.append(error)
|
|
||||||
if display_loss is True:
|
|
||||||
dynamic_loss_plot_update(ax, line, losses)
|
|
||||||
if no_improv > patience:
|
|
||||||
break
|
|
||||||
if epoch > max_epoch:
|
|
||||||
break
|
|
||||||
epoch += 1
|
|
||||||
if display_loss is True:
|
|
||||||
dynamic_loss_plot_finish()
|
|
||||||
return losses
|
|
||||||
|
|
||||||
def save(self, path: str):
|
def save(self, path: str):
|
||||||
path = path.removesuffix('.npy')
|
path = path.removesuffix('.npy')
|
||||||
@@ -88,13 +49,27 @@ class AAutoencoder(ABC):
|
|||||||
def forward(self, v: np.ndarray) -> np.ndarray:
|
def forward(self, v: np.ndarray) -> np.ndarray:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def train_dataset(self, *args, **kwargs) -> list[float]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ClassicalAutoencoder(AAutoencoder):
|
class ClassicalAutoencoder(AAutoencoder):
|
||||||
|
plotter_cls = CAPlotter
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
self.losses = []
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f'Encoder:\n{self.encoder}\n\nDecoder:\n{self.decoder}'
|
return "\n".join((
|
||||||
|
f"Type: {__class__.__name__}",
|
||||||
|
"Encoder:",
|
||||||
|
f"{self.encoder}",
|
||||||
|
"Decoder:",
|
||||||
|
f"{self.decoder}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||||
loss = 0
|
loss = 0
|
||||||
@@ -112,6 +87,41 @@ class ClassicalAutoencoder(AAutoencoder):
|
|||||||
)
|
)
|
||||||
return np.sum(np.abs(error)) / len(v)
|
return np.sum(np.abs(error)) / len(v)
|
||||||
|
|
||||||
|
@interruptable
|
||||||
|
def train_dataset(self,
|
||||||
|
data_set: list[np.ndarray],
|
||||||
|
max_epoch: int,
|
||||||
|
patience: int,
|
||||||
|
display_loss: bool = False) -> list[float]:
|
||||||
|
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
|
||||||
|
self.losses = [self.loss(data_set)]
|
||||||
|
epoch = 0
|
||||||
|
no_improv = 0
|
||||||
|
prev_error = self.losses[0]
|
||||||
|
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
|
||||||
|
while True:
|
||||||
|
lbar.set_description(
|
||||||
|
f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} error={float(prev_error):.6f})", # noqa
|
||||||
|
)
|
||||||
|
lbar.update()
|
||||||
|
error = 0
|
||||||
|
for x in tqdm(data_set, leave=False):
|
||||||
|
error += self.train(x)
|
||||||
|
error /= len(data_set)
|
||||||
|
derror = prev_error - error
|
||||||
|
if derror <= 0 or abs(derror) < 1e-4:
|
||||||
|
no_improv += 1
|
||||||
|
else:
|
||||||
|
no_improv = 0
|
||||||
|
prev_error = float(error)
|
||||||
|
self.losses.append(error)
|
||||||
|
if no_improv > patience:
|
||||||
|
break
|
||||||
|
if epoch > max_epoch:
|
||||||
|
break
|
||||||
|
plotter.update()
|
||||||
|
epoch += 1
|
||||||
|
|
||||||
def encode(self, v: np.ndarray) -> np.ndarray:
|
def encode(self, v: np.ndarray) -> np.ndarray:
|
||||||
return self.encoder.forward(v)
|
return self.encoder.forward(v)
|
||||||
|
|
||||||
@@ -125,20 +135,36 @@ class ClassicalAutoencoder(AAutoencoder):
|
|||||||
|
|
||||||
|
|
||||||
class VariationalAutoencoder(AAutoencoder):
|
class VariationalAutoencoder(AAutoencoder):
|
||||||
|
plotter_cls = VAEPlotter
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity())
|
self.sampler = SampleLayer(self.encoder.out_size, self.lr, Identity())
|
||||||
|
self.KL_losses = []
|
||||||
|
self.recon_losses = []
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "\n".join((
|
||||||
|
f"Type: {__class__.__name__}",
|
||||||
|
"Encoder:",
|
||||||
|
f"{self.encoder}",
|
||||||
|
"Decoder:",
|
||||||
|
f"{self.decoder}"
|
||||||
|
))
|
||||||
|
|
||||||
def loss(self, data_set: list[np.ndarray]) -> float:
|
def loss(self, data_set: list[np.ndarray]) -> float:
|
||||||
loss = 0
|
kl_loss = 0
|
||||||
|
recon_loss = 0
|
||||||
for x in data_set:
|
for x in data_set:
|
||||||
out = self.forward(x)[0]
|
out = self.forward(x)[0]
|
||||||
kl = self.sampler.DKL()
|
kl = self.sampler.DKL()
|
||||||
loss += np.mean((out - x) ** 2)
|
recon_loss += np.mean((out - x) ** 2)
|
||||||
loss += kl
|
kl_loss += kl
|
||||||
return loss / len(data_set)
|
kl_loss /= len(data_set)
|
||||||
|
recon_loss /= len(data_set)
|
||||||
|
return recon_loss, kl_loss
|
||||||
|
|
||||||
def train(self, v: np.ndarray) -> float:
|
def train(self, v: np.ndarray) -> tuple[float, float]:
|
||||||
out, _ = self.forward(v)
|
out, _ = self.forward(v)
|
||||||
error = out - v
|
error = out - v
|
||||||
self.encoder.backprop(
|
self.encoder.backprop(
|
||||||
@@ -146,10 +172,61 @@ class VariationalAutoencoder(AAutoencoder):
|
|||||||
self.decoder.backprop(error)
|
self.decoder.backprop(error)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return np.mean(error ** 2) + self.sampler.DKL()
|
return np.mean(error ** 2), self.sampler.DKL()
|
||||||
|
|
||||||
|
@interruptable
|
||||||
|
def train_dataset(self,
|
||||||
|
data_set: list[np.ndarray],
|
||||||
|
max_epoch: int,
|
||||||
|
patience: int,
|
||||||
|
display_loss: bool = False) -> list[float]:
|
||||||
|
plotter = self.plotter_cls(self) if display_loss else Plotter(self)
|
||||||
|
recon_0, kl_0 = self.loss(data_set)
|
||||||
|
self.recon_losses = [recon_0]
|
||||||
|
self.KL_losses = [kl_0]
|
||||||
|
epoch = 0
|
||||||
|
no_improv = 0
|
||||||
|
prev_loss = self.recon_losses[0] + self.KL_losses[0]
|
||||||
|
with tqdm(bar_format="{desc} {elapsed} {rate_fmt}") as lbar:
|
||||||
|
while True:
|
||||||
|
lbar.set_description(
|
||||||
|
f"{LOADER[epoch % len(LOADER)]} Training ({epoch=} loss={float(prev_loss):.6f})", # noqa
|
||||||
|
)
|
||||||
|
lbar.update()
|
||||||
|
dkl = 0
|
||||||
|
recon = 0
|
||||||
|
for x in tqdm(data_set, leave=False):
|
||||||
|
recon_i, dkl_i = self.train(x)
|
||||||
|
dkl += dkl_i
|
||||||
|
recon += recon_i
|
||||||
|
recon /= len(data_set)
|
||||||
|
dkl /= len(data_set)
|
||||||
|
loss = recon + dkl
|
||||||
|
dloss = prev_loss - loss
|
||||||
|
if dloss <= 0 or abs(dloss) < 1e-4:
|
||||||
|
no_improv += 1
|
||||||
|
else:
|
||||||
|
no_improv = 0
|
||||||
|
prev_loss = float(loss)
|
||||||
|
self.recon_losses.append(recon)
|
||||||
|
self.KL_losses.append(dkl)
|
||||||
|
if no_improv > patience:
|
||||||
|
break
|
||||||
|
if epoch > max_epoch:
|
||||||
|
break
|
||||||
|
plotter.update()
|
||||||
|
epoch += 1
|
||||||
|
|
||||||
def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
def forward(self, v: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||||
code = self.encoder.forward(v)
|
code = self.encoder.forward(v)
|
||||||
sample = self.sampler.forward(code)
|
sample = self.sampler.forward(code)
|
||||||
out = self.decoder.forward(sample)
|
out = self.decoder.forward(sample)
|
||||||
return out, code
|
return out, code
|
||||||
|
|
||||||
|
def encode(self, v: np.ndarray) -> np.ndarray:
|
||||||
|
return self.sampler.forward(
|
||||||
|
self.encoder.forward(v)
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, v: np.ndarray) -> np.ndarray:
|
||||||
|
return self.decoder.forward(v)
|
||||||
|
|||||||
93
src/easyvae/plotters.py
Normal file
93
src/easyvae/plotters.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .autoencoder import AAutoencoder, VariationalAutoencoder
|
||||||
|
|
||||||
|
|
||||||
|
class Plotter:
|
||||||
|
def __init__(self, autoencoder: 'AAutoencoder'):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
|
class CAPlotter(Plotter):
|
||||||
|
def __init__(self, autoencoder: 'AAutoencoder'):
|
||||||
|
self.autoencoder = autoencoder
|
||||||
|
plt.ion()
|
||||||
|
self.fig, self.ax = plt.subplots()
|
||||||
|
self.line, = self.ax.plot(
|
||||||
|
list(range(len(autoencoder.losses))),
|
||||||
|
autoencoder.losses,
|
||||||
|
label="Loss"
|
||||||
|
)
|
||||||
|
self.ax.set_xlabel("Epoch")
|
||||||
|
self.ax.set_ylabel("Loss")
|
||||||
|
self.ax.set_title("Training MSE Loss")
|
||||||
|
self.ax.legend()
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
self.line.set_xdata(range(len(self.autoencoder.losses)))
|
||||||
|
self.line.set_ydata(self.autoencoder.losses)
|
||||||
|
self.ax.relim()
|
||||||
|
self.ax.autoscale_view()
|
||||||
|
plt.draw()
|
||||||
|
plt.pause(0.1)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
plt.ioff()
|
||||||
|
plt.close(self.fig)
|
||||||
|
|
||||||
|
|
||||||
|
class VAEPlotter(Plotter):
|
||||||
|
def __init__(self, autoencoder: 'VariationalAutoencoder'):
|
||||||
|
self.autoencoder = autoencoder
|
||||||
|
plt.ion()
|
||||||
|
self.fig, (self.ax_recon, self.ax_dkl) = plt.subplots(1, 2)
|
||||||
|
self.line, = self.ax_recon.plot(
|
||||||
|
list(range(len(self.autoencoder.recon_losses))),
|
||||||
|
self.autoencoder.recon_losses,
|
||||||
|
label="Loss"
|
||||||
|
)
|
||||||
|
self.ax_recon.set_xlabel("Epoch")
|
||||||
|
self.ax_recon.set_ylabel("Loss")
|
||||||
|
self.ax_recon.set_title("Reconstruction MSE Loss")
|
||||||
|
self.ax_recon.legend()
|
||||||
|
|
||||||
|
self.dkl_line, = self.ax_dkl.plot(
|
||||||
|
list(range(len(self.autoencoder.KL_losses))),
|
||||||
|
self.autoencoder.KL_losses,
|
||||||
|
label="DKL Loss",
|
||||||
|
)
|
||||||
|
self.ax_dkl.set_xlabel("Epoch")
|
||||||
|
self.ax_dkl.set_ylabel("Loss")
|
||||||
|
self.ax_dkl.set_title("DKL Loss")
|
||||||
|
self.ax_dkl.legend()
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
self.line.set_xdata(range(len(self.autoencoder.recon_losses)))
|
||||||
|
self.line.set_ydata(self.autoencoder.recon_losses)
|
||||||
|
self.ax_recon.relim()
|
||||||
|
self.ax_recon.autoscale_view()
|
||||||
|
|
||||||
|
self.dkl_line.set_xdata(range(len(self.autoencoder.KL_losses)))
|
||||||
|
self.dkl_line.set_ydata(self.autoencoder.KL_losses)
|
||||||
|
self.ax_dkl.relim()
|
||||||
|
self.ax_dkl.autoscale_view()
|
||||||
|
|
||||||
|
plt.draw()
|
||||||
|
plt.pause(0.1)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
plt.ioff()
|
||||||
|
plt.close(self.fig)
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
|
|
||||||
def softmax(v: np.ndarray) -> np.ndarray:
|
def softmax(v: np.ndarray) -> np.ndarray:
|
||||||
@@ -21,26 +20,10 @@ def regularize(v: np.ndarray) -> np.ndarray:
|
|||||||
return (v - v_min) / (v_max - v_min)
|
return (v - v_min) / (v_max - v_min)
|
||||||
|
|
||||||
|
|
||||||
def dynamic_loss_plot_init(losses: list):
|
def interruptable(func):
|
||||||
plt.ion()
|
def inner(*args, **kwargs):
|
||||||
fig, ax = plt.subplots()
|
try:
|
||||||
line, = ax.plot([0], losses, label="Loss")
|
return func(*args, **kwargs)
|
||||||
ax.set_xlabel("Epoch")
|
except KeyboardInterrupt:
|
||||||
ax.set_ylabel("Loss")
|
pass
|
||||||
ax.set_title("Training Loss")
|
return inner
|
||||||
ax.legend()
|
|
||||||
return ax, line
|
|
||||||
|
|
||||||
|
|
||||||
def dynamic_loss_plot_update(ax, line, loss):
|
|
||||||
line.set_xdata(range(len(loss)))
|
|
||||||
line.set_ydata(loss)
|
|
||||||
ax.relim()
|
|
||||||
ax.autoscale_view()
|
|
||||||
plt.draw()
|
|
||||||
plt.pause(0.1)
|
|
||||||
|
|
||||||
|
|
||||||
def dynamic_loss_plot_finish():
|
|
||||||
plt.ioff()
|
|
||||||
plt.show()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user