-
Notifications
You must be signed in to change notification settings - Fork 87
Open
Description
Issue Description
Simply replacing torch.jit.script
or torch.jit.trace
with backend.jit
still fail for tc functions
Example scripts:
@torch.jit.script
def f(param):
c = tc.Circuit(6)
for i in range(5):
for j in range(5):
c.rzz(i, i+1, theta=param[i, j])
return c.expectation_ps(z=[1])
f(torch.ones([5, 5]))
or
@partial(torch.jit.trace, example_inputs=torch.ones([5, 5]))
def f(param):
c = tc.Circuit(6)
for i in range(5):
for j in range(5):
c.rzz(i, i+1, theta=param[i, j])
return c.expectation_ps(z=[1])
f(torch.ones([5, 5]))
actually the latter somehow works, but very fragile, for example, if the jit transformation is nested with grad or vmap operation, torch mostly fails
Proposed Solution
- Wait for further development of torch or 2. use tf/jax backend with torch interface instead or 3. actually maybe slightly fix in the exsisting tc codebase may work but currently have no time to try 4. or try
torch.compile
later.
Activity
refraction-ray commentedon Apr 10, 2023
On the other hand,
torch.vmap
seems to work fine at least at syntax level, detailed performance is not benchmarkedrefraction-ray commentedon Apr 10, 2023
pytorch/pytorch#98724
refraction-ray commentedon Apr 26, 2024
torch2.3 is okay for functional transformation nesting, but this version doesn't include support for macOS x86...
https://dev-discuss.pytorch.org/t/pytorch-macos-x86-builds-deprecation-starting-january-2024/1690
pytorch/pytorch#114602