Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multi gpu dist train #10143

Merged
merged 2 commits into from
Apr 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions paddle/fluid/framework/details/multi_devices_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,33 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
}
}

bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
OpDesc *send_op) const {
if (send_op == nullptr) {
return false;
}

auto checker = [&](const std::vector<std::string> opvars,
const std::vector<std::string> sendvars) -> bool {
bool is_dist_train_op = false;
for (auto &var : opvars) {
if (var.find(".block") != std::string::npos &&
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) {
is_dist_train_op = true;
break;
}
}
return is_dist_train_op;
};

if (op.Type() == "split") {
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames());
} else if (op.Type() == "concat") {
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames());
}
return false;
}

std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const {
auto graph = new SSAGraph();
Expand All @@ -89,19 +116,30 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size());

// Find "send" op first for split is in front of send.
OpDesc *send_op = nullptr;
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
send_op = op;
break;
}
}

bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
// append send op if program is distributed trainer main program.
// always use the first device
CreateSendOp(&result, *op);
} else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1);
} else if (IsScaleLossOp(*op)) {
if (!skip_scale_loss_) {
CreateScaleLossGradOp(&result);
}
is_forwarding = false;
} else {
CreateComputationalOps(&result, *op);
CreateComputationalOps(&result, *op, places_.size());
if (!is_forwarding) {
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. But there are no
Expand Down Expand Up @@ -199,8 +237,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
}

void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
const OpDesc &op) const {
for (size_t scope_idx = 0; scope_idx < places_.size(); ++scope_idx) {
const OpDesc &op,
size_t num_places) const {
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx];
result->ops_.emplace_back(new ComputationOpHandle(op, s, p));
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/details/multi_devices_graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {

void CreateSendOp(SSAGraph *result, const OpDesc &op) const;

void CreateComputationalOps(SSAGraph *result, const OpDesc &op) const;
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const;

void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
size_t num_places) const;

void CreateScaleLossGradOp(SSAGraph *result) const;

Expand Down