Skip to content

Commit

Permalink
MP overlap for 1f1b (PaddlePaddle#57446)
Browse files Browse the repository at this point in the history
* B-F overlap

* Add column_parallel_linear_backward_overlapping

* Add cost model

* Insert reshape for ColumnParallelLinearBackwardOverlappingPass

* Add cross-program event dependency

* Refine split program in _backward_forward_overlap

* Add empirical op cost

* Add NOTE

* Remove some redundant codes

* Remove some redundant codes

* Fix UTs
  • Loading branch information
From00 authored and Frida-a committed Oct 14, 2023
1 parent 204d02d commit e517aad
Show file tree
Hide file tree
Showing 12 changed files with 664 additions and 112 deletions.
34 changes: 29 additions & 5 deletions python/paddle/distributed/auto_parallel/static/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(self, global_id, local_id, machine):
self._dp_gflops = None
# Single precision GFLOPS
self._sp_gflops = None
# Half precision GFLOPS
self._hp_gflops = None
# Memory is stored by GB
self._memory = None

Expand Down Expand Up @@ -120,6 +122,14 @@ def sp_gflops(self):
def sp_gflops(self, value):
self._sp_gflops = value

@property
def hp_gflops(self):
return self._hp_gflops

@hp_gflops.setter
def hp_gflops(self, value):
self._hp_gflops = value

@property
def memory(self):
return self._memory
Expand All @@ -130,14 +140,15 @@ def memory(self, value):

def __str__(self):
str = ""
str += "global_id: {}, local_id: {}, machine_id: {}, type: {}, model: {}, dp_flops: {}, sp_flops: {}, memory: {}".format(
str += "global_id: {}, local_id: {}, machine_id: {}, type: {}, model: {}, dp_flops: {}, sp_flops: {}, hp_flops: {}, memory: {}".format(
self.global_id,
self.local_id,
self.machine.id,
self.type.name,
self.model,
self.dp_gflops,
self.sp_gflops,
self.hp_gflops,
self.memory,
)
return str
Expand Down Expand Up @@ -443,6 +454,7 @@ def gen_default_config_cluster(
intra_bandwidth=235,
gpu_dp_gflops=7800,
gpu_sp_gflops=15700,
gpu_hp_gflops=31400,
cpu_dp_gflops=75,
cpu_sp_gflops=150,
):
Expand Down Expand Up @@ -524,17 +536,16 @@ def _convert_to_cpu_info(cpu_model):
local_id += 1
type = _convert_to_type(gpu_model)
model = _convert_to_model(gpu_model, gpu_memory)
dp_gflops = gpu_dp_gflops
sp_gflops = gpu_dp_gflops
memory = gpu_memory

device["global_id"] = global_id
device["local_id"] = local_id
device["type"] = type
device["model"] = model
device["memory"] = memory
device["sp_gflops"] = sp_gflops
device["dp_gflops"] = dp_gflops
device["sp_gflops"] = gpu_sp_gflops
device["dp_gflops"] = gpu_dp_gflops
device["hp_gflops"] = gpu_hp_gflops
# hard code
device["type"] = "GPU"
global_id_to_device_type[global_id] = type
Expand Down Expand Up @@ -694,6 +705,7 @@ def _build_from_dict(self, cluster_info):
device.model = device_info.get("model", None)
device.dp_gflops = float(device_info.get("dp_gflops", 0))
device.sp_gflops = float(device_info.get("sp_gflops", 0))
device.hp_gflops = float(device_info.get("hp_gflops", 0))
device.memory = float(device_info.get("memory", 0))
self.add_device(device)
self.add_machine(machine)
Expand Down Expand Up @@ -909,10 +921,22 @@ def is_by_json_config(json_config):
os.getenv("PADDLE_CURRENT_ENDPOINT", None),
)
)

gflops_info = {
"V100": {"dp": 7800, "sp": 15700, "hp": 125000},
"A100": {"dp": 9700, "sp": 19500, "hp": 624000},
}
default_gflops = (
gflops_info["A100"] if gpu_model == "A100" else gflops_info["V100"]
)

cluster.gen_default_config_cluster(
node_count=node_count,
device_count=local_device_count,
gpu_model=gpu_model,
gpu_memory=memory,
gpu_dp_gflops=default_gflops["dp"],
gpu_sp_gflops=default_gflops["sp"],
gpu_hp_gflops=default_gflops["hp"],
)
return cluster
35 changes: 28 additions & 7 deletions python/paddle/distributed/auto_parallel/static/cost/base_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import numpy as np

import paddle
from paddle.base.core import VarDesc
from paddle.utils.flops import flops

from ..cluster import LinkType, get_default_cluster
from ..cluster import DeviceType, LinkType, get_default_cluster
from ..dist_tensor import DistributedTensor
from ..process_group import get_process_group
from ..utils import _get_comm_group, _get_idx_in_axis
Expand Down Expand Up @@ -936,7 +937,13 @@ def calc_time_by_cost_model(op, cluster=None):
)
if not cluster:
cluster = get_default_cluster()
time = 0.0

assert cluster._gpu_model in [
"V100",
"A100",
], "Only A100 and V100 gpu has been supported currently."

time = 0.0 # microsecond
op_type = op.type
# calc comp op time by flops
if op_type not in NON_COMP_TYPE:
Expand All @@ -958,15 +965,29 @@ def calc_time_by_cost_model(op, cluster=None):
else:
flops_count = flops(op_type, inputs, attrs)

if cluster._gpu_model == "V100":
time = flops_count * 2.9e-7 * 2.6
elif cluster._gpu_model == "A100":
time = flops_count * 2.9e-7
# FIXME(Ruibiao): Need a better way to get dtype
var_name = op.output_arg_names[0]
dtype = op.block._var_recursive(var_name).dtype
device = cluster.get_device(0)
assert (
device.type == DeviceType.GPU
), "Only GPU device is supported currently."

gflops = 0.0
if dtype == VarDesc.VarType.FP64:
gflops = device.dp_gflops
elif dtype == VarDesc.VarType.FP32:
gflops = device.sp_gflops
elif dtype == VarDesc.VarType.FP16 or dtype == VarDesc.VarType.BF16:
gflops = device.hp_gflops
else:
raise ValueError(
"Only A100 and V100 gpu has been supported currently."
f"Unsupported modeling compute time for dtype: {dtype}."
)

utilization_rate = 0.98
time = flops_count / (utilization_rate * gflops) * 1e-3

# calc comm op time by communication modeling formula
elif op_type in COMM_OP_TYPE:
op_cost = _g_op_cost_factory[op_type](
Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,7 @@ def fit(
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy,
)

lr = auto_utils.get_lr(self.optimizer)
logs = self._prepare_logger(
outs,
Expand Down
12 changes: 12 additions & 0 deletions python/paddle/distributed/auto_parallel/static/parallelizer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
import logging
import os
import time

from paddle.distributed.passes import PassManager, new_pass
Expand Down Expand Up @@ -359,6 +360,17 @@ def _apply_post_optimization(
)
params_grads = self._pass_context.get_attr("params_grads")

mp_async_allreduce_in_backward = os.getenv(
"FLAGS_mp_async_allreduce_in_backward"
) in [1, "1", True, "True"]
if mp_async_allreduce_in_backward:
column_parallel_linear_backward_overlapping_pass = new_pass(
"column_parallel_linear_backward_overlapping", {}
)
column_parallel_linear_backward_overlapping_pass.apply(
[main_program], [startup_program], self._pass_context
)

if self.is_train:
# GradClip is train-only optimization
config = copy.deepcopy(self._strategy.sharding.to_dict())
Expand Down
7 changes: 5 additions & 2 deletions python/paddle/distributed/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from .pass_base import new_pass, PassManager, PassContext
from .fuse_all_reduce import * # noqa: F403

from .auto_parallel_gradient_merge import * # noqa: F403
from .auto_parallel_sharding import * # noqa: F403
from .auto_parallel_amp import * # noqa: F403
Expand All @@ -24,11 +24,14 @@
from .auto_parallel_grad_clip import * # noqa: F403
from .auto_parallel_supplement_explicit_dependencies import * # noqa: F403
from .auto_parallel_pipeline import * # noqa: F403
from .pipeline_scheduler_pass import * # noqa: F403
from .column_parallel_linear_backward_overlapping import * # noqa: F403
from .cpp_pass import * # noqa: F403
from .fuse_all_reduce import * # noqa: F403
from .pipeline_scheduler_pass import * # noqa: F403
from .ps_trainer_pass import * # noqa: F403
from .ps_server_pass import * # noqa: F403


__all__ = [
'new_pass',
'PassManager',
Expand Down
7 changes: 3 additions & 4 deletions python/paddle/distributed/passes/auto_parallel_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from paddle.utils import unique_name

from .pass_base import PassBase, register_pass
from .pass_utils import AutoParallelStreamType

OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
Expand Down Expand Up @@ -748,13 +749,11 @@ def _fuse_overlap_parameter_comm_stage_two(self, sharding_info):
group = sharding_info.group
else:
group = new_process_group(ranks, force_new_group=True)
# NOTE here stream is just a presentation with different name,
# it is up to executor to create the exact streams given the name.
stream = f"sharding_param_comm_stream{i}"

self.param_comm_group_stream_pairs.append(
{
"comm_group": group,
"comm_stream": stream,
"comm_stream": AutoParallelStreamType.SHARDING_STREAM.value,
}
)
_logger.info(
Expand Down
Loading

0 comments on commit e517aad

Please sign in to comment.