Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug of RunWithExternalStream API in new executor #60122

Merged
merged 11 commits into from
Jan 5, 2024
49 changes: 49 additions & 0 deletions paddle/fluid/framework/new_executor/program_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PHI_DECLARE_bool(inference_switch_stream);
#endif
PD_DECLARE_bool(enable_host_event_recorder_hook);
PD_DECLARE_bool(log_memory_stats);
PHI_DECLARE_string(static_runtime_data_save_path);
Expand Down Expand Up @@ -163,6 +166,12 @@ FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
is_build_ = true;
is_shared_results_build_ = true;
} else {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (FLAGS_inference_switch_stream) {
UpdateDevCtx(&op_func_nodes);
FLAGS_inference_switch_stream = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个FLAGS在analysis_predictor中被设置,建议在相同的模块里重置,不要把设置和重置的操作隔离在两个完全不同的地方。另外,更建议直接在run接口新增参数开关传递是否需要重新构造ctx的信息,FLAGS的灵活性太高了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改,现弃用FLAG,run接口新增switch_stream参数传递是否需要重新构造ctx的信息。

}
#endif
RunImpl();
}

Expand Down Expand Up @@ -879,6 +888,46 @@ void ProgramInterpreter::Convert(
AnalyseExecuteOrderForTrace();
}

void ProgramInterpreter::UpdateDevCtx(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 看代码无法直接理解重新获取dev_ctx的目的以及在什么场景下需要这么做,建议在代码中添加注释说明。
  2. UpdateDevCtx命名无法体现这个函数所实际执行的操作,需要通过细看代码才能知道这个调用到底在做什么。这个函数里代码做的事情是把op_func_nodes清空,重新构建一遍,建议考虑build_op_func_noderebuild_op_func_node等更贴切的命名。
  3. 这份函数里的代码大部分是直接对Convert函数里的相关代码重复拷贝了一份,建议考虑代码复用,而不是直接大段拷贝。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

std::vector<paddle::framework::OpFuncNode>* op_func_nodes) {
auto nodes = *op_func_nodes;
auto op_nums = nodes.size();
vec_instruction_.clear();
vec_instruction_.reserve(op_nums);
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& op_func_node = nodes[op_idx];
stream_analyzer_.SetForceEventsToWaitInfo(force_evnets_to_wait_);
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
#ifdef PADDLE_WITH_CUDA
if (FLAGS_new_executor_use_cuda_graph) {
auto& op = op_func_node.operator_base_;
auto& op_type = op->Type();
if (op_type == interpreter::kMemcpyD2H ||
op_type == interpreter::kMemcpyH2D) {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Cuda memory copy d2h/h2d is not allowed while using cuda graph."));
}
PADDLE_ENFORCE_EQ(typeid(*dev_ctx_) == typeid(phi::GPUContext),
true,
platform::errors::InvalidArgument(
"Device context of op %s must be [%s] while using "
"cuda graph, but got [%s].",
op_type,
typeid(phi::GPUContext).name(),
typeid(*dev_ctx_).name()));
// cuda graph needs to record all stream
phi::backends::gpu::CUDAGraphContextManager::Instance()
.RecordCapturingDeviceContext(dev_ctx_);
}
#endif
vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_);

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
vec_instruction_.back().UpdataRecordStreamForGcInfo();
#endif
}
}

void ProgramInterpreter::BuildSkipShareLoDInfo() {
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
bool can_skip_lod = true;
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/new_executor/program_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class ProgramInterpreter : public InterpreterBaseImpl {
void BuildSkipShareLoDInfo();
void UpdateSyncOpNum();
void AnalyseExecuteOrderForTrace();
void UpdateDevCtx(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);

// inplace
void BuildInplace();
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@

PHI_DECLARE_bool(enable_pir_in_executor);
PHI_DECLARE_bool(pir_apply_inplace_pass);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PHI_DECLARE_bool(inference_switch_stream);
#endif

namespace paddle {
namespace {
Expand Down Expand Up @@ -2362,6 +2365,7 @@ bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) {
}));
auto &pool = paddle::experimental::DeviceContextPool::Instance();
pool.SyncDeviceContext(place_);
FLAGS_inference_switch_stream = true;
}

return ZeroCopyRun();
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/core/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,20 @@ PHI_DEFINE_EXPORTED_bool(new_executor_use_cuda_graph,
false,
"Use CUDA Graph in new executor");

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
/*
* Inference switch stream related FLAG
* Name: FLAGS_inference_switch_stream
* Since Version: 2.6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在新增的代码应该是在2.7版本发布了吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

* Value Range: bool, default=false
* Example: FLAGS_inference_switch_stream=true would switch
* It is possible for this flag to be set to true in RunWithExternalStream API.
*/
PHI_DEFINE_EXPORTED_bool(inference_switch_stream,
false,
"Swich stream when inference");
#endif

/*
* Executor related FLAG
* Name: FLAGS_executor_log_deps_every_microseconds
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/inference/api/analysis_predictor_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,7 @@ TEST(Tensor, RunWithExternalStream) {
cudaStream_t stream;
cudaStreamCreate(&stream);
config.SetExecStream(stream);
config.EnableNewExecutor();
auto predictor = CreatePredictor(config);

auto w0 = predictor->GetInputHandle("firstw");
Expand Down Expand Up @@ -703,8 +704,7 @@ TEST(Tensor, RunWithExternalStream) {

cudaStream_t external_stream;
cudaStreamCreate(&external_stream);
Config tmp_config(config);
tmp_config.SetExecStream(external_stream);

predictor->Run();
paddle_infer::experimental::InternalUtils::RunWithExternalStream(
predictor.get(), external_stream);
Expand Down