Skip to content

Commit

Permalink
[Inference] Fix bug of RunWithExternalStream API in new executor (#60122
Browse files Browse the repository at this point in the history
)

* fix bug of RunWithExternalStream API in new executor

* add test

* fix bug of RunWithExternalStream API in new executor

* reset flage in RunWithExternalStream

* fix bug

* add param swith_stream

* fix bug

* modify python api

* fix bug
  • Loading branch information
ming1753 authored Jan 5, 2024
1 parent 488f367 commit 2b86637
Show file tree
Hide file tree
Showing 16 changed files with 131 additions and 74 deletions.
6 changes: 4 additions & 2 deletions paddle/fluid/framework/naive_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,14 @@ void NaiveExecutor::PrepareInterpreterCore(
}

void NaiveExecutor::RunInterpreterCore(
const std::vector<std::string> &feed_names, bool need_fetch) {
const std::vector<std::string> &feed_names,
bool need_fetch,
bool switch_stream) {
platform::ScopedFlushDenormal flush;
#ifdef PADDLE_WITH_NVTX
platform::CudaNvtxRangePush("model", platform::NvtxRangeColor::Yellow);
#endif
interpreter_core_->Run(feed_names, need_fetch);
interpreter_core_->Run(feed_names, need_fetch, false, false, switch_stream);
#ifdef PADDLE_WITH_NVTX
platform::CudaNvtxRangePop();
#endif
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/naive_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class NaiveExecutor {
void Run();

void RunInterpreterCore(const std::vector<std::string>& feed_names = {},
bool need_fetch = false);
bool need_fetch = false,
bool switch_stream = false);

// Get an tensor to operating directly, without the need for feed_ops.
phi::DenseTensor* FindTensor(const std::string& name);
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/new_executor/interpreter_base_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ class InterpreterBaseImpl {
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors,
bool need_fetch = true,
bool enable_job_schedule_profiler = false) = 0;
bool enable_job_schedule_profiler = false,
bool switch_stream = false) = 0;

virtual paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names,
bool need_fetch = true,
bool enable_job_schedule_profiler = false,
bool enable_op_profiling = false) = 0;
bool enable_op_profiling = false,
bool switch_stream = false) = 0;

virtual void ShareWorkQueueFrom(InterpreterBaseImpl* src) = 0;

Expand Down
16 changes: 11 additions & 5 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,25 @@ FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors,
bool need_fetch,
bool enable_job_schedule_profiler) {
return impl_->Run(
feed_names, feed_tensors, need_fetch, enable_job_schedule_profiler);
bool enable_job_schedule_profiler,
bool switch_stream) {
return impl_->Run(feed_names,
feed_tensors,
need_fetch,
enable_job_schedule_profiler,
switch_stream);
}

FetchList InterpreterCore::Run(const std::vector<std::string>& feed_names,
bool need_fetch,
bool enable_job_schedule_profiler,
bool enable_op_profiling) {
bool enable_op_profiling,
bool switch_stream) {
return impl_->Run(feed_names,
need_fetch,
enable_job_schedule_profiler,
enable_op_profiling);
enable_op_profiling,
switch_stream);
}

void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src) {
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/new_executor/interpretercore.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ class InterpreterCore {
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors,
bool need_fetch = true,
bool enable_job_schedule_profiler = false);
bool enable_job_schedule_profiler = false,
bool switch_stream = false);

paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true,
bool enable_job_schedule_profiler = false,
bool enable_op_profiling = false);
bool enable_op_profiling = false,
bool switch_stream = false);

void RunProfile(const std::vector<std::string>& feed_names);

Expand Down
18 changes: 16 additions & 2 deletions paddle/fluid/framework/new_executor/pir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,8 @@ paddle::framework::FetchList PirInterpreter::Run(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors,
bool need_fetch,
bool enable_job_schedule_profiler) {
bool enable_job_schedule_profiler,
bool switch_stream) {
enable_job_schedule_profiler_ = enable_job_schedule_profiler;

auto FeedInput = [&] {
Expand Down Expand Up @@ -1318,6 +1319,12 @@ paddle::framework::FetchList PirInterpreter::Run(
is_build_ = true;
is_shared_results_build_ = true;
} else {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (switch_stream) {
BuildInstruction();
VLOG(4) << "Done BuildInstruction";
}
#endif
if (FLAGS_enable_pir_in_executor_trace_run || nccl_op_num_ > 1 ||
execution_config_.used_for_inference ||
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
Expand Down Expand Up @@ -1350,7 +1357,8 @@ paddle::framework::FetchList PirInterpreter::Run(
FetchList PirInterpreter::Run(const std::vector<std::string>& feed_names,
bool need_fetch,
bool enable_job_schedule_profiler,
bool enable_op_profiling) {
bool enable_op_profiling,
bool switch_stream) {
enable_job_schedule_profiler_ = enable_job_schedule_profiler;

if (enable_op_profiling) {
Expand Down Expand Up @@ -1401,6 +1409,12 @@ FetchList PirInterpreter::Run(const std::vector<std::string>& feed_names,
is_build_ = true;
is_shared_results_build_ = true;
} else {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (switch_stream) {
BuildInstruction();
VLOG(4) << "Done BuildInstruction";
}
#endif
if (FLAGS_enable_pir_in_executor_trace_run || nccl_op_num_ > 1 ||
execution_config_.used_for_inference ||
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/new_executor/pir_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@ class PirInterpreter : public InterpreterBaseImpl {
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors,
bool need_fetch = true,
bool enable_job_schedule_profiler = false) override;
bool enable_job_schedule_profiler = false,
bool switch_stream = false) override;

paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true,
bool enable_job_schedule_profiler = false,
bool enable_op_profiling = false) override;
bool enable_op_profiling = false,
bool switch_stream = false) override;

void ShareWorkQueueFrom(InterpreterBaseImpl* src) override;

Expand Down
96 changes: 55 additions & 41 deletions paddle/fluid/framework/new_executor/program_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ void ProgramInterpreter::RunImpl() {
FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
bool need_fetch,
bool enable_job_schedule_profiler,
bool enable_op_profiling) {
bool enable_op_profiling,
bool switch_stream) {
enable_job_schedule_profiler_ = enable_job_schedule_profiler;
is_in_op_profiling_mode_ = enable_op_profiling;

Expand All @@ -163,6 +164,11 @@ FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
is_build_ = true;
is_shared_results_build_ = true;
} else {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (switch_stream) {
BuildOpFuncNode(&op_func_nodes);
}
#endif
RunImpl();
}

Expand Down Expand Up @@ -233,7 +239,8 @@ FetchList ProgramInterpreter::Run(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors,
bool need_fetch,
bool enable_job_schedule_profiler) {
bool enable_job_schedule_profiler,
bool switch_stream) {
enable_job_schedule_profiler_ = enable_job_schedule_profiler;

SetDeviceId(place_);
Expand All @@ -244,7 +251,7 @@ FetchList ProgramInterpreter::Run(
#endif

bool is_build = is_build_;
Prepare(feed_names, feed_tensors, is_build);
Prepare(feed_names, feed_tensors, is_build, switch_stream);

if (is_build) {
RunImpl();
Expand Down Expand Up @@ -671,42 +678,7 @@ std::tuple<double, double> ProgramInterpreter::InterpreterRunTime() {
void ProgramInterpreter::Convert(
std::vector<paddle::framework::OpFuncNode>* op_func_nodes) {
auto& vec_meta_info = var_scope_.MutableVecMetaInfo();
auto nodes = *op_func_nodes;
auto op_nums = nodes.size();
vec_instruction_.clear();
vec_instruction_.reserve(op_nums);
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& op_func_node = nodes[op_idx];
stream_analyzer_.SetForceEventsToWaitInfo(force_evnets_to_wait_);
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
#ifdef PADDLE_WITH_CUDA
if (FLAGS_new_executor_use_cuda_graph) {
auto& op = op_func_node.operator_base_;
auto& op_type = op->Type();
if (op_type == interpreter::kMemcpyD2H ||
op_type == interpreter::kMemcpyH2D) {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Cuda memory copy d2h/h2d is not allowed while using cuda graph."));
}
PADDLE_ENFORCE_EQ(typeid(*dev_ctx_) == typeid(phi::GPUContext),
true,
platform::errors::InvalidArgument(
"Device context of op %s must be [%s] while using "
"cuda graph, but got [%s].",
op_type,
typeid(phi::GPUContext).name(),
typeid(*dev_ctx_).name()));
// cuda graph needs to record all stream
phi::backends::gpu::CUDAGraphContextManager::Instance()
.RecordCapturingDeviceContext(dev_ctx_);
}
#endif
vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_);

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
vec_instruction_.back().UpdataRecordStreamForGcInfo();
#endif
}
BuildOpFuncNode(op_func_nodes);

BuildOperatorDependences();

Expand Down Expand Up @@ -743,6 +715,7 @@ void ProgramInterpreter::Convert(
}

// calculate last_live_ops_
auto op_nums = (*op_func_nodes).size();
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
Instruction& instr = vec_instruction_[op_idx];
OpInOutInfo info;
Expand Down Expand Up @@ -879,6 +852,46 @@ void ProgramInterpreter::Convert(
AnalyseExecuteOrderForTrace();
}

void ProgramInterpreter::BuildOpFuncNode(
std::vector<paddle::framework::OpFuncNode>* op_func_nodes) {
auto nodes = *op_func_nodes;
auto op_nums = nodes.size();
vec_instruction_.clear();
vec_instruction_.reserve(op_nums);
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& op_func_node = nodes[op_idx];
stream_analyzer_.SetForceEventsToWaitInfo(force_evnets_to_wait_);
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
#ifdef PADDLE_WITH_CUDA
if (FLAGS_new_executor_use_cuda_graph) {
auto& op = op_func_node.operator_base_;
auto& op_type = op->Type();
if (op_type == interpreter::kMemcpyD2H ||
op_type == interpreter::kMemcpyH2D) {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Cuda memory copy d2h/h2d is not allowed while using cuda graph."));
}
PADDLE_ENFORCE_EQ(typeid(*dev_ctx_) == typeid(phi::GPUContext),
true,
platform::errors::InvalidArgument(
"Device context of op %s must be [%s] while using "
"cuda graph, but got [%s].",
op_type,
typeid(phi::GPUContext).name(),
typeid(*dev_ctx_).name()));
// cuda graph needs to record all stream
phi::backends::gpu::CUDAGraphContextManager::Instance()
.RecordCapturingDeviceContext(dev_ctx_);
}
#endif
vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_);

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
vec_instruction_.back().UpdataRecordStreamForGcInfo();
#endif
}
}

void ProgramInterpreter::BuildSkipShareLoDInfo() {
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
bool can_skip_lod = true;
Expand Down Expand Up @@ -1494,7 +1507,8 @@ void ProgramInterpreter::CheckGC(const Instruction& instr) {
void ProgramInterpreter::Prepare(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors,
bool prepare_feed) {
bool prepare_feed,
bool switch_stream) {
PADDLE_ENFORCE_EQ(feed_names.size(),
feed_tensors.size(),
platform::errors::PreconditionNotMet(
Expand All @@ -1517,7 +1531,7 @@ void ProgramInterpreter::Prepare(
}
};

if (!is_build_) {
if (!is_build_ || switch_stream) {
paddle::framework::interpreter::BuildVariableScope(
block_, execution_config_, &var_scope_);
FeedInput();
Expand Down
11 changes: 8 additions & 3 deletions paddle/fluid/framework/new_executor/program_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ class ProgramInterpreter : public InterpreterBaseImpl {
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors,
bool need_fetch = true,
bool enable_job_schedule_profiler = false) override;
bool enable_job_schedule_profiler = false,
bool switch_stream = false) override;

paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true,
bool enable_job_schedule_profiler = false,
bool enable_op_profiling = false) override;
bool enable_op_profiling = false,
bool switch_stream = false) override;

std::shared_ptr<ProgramDesc> GetMutableCopyProgram() override;

Expand Down Expand Up @@ -125,6 +127,8 @@ class ProgramInterpreter : public InterpreterBaseImpl {
void BuildSkipShareLoDInfo();
void UpdateSyncOpNum();
void AnalyseExecuteOrderForTrace();
void BuildOpFuncNode(
std::vector<paddle::framework::OpFuncNode>* op_func_nodes);

// inplace
void BuildInplace();
Expand All @@ -150,7 +154,8 @@ class ProgramInterpreter : public InterpreterBaseImpl {
// only used when program contains no feed op
void Prepare(const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors,
bool prepare_feed);
bool prepare_feed,
bool switch_stream = false);

void RecordMemcpyD2H(const Instruction& instr_node);

Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2249,7 +2249,7 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetOutputTensor(
return res;
}

bool AnalysisPredictor::ZeroCopyRun() {
bool AnalysisPredictor::ZeroCopyRun(bool switch_stream) {
inference::DisplayMemoryInfo(place_, "before run");
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
if (config_.dist_config().use_dist_model()) {
Expand Down Expand Up @@ -2312,7 +2312,7 @@ bool AnalysisPredictor::ZeroCopyRun() {
#endif

if (config_.new_executor_enabled()) {
executor_->RunInterpreterCore();
executor_->RunInterpreterCore({}, false, switch_stream);
} else {
executor_->Run();
}
Expand Down Expand Up @@ -2353,7 +2353,7 @@ bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) {
"Please use config.SetExecStream to init gpu resources, and then we "
"will bind gpu resources to execution stream."));
}

bool switch_stream = false;
if (stream != predictor_stream_) {
#ifdef PADDLE_WITH_HIP
hipStreamSynchronize(static_cast<gpuStream_t>(predictor_stream_));
Expand Down Expand Up @@ -2383,9 +2383,9 @@ bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) {
}));
auto &pool = paddle::experimental::DeviceContextPool::Instance();
pool.SyncDeviceContext(place_);
switch_stream = true;
}

return ZeroCopyRun();
return ZeroCopyRun(switch_stream);
}
#endif

Expand Down
Loading

0 comments on commit 2b86637

Please sign in to comment.