Skip to content

Commit

Permalink
fix device id env
Browse files Browse the repository at this point in the history
  • Loading branch information
kuizhiqing committed Mar 23, 2022
1 parent 17b8335 commit c519d83
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 21 deletions.
3 changes: 2 additions & 1 deletion python/paddle/distributed/fleet/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ def _parse_args():
elastic_group.add_argument(
"--force", type=bool, default=False, help="update np force")

return parser.parse_args()
known_args, _ = parser.parse_known_args()
return known_args


def get_cluster_from_args(args, device_mode, devices_per_proc):
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/distributed/launch/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ class Context(object):
def __init__(self, enable_plugin=True):
self.args, self.unknown_args = parse_args()
self.envs = fetch_envs()
self.logger = self.get_logger()

self.set_env_in_args()

self.node = Node()
self.status = Status()

self.set_env_in_args()
self.logger = self.get_logger()

# design for event queue, later
self.events = []
Expand Down
22 changes: 9 additions & 13 deletions python/paddle/distributed/launch/context/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def labels(self, lbs):
else:
self._labels = []

def get_selected_flag_key(self):
def get_selected_device_key(self):
if self._dtype == DeviceType.CPU:
return 'FLAGS_selected_cpus'
if self._dtype == DeviceType.GPU:
Expand All @@ -70,19 +70,15 @@ def get_selected_flag_key(self):
return 'FLAGS_selected_mlus'
return 'FLAGS_selected_devices'

def get_selected_flag_label(self, idx):
if idx < len(self._labels):
return self._labels[idx]
def get_selected_devices(self, devices=''):
'''
return the device label/id relative to the visible devices
'''
if not devices:
return [str(x) for x in range(0, len(self._labels))]
else:
return '0'

def selected_flags(self, idx=None):
if idx is None:
return {self.get_selected_flag_key(): ','.join(self._labels)}
else:
return {
self.get_selected_flag_key(): self.get_selected_flag_label(idx)
}
devs = [x.strip() for x in devices.split(',')]
return [str(self._labels.index(d)) for d in devs]

@classmethod
def parse_device(self):
Expand Down
7 changes: 5 additions & 2 deletions python/paddle/distributed/launch/controllers/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def build_pod(self):
job_endpoints = [i['endpoints'] for i in peer_list]

self.pod.reset()
selected_dev_key = self.ctx.node.device.get_selected_device_key()
selected_dev_list = self.ctx.node.device.get_selected_devices(
self.ctx.args.devices)
for i in range(self.pod.replicas):
e = {
"PADDLE_MASTER": collective_master,
Expand All @@ -90,9 +93,9 @@ def build_pod(self):
"PADDLE_RANK_IN_NODE": str(i),
}
if self.pod.replicas == 1:
e.update(self.ctx.node.device.selected_flags())
e.update({selected_dev_key: selected_dev_list})
else:
e.update(self.ctx.node.device.selected_flags(i))
e.update({selected_dev_key: selected_dev_list[i]})
self.add_container(envs=e, log_tag=i)

return True
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/launch/controllers/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ def pod_replicas(self):

if self.ctx.args.nproc_per_node:
return int(self.ctx.args.nproc_per_node)
elif self.ctx.args.devices:
return len(self.ctx.args.devices.split(','))
else:
return self.ctx.node.device.count

Expand Down
5 changes: 3 additions & 2 deletions python/paddle/distributed/launch/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def process_args(ctx):
#argdev = ctx.args.gpus or ctx.args.xpus or ctx.args.npus
argdev = ctx.args.devices
if argdev:
ctx.node.device.labels = argdev.split(',')
ctx.logger.debug('Device reset by args {}'.format(argdev))
for d in argdev.split(','):
assert d in ctx.node.device.labels, 'Device not found {}'.format(
argdev)


def collective_compatible(ctx):
Expand Down
5 changes: 4 additions & 1 deletion python/paddle/fluid/tests/unittests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def pdrun(self, args, env=None):
if args:
cmd.extend(args.split(" "))
cmd.extend([pyname])
proc = subprocess.Popen(cmd, env)
env = os.environ.copy()
# virtual devies for testing
env.update({'CUDA_VISIBLE_DEVICES': '0,1,2,3,4,5,6,7'})
proc = subprocess.Popen(cmd, env=env)
return proc

def test_collective_1(self):
Expand Down

1 comment on commit c519d83

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.