diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index d47aa43dda..48675c5a4d 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -51,7 +51,8 @@ class wrapper: """ def __init__(self, fn, gen_functors): - self.fn = torch.compile(fn) + with torch.no_grad(): + self.fn = torch.compile(fn) self.gen_functors = gen_functors.copy() def __call__(self, *inputs, **kwargs): @@ -62,7 +63,9 @@ def __call__(self, *inputs, **kwargs): setattr(pytensor.link.utils, n[1:], fn) # Torch does not accept numpy inputs and may return GPU objects - outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs) + with torch.no_grad(): + ins = (pytorch_typify(inp) for inp in inputs) + outs = self.fn(*ins, **kwargs) # unset attrs for n, _ in self.gen_functors: