From dea5a8ada1b6d55c8542a40d284f41707134d445 Mon Sep 17 00:00:00 2001 From: Martin Cerman Date: Wed, 2 Oct 2024 20:53:27 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A8=20[FIX]=20Fixes=20memory=20leak=20?= =?UTF-8?q?(#83)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fixes memory leak * Changed total_loss to use float type and adjusted collection of loss --------- Co-authored-by: Martin Cerman --- yolo/tools/solver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index 51ceffc..4d179b8 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -86,7 +86,7 @@ def train_one_batch(self, images: Tensor, targets: Tensor): def train_one_epoch(self, dataloader): self.model.train() - total_loss = defaultdict(lambda: torch.tensor(0.0, device=self.device)) + total_loss = defaultdict(float) total_samples = 0 self.optimizer.next_epoch(len(dataloader)) for batch_size, images, targets, *_ in dataloader: @@ -96,7 +96,7 @@ def train_one_epoch(self, dataloader): for loss_name, loss_val in loss_each.items(): if self.use_ddp: # collecting loss for each batch distributed.all_reduce(loss_val, op=distributed.ReduceOp.AVG) - total_loss[loss_name] += loss_val * batch_size + total_loss[loss_name] += loss_val.item() * batch_size total_samples += batch_size self.progress.one_batch(loss_each)