Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Generalize pytorch content for non-native device execution #66

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions RFC-0039-generalize-pytorch-ut.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@

# [RFC] Generalization of PyTorch framework UT for non-cuda device execution

**Authors:**
* @ankurneog


## **Summary**
Modify PyTorch framework UTs so that non-cuda devices such as intel Gaudi and intel XPU is able to harness the content and improve quality.


## **Motivation**
The Pytorch framework UTs are good indicator for device stack health, however these are mostly written for cpu and cuda devices, which restricts its use for non-cuda devices.

We propose to modify the content wherever possible to make it available for non-cuda device execution

This will also ensure greater participation for content enhancement.

### **Examples**

* The execution is blocked for non-native devices using decorators such as ```onlyNativeDevices```
* The execution is blocked for cuda only using decorators such as ```onlyNCCL``` or ```onlyCUDA```
* Need scalable mechanism to select Dtypes per op described in OpInfo or ModuleInfo instead of using separate variable similar to ```dtypesIfCUDA```
* Need scalable mechanism to skip for different devices instead of using specific decorator ```skipIfCUDA```
* The dynamo content should be refactored to allow tweaking per platform/device for eg. addition of custom backends or skipping in case of unsupported backends
* Distributed content assumes most execution is done for nccl and gloo, with almost entire non-cpu content hard coded for nccl.

## **Proposed Implementation**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the idea makes a lot of sense.
I think we would need more details and how this interracts with existing features like the device-generic tests (https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/common_device_type.py) that already work for privateuse1 partially btw and the opinfo consistency tests.

Also I don't think we want to aim at running the full test suite with the side device available but select specific device-dependent tests that need to be ran for each device we support.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD : thanks for your comment , yes I believe we need some extensive hooks to enable this , for eg. the one we introduced with
https://github.com/pytorch/pytorch/pull/128584/files#diff-d183f2afc51d6a59bc70094e8f476d2468c45e415500f6eb60abad955e065156R531
@onlyNativeDeviceTypesAnd(["hpu"])
The other devices can add to to such list, if it supports the TC.

we can modify other hooks like skipIfDevice in similar fashion.

The common_device_type is useable if we replace the onlyNativeDeviceType decorator.
It was widely used in the initial files but all recent files are not using it (eg: dynamo/distributed) and instead directly make .cuda() calls.

However these content shouldn't be too difficult to migrate.

I believe in general we should ensure new test content uses the common_device_type framework, and open up the content for "non-native" device execution.

Since the content is huge, we propose a staggered approach for the implementation
Steps:
* Remove restriction imposed through @onlyNativeDevices in core content, replace these with hooks so that supported devices can enable their content selectively.
These should be flexible enough to support both in-tree and out-of-tree devices.
* Dtypes for a device should be dynamically loaded per op based on a common dictionary, instead of using different variables per device , eg: dtypesIfCuda
* Miscelleneous decorators such as @skipIfCuda should be generalized @skipIfDevice
* Extend use of instantiate_device_type for all content, so that developers are forced to use generalized device code rather than using "cuda" or "cpu"
* Generalize common distributed content , so that it can be extended for non nccl backends such as intel's hccl and ccl
* Generalize the dynamo content for specific backends which other devices might want to verify with existing content, the backends should always be extracted from
a list that is abstracted out and the list can be appended per device per TC.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good list of items to generalize the test cases. Does the proposal just focus on the devices having dedicated device tags installed in PyTorch core or also support the PrivateUse1 device which is used to extend PyTorch with any out of the tree devices?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jgong5 : Thanks for your comment, Since i investigated mostly in lines to support intel Gaudi, which has dedicated device tag, i have not checked the impact or support needed for PrivateUse1 devices

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps @FFFrog @Yikun have more thoughts on this?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jgong5 @ankurneog Sorry for the late reply.

In theory, in the test framework of PyTorch, dedicated keys are almost the same as public keys (PrivateUse1), and PrivateUse1 is already supported in the test framework.

First of all, I can`t agree more with this proposal, because Ascend NPU is currently facing the above-described problems; The solution proposed by @ankurneog can solve most of the problems we have encountered; We are currently sorting out all the problems encountered, and will add them to this RFC later, and hope that the new stuff we will add will help the proposal be more complete.

By the way, if possible, we can work together to complete this proposal and make it land in PyTorch :D




#### Metrics
Other devices can track the pass-percentage and be part of the CI if the coverage and pass percentage is good.

#### Additional Context
Towards adding support for Intel Gaudi devices we have already done couple of changes in this regard.
* Removing onlyNativeDevice : https://github.com/pytorch/pytorch/pull/128584

* Changing Dynamo Content : https://github.com/pytorch/pytorch/pull/130714

* Generalizing Distributed Content : https://github.com/pytorch/pytorch/pull/131758

* Generalizing FSDP Content : https://github.com/pytorch/pytorch/pull/133209

More to follow


### Next Steps
As part of introducing support for intel Gaudi which is an out-of-tree device, we are already introduces changes to support it in a manner that can be used by other devices as well.