-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add the basic apis for auto_parallel #33804
Changes from 50 commits
b985745
b79e749
1671850
ec55a43
25abc00
bf24fb7
8ea9363
9e4b3d8
e65f77e
8b95c1e
ccae6ae
d107751
f7e70ea
3111159
70cdb69
9e5b0f0
59936ef
27ee413
3a8ceef
d11f317
f5ef245
7293b4f
1240edc
05455fb
3e1b3a0
8950c35
b94a9f2
e121349
fe51aa3
4563d42
192580d
2e69980
608dd3f
cb9b6bf
8e6559e
00f5f4d
1aa94da
97a446c
c586fc6
fc6cde9
967d0e7
cd1e390
f48ec91
b07affa
a00fe9e
da9fe30
3daecf2
5640879
05b0f82
fe93d0e
9856d47
890c70c
4b90b03
c395b84
8390e01
f7d5631
3a2666e
773516b
685504f
7ac6299
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/framework/process_mesh_desc.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
int32_t ProcessMeshDesc::next_id = -1; | ||
|
||
ProcessMeshDesc::ProcessMeshDesc(const std::vector<int32_t> &topo, | ||
const std::vector<int32_t> &process_group, | ||
int32_t parent_id) { | ||
int32_t cur_id = ++next_id; | ||
desc_.set_id(cur_id); | ||
desc_.set_parent_id(parent_id); | ||
for (size_t i = 0; i != topo.size(); ++i) { | ||
desc_.add_topology(topo[i]); | ||
} | ||
for (size_t i = 0; i != process_group.size(); ++i) { | ||
desc_.add_process_group(process_group[i]); | ||
} | ||
ProcessMeshDescMap::GetInstance().Insert(cur_id, this); | ||
} | ||
|
||
std::vector<int32_t> ProcessMeshDesc::Topology() const { | ||
size_t size = desc_.topology_size(); | ||
std::vector<int32_t> ret(size); | ||
for (auto i = 0; i != desc_.topology_size(); ++i) { | ||
ret[i] = desc_.topology(i); | ||
} | ||
return ret; | ||
} | ||
|
||
std::vector<int32_t> ProcessMeshDesc::ProcessGroup() const { | ||
size_t size = desc_.process_group_size(); | ||
std::vector<int32_t> ret(size); | ||
for (auto i = 0; i != desc_.process_group_size(); ++i) { | ||
ret[i] = desc_.process_group(i); | ||
} | ||
return ret; | ||
} | ||
|
||
ProcessMeshDescMap &ProcessMeshDescMap::GetInstance() { | ||
static ProcessMeshDescMap g_process_mesh_desc_map; | ||
return g_process_mesh_desc_map; | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include <unordered_map> | ||
#include <vector> | ||
|
||
#include "paddle/fluid/framework/framework.pb.h" | ||
#include "paddle/fluid/framework/proto_desc.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
#include "paddle/fluid/platform/macros.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
class ProcessMeshDesc { | ||
public: | ||
ProcessMeshDesc(const std::vector<int32_t>& topo, | ||
const std::vector<int32_t>& process_group, int32_t parent_id); | ||
|
||
int32_t ID() const { return desc_.id(); } | ||
int32_t Parent() const { return desc_.parent_id(); } | ||
|
||
std::vector<int32_t> Topology() const; | ||
std::vector<int32_t> ProcessGroup() const; | ||
|
||
static int32_t next_id; | ||
|
||
private: | ||
proto::ProcessMeshDesc desc_; // not_own | ||
}; | ||
|
||
class ProcessMeshDescMap { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will be convenient to expose a member func to get a set containing all processes from this map. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now the c++ side codes are not used, hence no more util-functions are defined. But in the future, we can extend more functions if needed. |
||
public: | ||
static ProcessMeshDescMap& GetInstance(); | ||
|
||
bool Has(int32_t index) const { return map_.find(index) != map_.end(); } | ||
|
||
void Insert(int32_t index, ProcessMeshDesc* mesh) { | ||
PADDLE_ENFORCE_NE( | ||
Has(index), true, | ||
platform::errors::AlreadyExists("Index (%d) has been used.", index)); | ||
map_.insert(std::make_pair(index, mesh)); | ||
} | ||
|
||
private: | ||
ProcessMeshDescMap() = default; | ||
// Use raw pointer to avoid double free | ||
std::unordered_map<int32_t, ProcessMeshDesc*> map_; | ||
DISABLE_COPY_AND_ASSIGN(ProcessMeshDescMap); | ||
}; | ||
} // namespace framework | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些新增的字段,在保存模型的时候,会被存下来吗?
我看示例代码,模型定义的时候就会添加这些字段,模型定义完再调用模型保存的时候,是不是会把这些字段都保存下来?什么时候把这些字段去掉呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
自动并行主要包括以下几个主要过程:1. 使用自动并行接口标识关键tensor或op;2. 自动补全:补全所有tensor和op的分布式属性;3. 逻辑切分;4. 物理映射;5. 执行训练。其中步骤1-3会使用到此处新增的字段;所以该接口新增的字段会在步骤1-3完成后删除,且该过程用户无感知。
常规的模型保存过程是 执行部分训练或全部训练完成后进行模型保存,这时,新增字段已经完全删除。
但存在一个特殊的情形,即用户完成组网后即刻保存模型,这时相关的字段会被保存下来。但我们认为,这种特殊情形是不应该存在的,因为完成组网后即保存模型是没有意义的。