Skip to content

Commit

Permalink
add ut, test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
sandyhouse committed Jul 9, 2021
1 parent 25abc00 commit bf24fb7
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 14 deletions.
8 changes: 4 additions & 4 deletions paddle/fluid/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,10 @@ void BindVarDsec(pybind11::module *m) {
.def("set_persistable", &pd::VarDesc::SetPersistable)
.def("need_check_feed", &pd::VarDesc::NeedCheckFeed)
.def("set_need_check_feed", &pd::VarDesc::SetNeedCheckFeed)
.def("has_attr", &pd::VarDesc::HasDistributedAttr)
.def("_set_attr", &pd::VarDesc::SetDistributedAttr)
.def("remove_attr", &pd::VarDesc::RemoveDistributedAttr)
.def("attr", &pd::VarDesc::GetDistributedAttr);
.def("has_distributed_attr", &pd::VarDesc::HasDistributedAttr)
.def("_set_distributed_attr", &pd::VarDesc::SetDistributedAttr)
.def("remove_distributed_attr", &pd::VarDesc::RemoveDistributedAttr)
.def("distributed_attr", &pd::VarDesc::GetDistributedAttr);

pybind11::enum_<pd::proto::VarType::Type> vartype(var_desc, "VarType", "");
g_vartype_pytype = (PyTypeObject *)vartype.ptr(); // NOLINT
Expand Down
15 changes: 14 additions & 1 deletion python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@
from .collective import send # noqa: F401
from .collective import wait # noqa: F401

from .auto_parallel import shard_tensor # noqa: F401
from .auto_parallel import shard_op # noqa: F401
from .auto_parallel import set_shard_mask # noqa: F401
from .auto_parallel import set_offload_device # noqa: F401
from .auto_parallel import set_pipeline_stage # noqa: F401
from .auto_parallel import ProcessMesh # noqa: F401

from .fleet import BoxPSDataset # noqa: F401

from .entry_attr import ProbabilityEntry # noqa: F401
Expand Down Expand Up @@ -69,5 +76,11 @@
"ReduceOp",
"wait",
"get_rank",
"ProbabilityEntry"
"ProbabilityEntry",
"shard_tensor",
"shard_op",
"set_shard_mask",
"set_offload_device",
"set_pipeline_stage",
"ProcessMesh",
]
7 changes: 7 additions & 0 deletions python/paddle/distributed/auto_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401
from .interface import set_shard_mask # noqa: F401
from .interface import set_offload_device # noqa: F401
from .interface import set_pipeline_stage # noqa: F401
from .utils import ProcessMesh # noqa: F401

__all__ = []
21 changes: 13 additions & 8 deletions python/paddle/distributed/auto_parallel/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = []

def validate_check():
pass


def shard_tensor(tensor, mesh, dims_mapping):
Expand All @@ -26,8 +28,9 @@ def shard_tensor(tensor, mesh, dims_mapping):
The tensor itself.
"""
validate_check()
tensor.distributed_attr['mesh'] = mesh
tensor.distributed_attr['dims_mapping'] = dims_mapping
tensor.desc._set_distributed_attr('mesh_topology', mesh.get_mesh())
tensor.desc._set_distributed_attr('mesh_group', mesh.get_process_group())
tensor.desc._set_distributed_attr('dims_mapping', dims_mapping)
return tensor


Expand All @@ -41,7 +44,8 @@ def set_shard_mask(tensor, mask):
The tensor itself.
"""
validate_check()
tensor.distributed_attr['mask'] = mask
tensor.desc._set_distributed_attr('mask_shape', mask.shape)
tensor.desc._set_distributed_attr('mask_value', mask.tolist())
return tensor


Expand All @@ -58,9 +62,10 @@ def shard_op(op_name, mesh, input_dims_mapping, output_dims_mapping):
"""
validate_check()
# op_mapping[op_name](parameter list from input_dims_mapping)
op.distributed_attr['mesh'] = mesh
op.distributed_attr['input_dims_mapping'] = input_dims_mapping
op.distributed_attr['output_dims_mapping'] = output_dims_mapping
op.desc._set_distributed_attr('mesh_topology', mesh.get_mesh())
op.desc._set_distributed_attr('mesh_group', mesh.get_process_group())
op.desc._set_distributed_attr('input_dims_mapping', input_dims_mapping)
op.desc._set_distributed_attr('output_dims_mapping', output_dims_mapping)
# input_dims_mapping = {index: {'name': in_name, 'dims_mapping': dims_mapping}}


Expand All @@ -73,7 +78,7 @@ def set_offload_device(tensor, dst_device):
Returns:
None.
"""
tensor.distributed_attr['offload_device'] = dst_device
tensor.desc._set_distributed_attr('offload_device', dst_device)


def set_pipeline_stage(stage):
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

__all__ = []


Expand All @@ -26,7 +28,7 @@ def __init__(self, mesh, process_group=None):
dp_degree=pp_degree=mp_degree=2
mesh = ProcessMesh([dp_degree, pp_degree, mp_degree])
"""
process_num = product(mesh)
process_num = np.prod(mesh)
if process_group is None:
process_group = list(range(process_num))
assert len(process_group) == process_num
Expand Down
59 changes: 59 additions & 0 deletions python/paddle/fluid/tests/unittests/test_auto_parallel_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import paddle
import paddle.fluid as fluid
import paddle.nn as nn

paddle.enable_static()

mesh = paddle.distributed.ProcessMesh([2, 3])


class SimpleNet(nn.Layer):
def __init__(self, vocab_size=128, hidden_size=4):
super(SimpleNet, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.dense1 = nn.Linear(hidden_size, hidden_size)
self.dense2 = nn.Linear(hidden_size, hidden_size // 2)

def forward(self, x, y):
x = paddle.distributed.shard_tensor(x, mesh, dims_mapping=[0, -1])
emb_out = self.word_embeddings(x)

y = paddle.distributed.shard_tensor(y, mesh, dims_mapping=[0, -1])
linear1 = self.dense1(y)
out = self.dense2(linear1)
return emb_out, linear1, out


class TestAutoParallelAPI(unittest.TestCase):
def test_api(self):
net = SimpleNet()
x = fluid.layers.fill_constant(shape=[2, 4], value=1, dtype="int64")
y = fluid.layers.fill_constant(shape=[2, 4], value=2, dtype="float32")
emb_out, linear1, out = net.forward(x, y)
self.assertEqual(x.desc.distributed_attr('mesh_topology'), [2, 3])
self.assertEqual(
x.desc.distributed_attr('mesh_group'), [0, 1, 2, 3, 4, 5])
self.assertEqual(y.desc.distributed_attr('mesh_topology'), [2, 3])
self.assertEqual(
y.desc.distributed_attr('mesh_group'), [0, 1, 2, 3, 4, 5])


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions python/setup.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ packages=['paddle',
'paddle.distributed.fleet.meta_parallel',
'paddle.distributed.fleet.meta_parallel.pp_utils',
'paddle.distributed.fleet.meta_parallel.parallel_layers',
'paddle.distributed.auto_parallel',
'paddle.framework',
'paddle.jit',
'paddle.jit.dy2static',
Expand Down

0 comments on commit bf24fb7

Please sign in to comment.