refactor: move plot logic to plotters.py
This commit is contained in:
48
src/easyvae/plotters.py
Normal file
48
src/easyvae/plotters.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .autoencoder import AAutoencoder
|
||||
|
||||
|
||||
class Plotter:
|
||||
def __init__(self, autoencoder: 'AAutoencoder'):
|
||||
pass
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
|
||||
class CAPlotter(Plotter):
|
||||
def __init__(self, autoencoder: 'AAutoencoder'):
|
||||
self.autoencoder = autoencoder
|
||||
plt.ion()
|
||||
self.fig, self.ax = plt.subplots()
|
||||
self.line, = self.ax.plot(
|
||||
list(range(len(autoencoder.losses))),
|
||||
autoencoder.losses,
|
||||
label="Loss"
|
||||
)
|
||||
self.ax.set_xlabel("Epoch")
|
||||
self.ax.set_ylabel("Loss")
|
||||
self.ax.set_title("Training MSE Loss")
|
||||
self.ax.legend()
|
||||
self.update()
|
||||
|
||||
def update(self):
|
||||
self.line.set_xdata(range(len(self.autoencoder.losses)))
|
||||
self.line.set_ydata(self.autoencoder.losses)
|
||||
self.ax.relim()
|
||||
self.ax.autoscale_view()
|
||||
plt.draw()
|
||||
plt.pause(0.1)
|
||||
|
||||
def close(self):
|
||||
plt.ioff()
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user