-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix bug of RunWithExternalStream API in new executor #60122
Changes from 2 commits
a4a378d
a944a65
f0644ac
ce2f9e8
f61d24a
08b692b
59b39f2
a84ff38
e560c8a
e5ee225
1672bb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,6 +41,9 @@ | |
PHI_DECLARE_bool(dynamic_static_unified_comm); | ||
#endif | ||
|
||
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
PHI_DECLARE_bool(inference_switch_stream); | ||
#endif | ||
PD_DECLARE_bool(enable_host_event_recorder_hook); | ||
PD_DECLARE_bool(log_memory_stats); | ||
PHI_DECLARE_string(static_runtime_data_save_path); | ||
|
@@ -163,6 +166,12 @@ FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names, | |
is_build_ = true; | ||
is_shared_results_build_ = true; | ||
} else { | ||
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
if (FLAGS_inference_switch_stream) { | ||
UpdateDevCtx(&op_func_nodes); | ||
FLAGS_inference_switch_stream = false; | ||
} | ||
#endif | ||
RunImpl(); | ||
} | ||
|
||
|
@@ -879,6 +888,46 @@ void ProgramInterpreter::Convert( | |
AnalyseExecuteOrderForTrace(); | ||
} | ||
|
||
void ProgramInterpreter::UpdateDevCtx( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
std::vector<paddle::framework::OpFuncNode>* op_func_nodes) { | ||
auto nodes = *op_func_nodes; | ||
auto op_nums = nodes.size(); | ||
vec_instruction_.clear(); | ||
vec_instruction_.reserve(op_nums); | ||
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { | ||
auto& op_func_node = nodes[op_idx]; | ||
stream_analyzer_.SetForceEventsToWaitInfo(force_evnets_to_wait_); | ||
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); | ||
#ifdef PADDLE_WITH_CUDA | ||
if (FLAGS_new_executor_use_cuda_graph) { | ||
auto& op = op_func_node.operator_base_; | ||
auto& op_type = op->Type(); | ||
if (op_type == interpreter::kMemcpyD2H || | ||
op_type == interpreter::kMemcpyH2D) { | ||
PADDLE_THROW(paddle::platform::errors::Fatal( | ||
"Cuda memory copy d2h/h2d is not allowed while using cuda graph.")); | ||
} | ||
PADDLE_ENFORCE_EQ(typeid(*dev_ctx_) == typeid(phi::GPUContext), | ||
true, | ||
platform::errors::InvalidArgument( | ||
"Device context of op %s must be [%s] while using " | ||
"cuda graph, but got [%s].", | ||
op_type, | ||
typeid(phi::GPUContext).name(), | ||
typeid(*dev_ctx_).name())); | ||
// cuda graph needs to record all stream | ||
phi::backends::gpu::CUDAGraphContextManager::Instance() | ||
.RecordCapturingDeviceContext(dev_ctx_); | ||
} | ||
#endif | ||
vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_); | ||
|
||
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
vec_instruction_.back().UpdataRecordStreamForGcInfo(); | ||
#endif | ||
} | ||
} | ||
|
||
void ProgramInterpreter::BuildSkipShareLoDInfo() { | ||
for (size_t i = 0; i < vec_instruction_.size(); ++i) { | ||
bool can_skip_lod = true; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1108,6 +1108,20 @@ PHI_DEFINE_EXPORTED_bool(new_executor_use_cuda_graph, | |
false, | ||
"Use CUDA Graph in new executor"); | ||
|
||
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
/* | ||
* Inference switch stream related FLAG | ||
* Name: FLAGS_inference_switch_stream | ||
* Since Version: 2.6 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 现在新增的代码应该是在2.7版本发布了吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
* Value Range: bool, default=false | ||
* Example: FLAGS_inference_switch_stream=true would switch | ||
* It is possible for this flag to be set to true in RunWithExternalStream API. | ||
*/ | ||
PHI_DEFINE_EXPORTED_bool(inference_switch_stream, | ||
false, | ||
"Swich stream when inference"); | ||
#endif | ||
|
||
/* | ||
* Executor related FLAG | ||
* Name: FLAGS_executor_log_deps_every_microseconds | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个FLAGS在
analysis_predictor
中被设置,建议在相同的模块里重置,不要把设置和重置的操作隔离在两个完全不同的地方。另外,更建议直接在run接口新增参数开关传递是否需要重新构造ctx的信息,FLAGS的灵活性太高了。There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改,现弃用FLAG,run接口新增switch_stream参数传递是否需要重新构造ctx的信息。