Skip to content

Commit

Permalink
Polish reference count pass (#21324)
Browse files Browse the repository at this point in the history
* fix ref_cnt pass, test=develop

* add cpp unittests to reference_count_pass, test=develop

* follow comments, test=develop
  • Loading branch information
sneaxiy authored Nov 28, 2019
1 parent b39f947 commit 8996652
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 111 deletions.
14 changes: 12 additions & 2 deletions paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor place)
dynload_cuda variable_visitor place device_memory_aligment)

if(WITH_DGC)
nv_library(sparse_all_reduce_op_handle SRCS sparse_all_reduce_op_handle.cc DEPS op_handle_base scope
Expand All @@ -46,7 +46,7 @@ else()
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor)
cc_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor place)
variable_visitor place device_memory_aligment)
if(WITH_DISTRIBUTE)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor sendrecvop_rpc)
Expand Down Expand Up @@ -103,4 +103,14 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
lock_free_optimize_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass
pass_builder
${NGRAPH_BS_DEPS})

if (WITH_MKLDNN)
target_link_libraries(build_strategy mkldnn_placement_pass)
endif()

if (WITH_NGRAPH)
target_link_libraries(build_strategy ngraph_subgraph_pass)
endif()
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handl

cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass)
cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass)

cc_test(test_reference_count_pass_last_lived_ops SRCS test_reference_count_pass_last_lived_ops.cc DEPS parallel_executor elementwise_mul_op elementwise_add_op scale_op)
146 changes: 53 additions & 93 deletions paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,57 +202,15 @@ static bool ShrinkNoNeedBufferVarOpDependency(
}
}

/**
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
*/
static details::ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
details::OpHandleBase *op, size_t scope_idx) {
std::queue<details::OpHandleBase *> q;
std::unordered_set<details::OpHandleBase *> visited;
q.push(op);
while (!q.empty()) {
auto *op = q.front();
q.pop();
auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op;
}
for (auto *out_var : op->Outputs()) {
for (auto *pending_op : out_var->PendingOps()) {
if (visited.count(pending_op)) continue;
visited.insert(pending_op);
q.push(pending_op);
}
}
}
return nullptr;
}

enum LastLiveOpSearchStatus { kSuccess, kFailure, kShouldPrecede };
enum LastLiveOpSearchStatus { kSuccess, kFailure };

static std::unordered_set<details::ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx,
const std::string &var_name,
const ShrinkDepsOpFunctor &shrink_func,
LastLiveOpSearchStatus *status) {
// stage one. Get last op for variable.
std::unordered_set<details::OpHandleBase *> candidates;
{
if (var->PendingOps().empty() && var->GeneratedOp()) {
// No operator depends on this variable. So the last operator is the op
// who generates this variable.
candidates.emplace(var->GeneratedOp());
} else {
candidates = var->PendingOps();
}

// No pending ops or generated op is nullptr
if (candidates.empty()) {
*status = LastLiveOpSearchStatus::kFailure;
return {};
}
}
auto candidates = var->PendingOps();

// stage two. Try to cast them to computation op.
// return (*status=kFailure) when failed.
Expand All @@ -262,37 +220,41 @@ ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx,
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
std::unordered_set<details::ComputationOpHandle *> computation_op;
std::unordered_set<details::ComputationOpHandle *> computation_ops;
{
for (auto *op : candidates) {
auto *compute_op =
FindNextComputationOpHandleOrReturnItself(op, scope_idx);
if (compute_op == nullptr) {
auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op && compute_op->GetScopeIdx() == scope_idx) {
computation_ops.emplace(compute_op);
} else {
*status = LastLiveOpSearchStatus::kFailure;
return {};
}
computation_op.emplace(compute_op);
}

auto *generated_op =
dynamic_cast<details::ComputationOpHandle *>(var->GeneratedOp());
if (generated_op && generated_op->GetScopeIdx() == scope_idx) {
computation_ops.emplace(generated_op);
}
}

// stage three. Try to shrink computation op if any of them does
// not need the buffer of var_name.
// If all computation ops do not need the buffer of var_name,
// return empty computation op set, and mark the status as kShouldPrecede,
// which means that the last living ops of var_name should be
// found in the previous version of var_name.
if (ShrinkNoNeedBufferVarOpDependency(var_name, &computation_op)) {
*status = LastLiveOpSearchStatus::kShouldPrecede;
if (computation_ops.empty() ||
ShrinkNoNeedBufferVarOpDependency(var_name, &computation_ops)) {
*status = LastLiveOpSearchStatus::kFailure;
return {};
}

PADDLE_ENFORCE(!computation_op.empty(),
"Computation ops should not be empty");
PADDLE_ENFORCE_EQ(
computation_ops.empty(), false,
platform::errors::InvalidArgument("Computation ops should not be empty"));

// stage four. Try to shrink computation op if they depend on each other.
// Get the smallest set of the most ops.
*status = LastLiveOpSearchStatus::kSuccess;
return shrink_func(computation_op);
return shrink_func(computation_ops);
}

void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
Expand Down Expand Up @@ -344,47 +306,45 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {

PADDLE_ENFORCE_EQ(var_desc->Name(), var_name);

for (auto iter = var_handles.rbegin(); iter != var_handles.rend();
++iter) {
if ((*iter)->Node()->IsCtrlVar()) {
break;
}
PADDLE_ENFORCE_EQ(
var_handles.empty(), false,
platform::errors::InvalidArgument("Variable %s not found", var_name));
auto last_ver_var = var_handles.back();

VLOG(10) << "Try to find last living ops of " << var_name << " "
<< (iter - var_handles.rbegin()) << " time";
LastLiveOpSearchStatus status = LastLiveOpSearchStatus::kFailure;
auto result = ExtractComputationOpFromLastLivedVar(
*iter, i, var_name, shrink_func, &status);

// Seldomly, some vars may have no pending or preceding computation ops
// Just break;
if (status == LastLiveOpSearchStatus::kFailure) {
VLOG(1) << "Cannot find last live ops of variable " << var_name
<< " in scope " << (*iter)->scope_idx();
break;
}
if (last_ver_var->Node()->IsCtrlVar()) {
continue;
}

if (status == LastLiveOpSearchStatus::kShouldPrecede) {
VLOG(10) << "Try to precede reference count computing at var "
<< var_name;
continue;
}
LastLiveOpSearchStatus status = LastLiveOpSearchStatus::kFailure;
auto result = ExtractComputationOpFromLastLivedVar(
last_ver_var, i, var_name, shrink_func, &status);

// Seldomly, some vars may have no pending or preceding computation ops
// Just break;
if (status == LastLiveOpSearchStatus::kFailure) {
VLOG(1) << "Cannot find last live ops of variable " << var_name
<< " in scope " << last_ver_var->scope_idx();
continue;
}

PADDLE_ENFORCE_EQ(status, LastLiveOpSearchStatus::kSuccess);
PADDLE_ENFORCE(!result.empty(), "Last living ops of %s cannot be empty",
var_name);
PADDLE_ENFORCE_EQ(
status, LastLiveOpSearchStatus::kSuccess,
platform::errors::InvalidArgument("status must be success"));
PADDLE_ENFORCE_EQ(result.empty(), false,
platform::errors::NotFound(
"Last living ops of %s cannot be empty", var_name));

VLOG(10) << "Extract " << result.size() << " ops of var " << var_name;
var_infos[i][var_name].reset(
new MemOptVarInfo(var_name, result.size()));
auto &last_live_ops_of_var = last_live_ops_of_vars[i][var_name];
last_live_ops_of_var.set_var(*iter);
*(last_live_ops_of_var.mutable_ops()) = std::move(result);
break;
std::string last_live_ops_log_str;
for (auto &each_ret : result) {
last_live_ops_log_str += (" " + each_ret->GetOp()->Type());
}
VLOG(10) << "Extract " << result.size() << " ops of var " << var_name
<< " : " << last_live_ops_log_str;

// Seldomly, all preceding trying failed.
// Just skip this corner case
var_infos[i][var_name].reset(new MemOptVarInfo(var_name, result.size()));
auto &last_live_ops_of_var = last_live_ops_of_vars[i][var_name];
last_live_ops_of_var.set_var(last_ver_var);
*(last_live_ops_of_var.mutable_ops()) = std::move(result);
}
}
}
Expand Down
Loading

0 comments on commit 8996652

Please sign in to comment.