Skip to content

Conversation

@RUAN-ZX
Copy link
Contributor

@RUAN-ZX RUAN-ZX commented Oct 26, 2023

Unit tests would fail or skip when device=npu, and we definitely want to test all these wonderful features by official unit tests.
Here comes the commit to add NPU support for unit test. P.S. see what we have already done #4567.

What I do in this commit

  1. Just add npu logic branch
    feat: Add npu support for skip_on_arch in tests/unit/util.py
    feat: Add npu support for skip_on_cuda in tests/unit/util.py
    feat: Add npu support for tests/unit/common.py

  2. Set_device of accelerator before deepspeed.init_distributed in tests/unit/common.py
    It would be friendlier and easier for other device like npu, if we can set_device of accelerator before init_distributed. Plus, setting device param before init sounds more reasonable.

  3. Solve the problem of calling get_accelerator().random().fork_rng with non-cuda device
    Function train_cifar() in tests/unit/alexnet_model.py calls get_accelerator().random().fork_rng without passing device_type explicitly. Unfortunately, torch.random.fork_rng() has default value setting device_type=cuda and non-cuda devices would fail to run. So my solution is explicitly passing device_type=get_accelerator().device_name(), and either cuda or non-cuda devices would perform correctly.

@RUAN-ZX
Copy link
Contributor Author

RUAN-ZX commented Oct 26, 2023

@microsoft-github-policy-service agree

@RUAN-ZX RUAN-ZX marked this pull request as draft October 26, 2023 03:43
@RUAN-ZX RUAN-ZX changed the title [NPU] Add NPU support for unit test #4568 [WIP] [NPU] Add NPU support for unit test #4568 Oct 26, 2023
@RUAN-ZX RUAN-ZX changed the title [WIP] [NPU] Add NPU support for unit test #4568 [WIP] [NPU] Add NPU support for unit test Oct 26, 2023
@hipudding
Copy link
Contributor

Please add cuda check in "check_environment" to avoid warning meesage when using NPU as the backend.
Please squash these commits into one(use git rebase -i), and describe all the changes.

@RUAN-ZX RUAN-ZX changed the title [WIP] [NPU] Add NPU support for unit test [NPU] Add NPU support for unit test Oct 30, 2023
@RUAN-ZX RUAN-ZX marked this pull request as ready for review October 30, 2023 06:30
@RUAN-ZX
Copy link
Contributor Author

RUAN-ZX commented Nov 1, 2023

@tjruwase Would you be so kind to review this commit, since we have some other commits based on this? Or maybe you can invite other reviewers to do the job? Thank you.

@RUAN-ZX
Copy link
Contributor Author

RUAN-ZX commented Nov 1, 2023

Please add cuda check in "check_environment" to avoid warning meesage when using NPU as the backend. Please squash these commits into one(use git rebase -i), and describe all the changes.

Npu support for check_environment will be done in another PR :)
For commits squashing, I think these commits should be seperated for clarity :)
If you have more suggestions, please let me know, thank you for your advice.

@RUAN-ZX
Copy link
Contributor Author

RUAN-ZX commented Nov 1, 2023

@tjruwase It seems that the CI problem might be fixed by loadams in #4590? If so, would you please launch CI for me again?
Several PRs (#4588, #4585, #4578, etc) have met the same problem as below:
unit/inference/test_inference.py::TestMPSize::test[fp16-bloom] FAILED [ 91%]

P.S. I have push new code for solving problems, you can see that in no.3 description in the first comment. Thanks :)

@RUAN-ZX
Copy link
Contributor Author

RUAN-ZX commented Nov 3, 2023

@tjruwase Perhaps #4591 has solved the problem of CI? Cause I see the latest commit #4598 manages to pass. Would you launch CI for me again?

@RUAN-ZX
Copy link
Contributor Author

RUAN-ZX commented Nov 4, 2023

@tjruwase Perhaps #4591 has solved the problem of CI? Cause I see the latest commit #4598 manages to pass. Would you launch CI for me again?

@tjruwase Could you launch CI for me again? The last two commits have already pass.

@RUAN-ZX
Copy link
Contributor Author

RUAN-ZX commented Nov 6, 2023

@tjruwase Would you please launch CI for this PR? I have been waiting for a long time. Or maybe is there something that I need to improve? In that case I really hope you can point that out for me, so that I can fix it ASAP and do better next time. Thanks :)

@tjruwase
Copy link
Contributor

tjruwase commented Nov 7, 2023

@RUAN-ZX, apologies for the delay in merging this PR. We had to push out a scheduled release last week. Unfortunately, I notice a new CI failure https://github.com/microsoft/DeepSpeed/actions/runs/6785863907/job/18445092642?pr=4569.

Can you please take a look? It seems to have been caused by a recent merge. I will play close attention to ensure this PR is merged this week. Thanks for your patience.

@RUAN-ZX
Copy link
Contributor Author

RUAN-ZX commented Nov 8, 2023

@RUAN-ZX, apologies for the delay in merging this PR. We had to push out a scheduled release last week. Unfortunately, I notice a new CI failure https://github.com/microsoft/DeepSpeed/actions/runs/6785863907/job/18445092642?pr=4569.

Can you please take a look? It seems to have been caused by a recent merge. I will play close attention to ensure this PR is merged this week. Thanks for your patience.

Thank you! @tjruwase About the failure, I found an assert error: CUDA_HOME does not exist, unable to compile CUDA op(s) from op_builder/builder.py, and eventually it uses cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') to get CUDA_HOME in torch\utils\cpp_extension.py.
There is no code that may change this env param and I don't understand why CUDA_HOME is None :)

@RUAN-ZX
Copy link
Contributor Author

RUAN-ZX commented Nov 13, 2023

@tjruwase I have solved problems raised by pre-commit hooks, please launch CI again :)

@tjruwase tjruwase added this pull request to the merge queue Nov 13, 2023
Merged via the queue into deepspeedai:master with commit 4b7cae7 Nov 13, 2023
mrwyattii added a commit that referenced this pull request Dec 15, 2023
Our torch 1.10 tests have been failling since the merge of #4569. This
added a `device_type` kwarg to the `torch.random.fork_rng` call. But
this is not compatible with older versions of torch. Added in
pytorch/pytorch#98069

Fixes #4644, #4503
mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
Unit tests would fail or skip when device=npu, and we definitely want to
test all these wonderful features by official unit tests.
Here comes the commit to add NPU support for unit test. P.S. see what we
have already done deepspeedai#4567.


**What I do in this commit**
1. Just add npu logic branch 
feat: Add npu support for skip_on_arch in tests/unit/util.py
feat: Add npu support for skip_on_cuda in tests/unit/util.py
feat: Add npu support for tests/unit/common.py

2. Set_device of accelerator before deepspeed.init_distributed in
tests/unit/common.py
It would be friendlier and easier for other device like npu, if we can
set_device of accelerator before init_distributed. Plus, setting device
param before init sounds more reasonable.

3. Solve the problem of calling get_accelerator().random().fork_rng with
non-cuda device
Function `train_cifar()` in `tests/unit/alexnet_model.py` calls
`get_accelerator().random().fork_rng` without passing `device_type`
explicitly. Unfortunately, `torch.random.fork_rng()` has default value
setting `device_type=cuda` and non-cuda devices would fail to run. So my
solution is explicitly passing
`device_type=get_accelerator().device_name()`, and either cuda or
non-cuda devices would perform correctly.

---------

Co-authored-by: ryan <ruanzhixiang1@huawei.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
Our torch 1.10 tests have been failling since the merge of deepspeedai#4569. This
added a `device_type` kwarg to the `torch.random.fork_rng` call. But
this is not compatible with older versions of torch. Added in
pytorch/pytorch#98069

Fixes deepspeedai#4644, deepspeedai#4503
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants