Skip to content

Commit

Permalink
[MLU] fix mlu ctest final. (#44404)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShawnNew authored Jul 18, 2022
1 parent 1d12832 commit b2224e6
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@
set -e
# use default values
# FIXME: random fails on Unknown command lines -c (or -m).
launch_py=${PADDLE_BINARY_DIR}/python/paddle/distributed/launch.py
MLU_VISIBLE_DEVICES=0,1 python ${launch_py} c_comm_init_op_mlu.py
MLU_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch c_comm_init_op_mlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
import paddle.nn.functional as F
import paddle.fluid as fluid
import paddle
import sys

sys.path.append("..")
from op_test import OpTest

import numpy as np
import unittest
import sys

sys.path.append("..")

paddle.enable_static()
SEED = 2021
Expand Down
7 changes: 4 additions & 3 deletions python/paddle/fluid/tests/unittests/mlu/test_spawn_mlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,13 @@ def test_get_default_nprocs(self):
self.assertEqual(nprocs, core.get_mlu_device_count())

def test_spawn(self):
context = dist.spawn(train, backend='cncl', nprocs=4)
num_devs = core.get_mlu_device_count()
context = dist.spawn(train, backend='cncl', nprocs=num_devs)
rank_list = []
for i in range(4):
for i in range(num_devs):
rank_list.append(context.return_queues[i].get())
rank_list.sort()
self.assertEqual(rank_list, list(range(4)))
self.assertEqual(rank_list, list(range(num_devs)))


if __name__ == '__main__':
Expand Down

0 comments on commit b2224e6

Please sign in to comment.