diff --git a/flexinfer/inference/inference.py b/flexinfer/inference/inference.py index fee1dd0..15b9920 100644 --- a/flexinfer/inference/inference.py +++ b/flexinfer/inference/inference.py @@ -22,9 +22,8 @@ def __call__(self, imgs): Returns: outp (torch.float32) """ - with torch.no_grad(): - imgs = imgs.cuda() - outp = self.model(imgs) + imgs = imgs.cuda() + outp = self.model(imgs) return outp