Skip to content

Commit

Permalink
fix bug of sync_parameters (#33955)
Browse files Browse the repository at this point in the history
  • Loading branch information
ForFishes authored Jul 5, 2021
1 parent 9254183 commit bd559a2
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions python/paddle/fluid/dygraph/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
import numpy as np
import warnings
from collections import OrderedDict
import itertools
import warnings

import paddle
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph import layers
Expand All @@ -26,9 +29,7 @@
from paddle.utils import deprecated
from ..layers import collective
from paddle.fluid.dygraph import base as imperative_base
import warnings
import paddle
import itertools
from paddle.fluid.framework import ParamBase

__all__ = ["prepare_context", "ParallelEnv", "DataParallel"]

Expand Down Expand Up @@ -353,8 +354,9 @@ def sync_params_buffers(model,
raise TypeError("The data type of '%s' must be Varbase" %
param.name)
# is_distributed param not need to sync when in mp mode
if is_model_parallel and param.is_distributed:
continue
if is_model_parallel and isinstance(param, ParamBase):
if param.is_distributed:
continue

model_vars.append(param.detach())
if len(model_vars) == 0:
Expand Down

0 comments on commit bd559a2

Please sign in to comment.