Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[question] [docs] Short/mid/long-term status of TorchScript / JIT / torch.jit.trace / FX / symbolic tracing and its replacement by Dynamo #103841

Open
vadimkantorov opened this issue Jun 19, 2023 · 6 comments
Labels
module: docs Related to our documentation, both in docs/ and docblocks oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jun 19, 2023

📚 The doc issue

At OP pytorch/vision#7624 (comment) I stumbled that there might be lack of official communication on short-term / mid-term / long-term status of TorchScript / JIT interpreter.

In that OP case, need to support TorchScript necessitated use of inelegant indexing instead of simpler tensor[..., perm, :, :] which seems not supported by TorchScript (btw is this construct supported by Dynamo?)

Any official information on need to support TorchScript in new code of domain libraries / in general? Does TorchScript continue to receive development? Any deprecation calendar? I would propose that this information / overview (even if no decisions made yet) made very clear directly on main page of https://pytorch.org/docs (preferrably at the top of this long page)

This question comes from the fact that PyTorch had offered many tracing/scripting/compilation technologies over the years, their recipes/tutorials are still all available, and the users need to have information at least about mid-term level plans and recommendations of the core team. My related question on deployment as well: https://discuss.pytorch.org/t/torch-compiles-deployment-story-to-non-python-host-processes/180943/4 - would be good to have this explained very clearly as well

cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel @svekars @carljparker @ezyang @albanD

@ezyang
Copy link
Contributor

ezyang commented Jun 23, 2023

In that OP case, need to support TorchScript necessitated use of inelegant indexing instead of simpler tensor[..., perm, :, :] which seems not supported by TorchScript (btw is this construct supported by Dynamo?)

It is supported in Dynamo.

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Jun 23, 2023

The hacks in pytorch/vision#7624 (comment) are new code yet to be merged, so if TorchScript is about to be deprecated, maybe new code should not be developed for its quirks... It seems that some public (and not public) communication on fate of older tech like TorcScript / JIT is badly needed :)

Also, I wonder, can there be now some function decorator to do the if not in scripting: log_usage_once to avoid the repetition, or maybe a way enable these logging somehow externally? so that it doesn't pollute the code everywhere

@soumith
Copy link
Member

soumith commented Jun 23, 2023

TorchScript is frozen. It will not receive new feature development.
We have not deprecated it yet, because we dont offer an equivalent replacement yet (we hope to with torch.export).

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Jun 24, 2023

If this is imminent, I propose to announce it officially anyways so that torchvision doesn't have to develop newish hacks for torch.jit.script or torch.jit.trace :)

@malfet malfet added oncall: jit Add this issue/PR to JIT oncall triage queue module: docs Related to our documentation, both in docs/ and docblocks triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Jun 26, 2023
@vadimkantorov vadimkantorov changed the title [question] [docs] Short/mid/long-term status of TorchScript / JIT / FX / symbolic tracing and its replacement by Dynamo [question] [docs] Short/mid/long-term status of TorchScript / JIT / torch.jit.trace / FX / symbolic tracing and its replacement by Dynamo Jul 5, 2023
pytorchmergebot pushed a commit that referenced this issue Aug 11, 2023
Although the sun is setting for torchscript, it is not [officially deprecated](#103841 (comment)) since nothing currently fully replaces it. Thus, "downstream" libraries like TorchVision, that started offering torchscript support still need to support it for BC.

torchscript has forced us to use workaround after workaround since forever. Although this makes the code harder to read and maintain, we made our peace with it. However, we are currently looking into more elaborate API designs that are severely hampered by our torchscript BC guarantees.

Although likely not intended as such, while looking for ways to enable our design while keeping a subset of it scriptable, we found the undocumented `__prepare_scriptable__` escape hatch:

https://github.com/pytorch/pytorch/blob/0cf918947d161e02f208a6e93d204a0f29aaa643/torch/jit/_script.py#L977

One can define this method and if you call `torch.jit.script` on the object, the returned object of the method will be scripted rather than the original object. In TorchVision we are using exactly [this mechanism to enable BC](https://github.com/pytorch/vision/blob/3966f9558bfc8443fc4fe16538b33805dd42812d/torchvision/transforms/v2/_transform.py#L122-L136) while allowing the object in eager mode to be a lot more flexible (`*args, **kwargs`, dynamic dispatch, ...).

Unfortunately, this escape hatch is only available for `nn.Module`'s

https://github.com/pytorch/pytorch/blob/0cf918947d161e02f208a6e93d204a0f29aaa643/torch/jit/_script.py#L1279-L1283

This was fine for the example above since we were subclassing from `nn.Module` anyway. However, we recently also hit a case [where this wasn't the case](pytorch/vision#7747 (comment)).

Given the frozen state on JIT, would it be possible to give us a general escape hatch so that we can move forward with the design unconstrained while still keeping BC?

This PR implements just this by re-using the `__prepare_scriptable__` hook.
Pull Request resolved: #106229
Approved by: https://github.com/lezcano, https://github.com/ezyang
andreaskoepf added a commit to gpu-mode/resource-stream that referenced this issue Jan 2, 2024
@sebffischer
Copy link

sebffischer commented Dec 15, 2024

@soumith How exactly will this deprecation of TorchScript look like? Will only the PyTorch API be deprecated or will the LibTorch C++ API also be deprecated? I am asking because an R version of pytorch (https://github.com/mlverse/torch) uses TorchScript in order to export models. As it seems to me that these new replacements to TorchScript jit-compilation are only available in Python and don't have a C++ API, deprecating TorchScript would mean that for interfaces to LibTorch that are not in Python would be unable to export models, which would be unfortunate.

@gmagogsfm
Copy link
Contributor

@sebffischer Hi, we generally do not consider LibTorch C++ API to be part of TorchScript, there is no plan to deprecate them.

As for why the new replacement (torch.export) doesn't have C++ runtime API, it is just a matter of timing, we plan to open source a C++ runtime in 2025.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: docs Related to our documentation, both in docs/ and docblocks oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants