Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add the basic apis for auto_parallel #33804

Merged
merged 60 commits into from
Aug 11, 2021
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
b985745
add auto_parallel dir
Jun 28, 2021
b79e749
mv to paddle.distributed
Jun 28, 2021
1671850
add shard_xx api
Jul 1, 2021
ec55a43
add distributed attrs for var
Jul 8, 2021
25abc00
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Jul 9, 2021
bf24fb7
add ut, test=develop
Jul 9, 2021
8ea9363
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Jul 18, 2021
9e4b3d8
add dist
Jul 21, 2021
e65f77e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Jul 22, 2021
8b95c1e
update
Jul 26, 2021
ccae6ae
update
Jul 26, 2021
d107751
update
Jul 27, 2021
f7e70ea
update
Jul 27, 2021
3111159
update
Jul 27, 2021
70cdb69
update, test=develop
Jul 27, 2021
9e5b0f0
update, test=develop
Jul 27, 2021
59936ef
update, test=develop
Jul 27, 2021
27ee413
update, test=develop
Jul 27, 2021
3a8ceef
update, test=develop
Jul 27, 2021
d11f317
update, test=develop
Jul 28, 2021
f5ef245
update, test=develop
Jul 28, 2021
7293b4f
update
Jul 28, 2021
1240edc
update
Jul 28, 2021
05455fb
update
Jul 28, 2021
3e1b3a0
update
Jul 28, 2021
8950c35
update
Jul 28, 2021
b94a9f2
update, test=develop
Jul 28, 2021
e121349
update, test=develop
Jul 28, 2021
fe51aa3
update
Jul 28, 2021
4563d42
update
Jul 28, 2021
192580d
Merge branch 'develop' into auto_parallel_basic
Jul 28, 2021
2e69980
delete unused proto
Jul 28, 2021
608dd3f
resotre op_desc
Jul 28, 2021
cb9b6bf
restore type_defs
Jul 28, 2021
8e6559e
update var_desc
Jul 28, 2021
00f5f4d
remove dimss_mapping for proto_pybind
Jul 28, 2021
1aa94da
update interface.py
Jul 28, 2021
97a446c
update framework.py
Jul 28, 2021
c586fc6
update
Jul 28, 2021
fc6cde9
update
Jul 29, 2021
967d0e7
fix process_mesh ut
Jul 29, 2021
cd1e390
fix process_mesh ut
Jul 29, 2021
f48ec91
update
Jul 29, 2021
b07affa
update, test=develop
Jul 30, 2021
a00fe9e
update
Jul 30, 2021
da9fe30
update
Aug 2, 2021
3daecf2
update
Aug 2, 2021
5640879
fix doc sample codes, test=develop
Aug 2, 2021
05b0f82
improve coverage, test=develop
Aug 2, 2021
fe93d0e
add static_mode check, test=develop
Aug 2, 2021
9856d47
update, test=develop
Aug 4, 2021
890c70c
add set_placement, test=develop
Aug 5, 2021
4b90b03
update doc, test=develop
Aug 5, 2021
c395b84
update doc, test=develop
Aug 5, 2021
8390e01
update doc, test=develop
Aug 5, 2021
f7d5631
update doc, test=develop
Aug 6, 2021
3a2666e
update, test=develop
Aug 6, 2021
773516b
update ndarray to nested list, test=develop
Aug 10, 2021
685504f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Aug 10, 2021
7ac6299
update, test=develop
Aug 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ cc_test(operator_exception_test SRCS operator_exception_test.cc DEPS operator op
cc_library(version SRCS version.cc)
cc_test(version_test SRCS version_test.cc DEPS version)

cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute shape_inference op_info operator glog version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc process_mesh_desc.cc DEPS attribute shape_inference op_info operator glog version)

cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)

Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/framework/framework.proto
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ enum AttrType {
FLOAT64S = 12;
}

message ProcessMeshDesc {
required int32 id = 1;
required int32 parent_id = 2;
repeated int32 topology = 3;
repeated int32 process_group = 4;
};

// OpDesc describes an instance of a C++ framework::OperatorBase
// derived class type.
message OpDesc {
Expand Down Expand Up @@ -167,6 +174,15 @@ message VarType {
}

message VarDesc {

message Attr {
required string name = 1;
required AttrType type = 2;
optional int32 i = 3;
optional string s = 4;
repeated int32 ints = 5;
};

required string name = 1;
required VarType type = 2;
optional bool persistable = 3 [ default = false ];
Expand All @@ -175,6 +191,7 @@ message VarDesc {
optional bool need_check_feed = 4 [ default = false ];
optional bool is_parameter = 5 [ default = false ];
optional bool stop_gradient = 6 [ default = false ];
repeated Attr attrs = 7;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些新增的字段,在保存模型的时候,会被存下来吗?
我看示例代码,模型定义的时候就会添加这些字段,模型定义完再调用模型保存的时候,是不是会把这些字段都保存下来?什么时候把这些字段去掉呢?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

自动并行主要包括以下几个主要过程:1. 使用自动并行接口标识关键tensor或op;2. 自动补全:补全所有tensor和op的分布式属性;3. 逻辑切分;4. 物理映射;5. 执行训练。其中步骤1-3会使用到此处新增的字段;所以该接口新增的字段会在步骤1-3完成后删除,且该过程用户无感知。

常规的模型保存过程是 执行部分训练或全部训练完成后进行模型保存,这时,新增字段已经完全删除。

但存在一个特殊的情形,即用户完成组网后即刻保存模型,这时相关的字段会被保存下来。但我们认为,这种特殊情形是不应该存在的,因为完成组网后即保存模型是没有意义的。

}

message BlockDesc {
Expand Down
61 changes: 61 additions & 0 deletions paddle/fluid/framework/process_mesh_desc.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/process_mesh_desc.h"

namespace paddle {
namespace framework {

int32_t ProcessMeshDesc::next_id = -1;

ProcessMeshDesc::ProcessMeshDesc(const std::vector<int32_t> &topo,
const std::vector<int32_t> &process_group,
int32_t parent_id) {
int32_t cur_id = ++next_id;
desc_.set_id(cur_id);
desc_.set_parent_id(parent_id);
for (size_t i = 0; i != topo.size(); ++i) {
desc_.add_topology(topo[i]);
}
for (size_t i = 0; i != process_group.size(); ++i) {
desc_.add_process_group(process_group[i]);
}
ProcessMeshDescMap::GetInstance().Insert(cur_id, this);
}

std::vector<int32_t> ProcessMeshDesc::Topology() const {
size_t size = desc_.topology_size();
std::vector<int32_t> ret(size);
for (auto i = 0; i != desc_.topology_size(); ++i) {
ret[i] = desc_.topology(i);
}
return ret;
}

std::vector<int32_t> ProcessMeshDesc::ProcessGroup() const {
size_t size = desc_.process_group_size();
std::vector<int32_t> ret(size);
for (auto i = 0; i != desc_.process_group_size(); ++i) {
ret[i] = desc_.process_group(i);
}
return ret;
}

ProcessMeshDescMap &ProcessMeshDescMap::GetInstance() {
static ProcessMeshDescMap g_process_mesh_desc_map;
return g_process_mesh_desc_map;
}

} // namespace framework
} // namespace paddle
65 changes: 65 additions & 0 deletions paddle/fluid/framework/process_mesh_desc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"

namespace paddle {
namespace framework {

class ProcessMeshDesc {
public:
ProcessMeshDesc(const std::vector<int32_t>& topo,
const std::vector<int32_t>& process_group, int32_t parent_id);

int32_t ID() const { return desc_.id(); }
int32_t Parent() const { return desc_.parent_id(); }

std::vector<int32_t> Topology() const;
std::vector<int32_t> ProcessGroup() const;

static int32_t next_id;

private:
proto::ProcessMeshDesc desc_; // not_own
};

class ProcessMeshDescMap {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be convenient to expose a member func to get a set containing all processes from this map.

Copy link
Author

@sandyhouse sandyhouse Aug 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the c++ side codes are not used, hence no more util-functions are defined. But in the future, we can extend more functions if needed.

public:
static ProcessMeshDescMap& GetInstance();

bool Has(int32_t index) const { return map_.find(index) != map_.end(); }

void Insert(int32_t index, ProcessMeshDesc* mesh) {
PADDLE_ENFORCE_NE(
Has(index), true,
platform::errors::AlreadyExists("Index (%d) has been used.", index));
map_.insert(std::make_pair(index, mesh));
}

private:
ProcessMeshDescMap() = default;
// Use raw pointer to avoid double free
std::unordered_map<int32_t, ProcessMeshDesc*> map_;
DISABLE_COPY_AND_ASSIGN(ProcessMeshDescMap);
};
} // namespace framework
} // namespace paddle
8 changes: 8 additions & 0 deletions paddle/fluid/framework/proto_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,13 @@ constexpr int kRootBlockIndex = 0;
// The Parent Index of root Block, this block does not exist.
constexpr int kNoneBlockIndex = -1;

// The Parent Index of root ProcessMesh, this ProcessMesh does not exist.
constexpr int kNoneProcessMeshIndex = -1;

// If a attribute name has a certain suffix, it means that the
// atrribute is a distributed-related attribute for auto parallel.
// e.g., "mesh_id@PARALLEL".
constexpr char kAutoParallelSuffix[] = "@PARALLEL";

} // namespace framework
} // namespace paddle
40 changes: 40 additions & 0 deletions paddle/fluid/framework/var_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,46 @@ std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
}
}

std::vector<std::string> VarDesc::AttrNames() const {
std::vector<std::string> retv;
retv.reserve(attrs_.size());
for (auto &attr : attrs_) {
retv.push_back(attr.first);
}
return retv;
}

void VarDesc::RemoveAttr(const std::string &name) { attrs_.erase(name); }

void VarDesc::SetAttr(const std::string &name, const Attribute &v) {
// NOTICE(sandyhouse): pybind11 will take the empty list in python as
// the std::vector<int> type in C++; so we have to change the attr's type
// here if we meet this issue
proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1);
if (attr_type == proto::AttrType::INTS &&
BOOST_GET_CONST(std::vector<int>, v).size() == 0u) {
// Find current attr via attr name and set the correct attribute value
this->attrs_[name] = std::vector<int>();
return;
}
bool valid = attr_type == proto::AttrType::INT ||
attr_type == proto::AttrType::STRING ||
attr_type == proto::AttrType::INTS;
PADDLE_ENFORCE_EQ(valid, true, platform::errors::InvalidArgument(
"The value for attr (%s) must be "
"one of list or int or string.",
name));

this->attrs_[name] = v;
}

Attribute VarDesc::GetAttr(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE_NE(it, attrs_.end(), platform::errors::NotFound(
"Attribute %s is not found.", name));
return it->second;
}

bool operator==(const VarDesc &left, const VarDesc &right) {
return left.Proto()->SerializeAsString() ==
right.Proto()->SerializeAsString();
Expand Down
23 changes: 23 additions & 0 deletions paddle/fluid/framework/var_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ limitations under the License. */
#include <vector>

#include "glog/logging.h"
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/type_defs.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -137,13 +139,34 @@ class VarDesc {
desc_.set_need_check_feed(need_check_feed);
}

bool HasAttr(const std::string &name) const {
return attrs_.find(name) != attrs_.end();
}

std::vector<std::string> AttrNames() const;

void SetAttr(const std::string &name, const Attribute &v);
void RemoveAttr(const std::string &name);

Attribute GetAttr(const std::string &name) const;

template <typename T>
T GetAttrIfExists(const std::string &name) const {
T result{};
if (HasAttr(name)) {
result = BOOST_GET_CONST(T, GetAttr(name));
}
return result;
}

private:
const proto::VarType::TensorDesc &tensor_desc() const;
std::vector<proto::VarType::TensorDesc> tensor_descs() const;
proto::VarType::TensorDesc *mutable_tensor_desc();
std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();

proto::VarDesc desc_;
AttributeMap attrs_;
};

bool operator==(const VarDesc &left, const VarDesc &right);
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/pybind/const_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/proto_desc.h"

#if defined(PADDLE_WITH_DGC)
#include "paddle/fluid/framework/details/dgc_const_values.h"
Expand All @@ -33,6 +34,9 @@ void BindConstValue(pybind11::module* m) {
m->def("kControlDepVarName",
[] { return framework::ir::Node::kControlDepVarName; });
m->def("kNewGradSuffix", [] { return framework::kNewGradSuffix; });
m->def("kAutoParallelSuffix", [] { return framework::kAutoParallelSuffix; });
m->def("kNoneProcessMeshIndex",
[] { return framework::kNoneProcessMeshIndex; });

auto op_proto_and_checker_maker =
m->def_submodule("op_proto_and_checker_maker");
Expand Down
19 changes: 18 additions & 1 deletion paddle/fluid/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License. */

#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/process_mesh_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/version.h"
Expand Down Expand Up @@ -84,6 +85,17 @@ void BindProgramDesc(pybind11::module *m) {
[](pd::ProgramDesc &self) -> int64_t { return self.Version(); });
}

void BindProcessMeshDesc(pybind11::module *m) {
pybind11::class_<pd::ProcessMeshDesc>(*m, "ProcessMeshDesc", "")
.def(pybind11::init<const std::vector<int32_t> &,
const std::vector<int32_t> &, int32_t>())
.def_property_readonly("id", &pd::ProcessMeshDesc::ID)
.def_property_readonly("parent", &pd::ProcessMeshDesc::Parent)
.def_property_readonly("topology", &pd::ProcessMeshDesc::Topology)
.def_property_readonly("process_group",
&pd::ProcessMeshDesc::ProcessGroup);
}

void BindBlockDesc(pybind11::module *m) {
pybind11::class_<pd::BlockDesc> blockdesc(*m, "BlockDesc", "");
g_blockdesc_pytype = (PyTypeObject *)blockdesc.ptr(); // NOLINT
Expand Down Expand Up @@ -184,7 +196,12 @@ void BindVarDsec(pybind11::module *m) {
.def("clear_stop_gradient", &pd::VarDesc::ClearStopGradient)
.def("has_stop_gradient", &pd::VarDesc::HasStopGradient)
.def("need_check_feed", &pd::VarDesc::NeedCheckFeed)
.def("set_need_check_feed", &pd::VarDesc::SetNeedCheckFeed);
.def("set_need_check_feed", &pd::VarDesc::SetNeedCheckFeed)
.def("has_attr", &pd::VarDesc::HasAttr)
.def("attr_names", &pd::VarDesc::AttrNames)
.def("_set_attr", &pd::VarDesc::SetAttr)
.def("remove_attr", &pd::VarDesc::RemoveAttr)
.def("attr", &pd::VarDesc::GetAttr);

pybind11::enum_<pd::proto::VarType::Type> vartype(var_desc, "VarType", "");
g_vartype_pytype = (PyTypeObject *)vartype.ptr(); // NOLINT
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/protobuf.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ void BindProgramDesc(pybind11::module* m);
void BindBlockDesc(pybind11::module* m);
void BindVarDsec(pybind11::module* m);
void BindOpDesc(pybind11::module* m);
void BindProcessMeshDesc(pybind11::module* m);

} // namespace pybind
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1998,6 +1998,7 @@ All parameter, weight, gradient are variables in Paddle.
BindOpDesc(&m);
BindConstValue(&m);
BindGlobalValueGetterSetter(&m);
BindProcessMeshDesc(&m);

py::class_<framework::LoDRankTable>(m, "LodRankTable")
.def("items", [](framework::LoDRankTable &table) {
Expand Down
15 changes: 14 additions & 1 deletion python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@
from .collective import send # noqa: F401
from .collective import wait # noqa: F401

from .auto_parallel import shard_tensor # noqa: F401
from .auto_parallel import shard_op # noqa: F401
from .auto_parallel import set_shard_mask # noqa: F401
from .auto_parallel import set_offload_device # noqa: F401
from .auto_parallel import set_pipeline_stage # noqa: F401
from .auto_parallel import ProcessMesh # noqa: F401

from .fleet import BoxPSDataset # noqa: F401

from .entry_attr import ProbabilityEntry # noqa: F401
Expand Down Expand Up @@ -69,5 +76,11 @@
"ReduceOp",
"wait",
"get_rank",
"ProbabilityEntry"
"ProbabilityEntry",
"shard_tensor",
"shard_op",
"set_shard_mask",
"set_offload_device",
"set_pipeline_stage",
"ProcessMesh",
]
Loading