diff --git a/python/paddle/distributed/launch/controllers/collective.py b/python/paddle/distributed/launch/controllers/collective.py index ce69ec0a4c781..017eb34de3814 100644 --- a/python/paddle/distributed/launch/controllers/collective.py +++ b/python/paddle/distributed/launch/controllers/collective.py @@ -45,7 +45,10 @@ def build_pod(self): ): return self._build_pod_with_args() else: - return self._build_pod_with_master() + if self.ctx.args.auto_parallel_config is None: + skip_run = True + # only when skip_run is Flase, should not reset pod + return self._build_pod_with_master(skip_run) def _build_pod_with_tuner(self): auto_parallel_config = self.ctx.args.auto_parallel_config @@ -150,7 +153,7 @@ def _build_pod_with_args(self): return True - def _build_pod_with_master(self): + def _build_pod_with_master(self, reset_pod=True): self.pod.replicas = self.pod_replicas() # rank will be reset when restart @@ -205,7 +208,8 @@ def _build_pod_with_master(self): job_endpoints = [i['endpoints'] for i in peer_list] - # self.pod.reset() + if reset_pod: + self.pod.reset() selected_dev_key = self.ctx.node.device.get_selected_device_key() selected_dev_list = self.ctx.node.device.get_selected_devices( self.ctx.args.devices