From 8a3bf6ab3c029c158a4a29c2f2a485ef7618736a Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Sun, 15 Dec 2024 09:51:58 -0800 Subject: [PATCH] Remove gradient tracking --- pytensor/link/pytorch/linker.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index d47aa43dda..48675c5a4d 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -51,7 +51,8 @@ class wrapper: """ def __init__(self, fn, gen_functors): - self.fn = torch.compile(fn) + with torch.no_grad(): + self.fn = torch.compile(fn) self.gen_functors = gen_functors.copy() def __call__(self, *inputs, **kwargs): @@ -62,7 +63,9 @@ def __call__(self, *inputs, **kwargs): setattr(pytensor.link.utils, n[1:], fn) # Torch does not accept numpy inputs and may return GPU objects - outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs) + with torch.no_grad(): + ins = (pytorch_typify(inp) for inp in inputs) + outs = self.fn(*ins, **kwargs) # unset attrs for n, _ in self.gen_functors: