Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions onnxruntime/core/common/profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ void Profiler::Initialize(const logging::Logger* session_logger) {

void Profiler::StartProfiling(const logging::Logger* custom_logger) {
ORT_ENFORCE(custom_logger != nullptr);
enabled_ = true;
profile_with_logger_ = true;
custom_logger_ = custom_logger;
profiling_start_time_ = StartTime();
Expand All @@ -35,13 +36,11 @@ void Profiler::EndTimeAndRecordEvent(EventCategory category,
TimePoint& start_time,
const std::initializer_list<std::pair<std::string, std::string>>& event_args,
bool /*sync_gpu*/) {
if (!enabled_ && !profile_with_logger_)
return;
long long dur = TimeDiffMicroSeconds(start_time);
long long ts = TimeDiffMicroSeconds(profiling_start_time_, start_time);

EventRecord event(category, logging::GetProcessId(),
logging::GetThreadId(), event_name, ts, dur, { event_args.begin(), event_args.end() });
logging::GetThreadId(), event_name, ts, dur, {event_args.begin(), event_args.end()});
if (profile_with_logger_) {
custom_logger_->SendProfileEvent(event);
} else {
Expand Down Expand Up @@ -99,12 +98,5 @@ std::string Profiler::EndProfiling() {
return profile_stream_file_;
}

//
// Conditionally sync the GPU if the syncGPU flag is set.
//
void ProfilerSyncGpu() {
ORT_NOT_IMPLEMENTED("Needs to implement only for gpus");
}

} // namespace profiling
} // namespace onnxruntime
4 changes: 4 additions & 0 deletions onnxruntime/core/common/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class Profiler {
*/
TimePoint StartTime() const;

bool FEnabled() const {
return enabled_;
}

/*
Record a single event. Time is measured till the call of this function from
the start_time.
Expand Down
64 changes: 38 additions & 26 deletions onnxruntime/core/framework/parallel_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ Status ParallelExecutor::Execute(const SessionState& session_state,
const std::vector<std::string>& output_names,
std::vector<MLValue>& fetches,
const logging::Logger& logger) {
auto tp = session_state.Profiler().StartTime();
TimePoint tp;
bool f_profiler_enabled = session_state.Profiler().FEnabled();
if (f_profiler_enabled) {
tp = session_state.Profiler().StartTime();
}

root_frame_ = std::make_unique<ExecutionFrame>(feeds, output_names, fetches, session_state);
//std::cout << "start nodes:" << std::endl;
Expand Down Expand Up @@ -72,7 +76,9 @@ Status ParallelExecutor::Execute(const SessionState& session_state,
}
}

session_state.Profiler().EndTimeAndRecordEvent(profiling::SESSION_EVENT, "ParallelExecutor::Execute", tp);
if (f_profiler_enabled) {
session_state.Profiler().EndTimeAndRecordEvent(profiling::SESSION_EVENT, "ParallelExecutor::Execute", tp);
}
return Status::OK();
}

Expand All @@ -83,7 +89,7 @@ void ParallelExecutor::RunNodeAsync(size_t p_node_index,
RunNodeAsyncInternal(p_node_index, session_state, logger);
} catch (...) {
FinishNodeRun();
throw;
throw;
}
}

Expand All @@ -95,6 +101,9 @@ void ParallelExecutor::RunNodeAsyncInternal(size_t p_node_index,
size_t node_index = p_node_index;
bool keep_running = true;
auto graph_viewer = session_state.GetGraphViewer();
TimePoint sync_time_begin;
TimePoint kernel_begin_time;
bool f_profiler_enabled = session_state.Profiler().FEnabled();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bool profiler_enabled = session_state.Profiler().Enabled() reads better?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same with the function name, f means flag here.

// Avoid context switching if possible.
while (keep_running) {
// TODO: Convert RunNodeAsync return Status.
Expand All @@ -109,14 +118,16 @@ void ParallelExecutor::RunNodeAsyncInternal(size_t p_node_index,
// if a kernel has been added in the session state, it better be NON-null.
if (p_op_kernel == nullptr) {
ORT_THROW("Got nullptr from GetKernel for node: ",
graph_viewer->GetNode(node_index)->Name());
graph_viewer->GetNode(node_index)->Name());
}

OpKernelContextInternal op_kernel_context(*root_frame_, *p_op_kernel, logger,
p_op_kernel->Node().ImplicitInputDefs(),
terminate_flag_);

auto sync_time_begin = session_state.Profiler().StartTime();
if (f_profiler_enabled) {
sync_time_begin = session_state.Profiler().StartTime();
}
// sync before compute
int queue_id = p_op_kernel->KernelDef().ExecQueueId();

Expand All @@ -141,31 +152,31 @@ void ParallelExecutor::RunNodeAsyncInternal(size_t p_node_index,
}
}

const std::string& node_name = p_op_kernel->Node().Name();
const std::string& op_name = p_op_kernel->KernelDef().OpName();
if (f_profiler_enabled) {
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
p_op_kernel->Node().Name() + "_fence_before",
sync_time_begin,
{{"op_name", p_op_kernel->KernelDef().OpName()}});

session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name + "_fence_before",
sync_time_begin,
{{"op_name", op_name}});
kernel_begin_time = session_state.Profiler().StartTime();
}

// call compute on the kernel
VLOGS(logger, 1) << "Computing kernel: " << p_op_kernel->Node().Name();

auto kernel_begin_time = session_state.Profiler().StartTime();

// Execute the kernel.
auto status = p_op_kernel->Compute(&op_kernel_context);
if (!status.IsOK()) {
ORT_THROW("Compute failed for node: ", graph_viewer->GetNode(node_index)->Name());
}
if (f_profiler_enabled) {
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
p_op_kernel->Node().Name() + "_kernel_time",
kernel_begin_time,
{{"op_name", p_op_kernel->KernelDef().OpName()}});

session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name + "_kernel_time",
kernel_begin_time,
{{"op_name", op_name}});

sync_time_begin = session_state.Profiler().StartTime();
sync_time_begin = session_state.Profiler().StartTime();
}
// sync after compute for outputs
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
Expand All @@ -187,11 +198,12 @@ void ParallelExecutor::RunNodeAsyncInternal(size_t p_node_index,
fence->AfterUsedAsOutput(queue_id);
}
}
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name + "_fence_after",
sync_time_begin,
{{"op_name", op_name}});

if (f_profiler_enabled) {
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
p_op_kernel->Node().Name() + "_fence_after",
sync_time_begin,
{{"op_name", p_op_kernel->KernelDef().OpName()}});
}
//std::cout << "Run async node finish: " << p_node_index << std::endl;

keep_running = false;
Expand Down Expand Up @@ -241,8 +253,8 @@ Status ParallelExecutor::FetchOutput(const MLValueNameIdxMap& name_idx_map,
} else {
// this should've been checked before already
ORT_ENFORCE(output_names.size() == fetches.size(),
"output_names vector size: " + std::to_string(output_names.size()) +
" does not match that of fetches vector: " + std::to_string(fetches.size()));
"output_names vector size: " + std::to_string(output_names.size()) +
" does not match that of fetches vector: " + std::to_string(fetches.size()));
}

auto idx = 0;
Expand Down
67 changes: 43 additions & 24 deletions onnxruntime/core/framework/sequential_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@ Status SequentialExecutor::Execute(const SessionState& session_state,
const std::vector<std::string>& output_names,
std::vector<MLValue>& fetches,
const logging::Logger& logger) {
auto tp = session_state.Profiler().StartTime();
bool f_profiler_enabled = session_state.Profiler().FEnabled();
TimePoint tp;
TimePoint sync_time_begin;
TimePoint kernel_begin_time;

if (f_profiler_enabled) {
tp = session_state.Profiler().StartTime();
}

ExecutionFrame frame{feeds, output_names, fetches, session_state};

Expand All @@ -55,17 +62,17 @@ Status SequentialExecutor::Execute(const SessionState& session_state,
// if a kernel has been added in the session state, it better be NON-null.
if (p_op_kernel == nullptr)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Got nullptr from GetKernel for node: ",
session_state.GetGraphViewer()->GetNode(node_index)->Name());
session_state.GetGraphViewer()->GetNode(node_index)->Name());

const std::string& node_name = p_op_kernel->Node().Name();
const std::string& op_name = p_op_kernel->KernelDef().OpName();
// construct OpKernelContext
// TODO: log kernel inputs?
OpKernelContextInternal op_kernel_context(frame, *p_op_kernel, logger, p_op_kernel->Node().ImplicitInputDefs(),
terminate_flag_);
// TODO: log kernel outputs?
if (f_profiler_enabled) {
sync_time_begin = session_state.Profiler().StartTime();
}

auto sync_time_begin = session_state.Profiler().StartTime();
// sync before compute
int queue_id = p_op_kernel->KernelDef().ExecQueueId();
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Expand All @@ -89,22 +96,28 @@ Status SequentialExecutor::Execute(const SessionState& session_state,
}
}

session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name + "_fence_before",
sync_time_begin,
{{"op_name", op_name}});
if (f_profiler_enabled) {
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
p_op_kernel->Node().Name() + "_fence_before",
sync_time_begin,
{{"op_name", p_op_kernel->KernelDef().OpName()}});

// call compute on the kernel
VLOGS(logger, 1) << "Computing kernel: " << p_op_kernel->Node().Name();
// call compute on the kernel
VLOGS(logger, 1) << "Computing kernel: " << p_op_kernel->Node().Name();

auto kernel_begin_time = session_state.Profiler().StartTime();
kernel_begin_time = session_state.Profiler().StartTime();
}
ORT_RETURN_IF_ERROR(p_op_kernel->Compute(&op_kernel_context));
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name + "_kernel_time",
kernel_begin_time,
{{"op_name", op_name}});

sync_time_begin = session_state.Profiler().StartTime();
if (f_profiler_enabled) {
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
p_op_kernel->Node().Name() + "_kernel_time",
kernel_begin_time,
{{"op_name", p_op_kernel->KernelDef().OpName()}});

sync_time_begin = session_state.Profiler().StartTime();
}

// sync after compute for outputs
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
Expand All @@ -126,10 +139,13 @@ Status SequentialExecutor::Execute(const SessionState& session_state,
fence->AfterUsedAsOutput(queue_id);
}
}
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name + "_fence_after",
sync_time_begin,
{{"op_name", op_name}});

if (f_profiler_enabled) {
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
p_op_kernel->Node().Name() + "_fence_after",
sync_time_begin,
{{"op_name", p_op_kernel->KernelDef().OpName()}});
}

// free ml-values corresponding to this node
VLOGS(logger, 1) << "Releasing node ML values after computing kernel: " << p_op_kernel->Node().Name();
Expand Down Expand Up @@ -158,7 +174,10 @@ Status SequentialExecutor::Execute(const SessionState& session_state,
}
}

session_state.Profiler().EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", tp);
if (f_profiler_enabled) {
session_state.Profiler().EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", tp);
}

return Status::OK();
}

Expand All @@ -172,8 +191,8 @@ static Status FetchOutput(const MLValueNameIdxMap& name_idx_map,
} else {
// this should've been checked before already
ORT_ENFORCE(output_names.size() == fetches.size(),
"output_names vector size: " + std::to_string(output_names.size()) +
" does not match that of fetches vector: " + std::to_string(fetches.size()));
"output_names vector size: " + std::to_string(output_names.size()) +
" does not match that of fetches vector: " + std::to_string(fetches.size()));
}

auto idx = 0;
Expand Down
24 changes: 18 additions & 6 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ class InferenceSession::Impl {
LOGS(*session_logger_, ERROR) << "Unknown exception in Load()";
return Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()");
}
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_loading_uri", tp);
if (session_profiler_.FEnabled()) {
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_loading_uri", tp);
}
return common::Status::OK();
}

Expand Down Expand Up @@ -176,7 +178,9 @@ class InferenceSession::Impl {
LOGS(*session_logger_, ERROR) << "Unknown exception in Load()";
return Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()");
}
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_loading_proto", tp);
if (session_profiler_.FEnabled()) {
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_loading_proto", tp);
}
return Status::OK();
}

Expand Down Expand Up @@ -207,7 +211,9 @@ class InferenceSession::Impl {
LOGS(*session_logger_, ERROR) << "Unknown exception in Load()";
return Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()");
}
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_loading_proto", tp);
if (session_profiler_.FEnabled()) {
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_loading_proto", tp);
}
return Status::OK();
}

Expand Down Expand Up @@ -244,7 +250,9 @@ class InferenceSession::Impl {
LOGS(*session_logger_, ERROR) << "Unknown exception in Load()";
return Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()");
}
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_loading_istream", tp);
if (session_profiler_.FEnabled()) {
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_loading_istream", tp);
}
return common::Status::OK();
}

Expand Down Expand Up @@ -419,7 +427,9 @@ class InferenceSession::Impl {
LOGS(*session_logger_, ERROR) << status.ErrorMessage();
}

session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "session_initialization", tp);
if (session_profiler_.FEnabled()) {
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "session_initialization", tp);
}
return status;
}

Expand Down Expand Up @@ -841,7 +851,9 @@ class InferenceSession::Impl {
ORT_CHECK_AND_SET_RETVAL(xp->OnRunEnd());

--current_num_runs_;
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_run", tp);
if (session_profiler_.FEnabled()) {
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_run", tp);
}
return retval;
}

Expand Down