Skip to content

Commit

Permalink
Remove gradient tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
Ch0ronomato committed Dec 15, 2024
1 parent 6a39bcb commit 8a3bf6a
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class wrapper:
"""

def __init__(self, fn, gen_functors):
self.fn = torch.compile(fn)
with torch.no_grad():
self.fn = torch.compile(fn)
self.gen_functors = gen_functors.copy()

def __call__(self, *inputs, **kwargs):
Expand All @@ -62,7 +63,9 @@ def __call__(self, *inputs, **kwargs):
setattr(pytensor.link.utils, n[1:], fn)

# Torch does not accept numpy inputs and may return GPU objects
outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs)
with torch.no_grad():
ins = (pytorch_typify(inp) for inp in inputs)
outs = self.fn(*ins, **kwargs)

# unset attrs
for n, _ in self.gen_functors:
Expand Down

0 comments on commit 8a3bf6a

Please sign in to comment.