diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index d715ee3973f..861c9299a8f 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -6,6 +6,7 @@ import os from typing import Any, Dict, cast +import matplotlib.pyplot as plt import pytorch_lightning as pl import timm import torch @@ -199,6 +200,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step ) + plt.close() except AttributeError: pass diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 3d16100e696..12679778d70 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, cast +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch import torchmetrics @@ -190,6 +191,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step ) + plt.close() except AttributeError: pass diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index bad9717d638..c9269e3daa2 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -5,6 +5,7 @@ from typing import Any, Dict, cast +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch import torch.nn as nn @@ -140,6 +141,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step ) + plt.close() except AttributeError: pass diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 890b47aa2e9..222351e4b6d 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -6,6 +6,7 @@ import warnings from typing import Any, Dict, cast +import matplotlib.pyplot as plt import pytorch_lightning as pl import segmentation_models_pytorch as smp import torch @@ -194,6 +195,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step ) + plt.close() except AttributeError: pass