-
Notifications
You must be signed in to change notification settings - Fork 867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add PyTorch/XLA support #2182
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool! Left some minor questions on the PR directly
One question I had was whether this is the right way to uses torch/xla nowadays or whether users are recommended to pass in an XLA backend to torch.compile()
Since most of machines are running on AWS in CI it's unlikely we'll get a TPU available to fuly test this but I'm assuming this should work just fine on GPU as well, in which case a quick test would also be super helpful
ts/torch_handler/base_handler.py
Outdated
@@ -278,6 +303,9 @@ def inference(self, data, *args, **kwargs): | |||
with torch.no_grad(): | |||
marshalled_data = data.to(self.device) | |||
results = self.model(marshalled_data, *args, **kwargs) | |||
if torch_xla_enabled: | |||
xm.mark_step() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not super familiar with xla internals but what does this line do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the xm.mark_step()
cause this is essential for training, optional for inferencing. In short, the value /calculation is upon either a xm.mark_step()
or when it gets retrieved. In our case it's the latter one.
ts/torch_handler/base_handler.py
Outdated
@@ -59,6 +59,24 @@ def check_pt2_enabled(): | |||
) | |||
|
|||
|
|||
def check_torch_xla_enabled() -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lxning another good candidate for your new config change, it might be possible that a user has xla installed but doesnt want to necessarily comile the model with XLA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@msaroufim yes. the model yaml config can make this much easier. I'll send the PR early next week to unblock this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lxning another good candidate for your new config change, it might be possible that a user has xla installed but doesnt want to necessarily comile the model with XLA
IIUC, the above mentioned scenario applies to gpu. Though, I have torch.cuda.is_available() and properties.get("gpu_id") is not None:
as the prioritized condition. For accelerator type the require torch_xla, users do have option to choose to compile the torchxla_trace_once
, which is an experimental backend for Dynamo.
torch.compile() is a good point. I'm guessing we'll need both this version to support pytorch <2.0, and another change to support pytorch 2.0 models. |
So we do actually already support torch.compile #1960 and you can pass in a custom backend via a I don't think supporting both workflows is a huge deal but curious which one would you prefer people use assuming people have 2.0 installed |
As discussed, we decided to prioritize pytorch/xla 2.0 and above. |
Added |
Codecov Report
@@ Coverage Diff @@
## master #2182 +/- ##
==========================================
+ Coverage 71.31% 71.41% +0.10%
==========================================
Files 73 73
Lines 3336 3348 +12
Branches 57 57
==========================================
+ Hits 2379 2391 +12
Misses 954 954
Partials 3 3
... and 1 file with indirect coverage changes 📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
PR looks good but I was hoping we could have the test you're running checked in and only run it if a TPU is found |
Added test, PTAL |
LGTM thank you, as FYI we're killing the compile.json in the next release but I'll make the change and test out the kokoro CI directly |
Thanks for the heads up! |
@lxning , follow up for review request |
Description
This PR is to add PyTorch/XLA support in TorchServe backend base handler.
Type of change