Skip to content

Commit

Permalink
Fix the re-use case
Browse files Browse the repository at this point in the history
  • Loading branch information
ybrnathan committed Aug 9, 2019
1 parent 04f70c0 commit 5967c87
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
14 changes: 6 additions & 8 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -595,14 +595,12 @@ class PlannerImpl {
if (arg && arg->Exists()) {
OrtValueIndex index = Index(arg->Name());
AllocPlanPerValue& value_plan = AllocPlan(index);

has_fence = value_plan.create_fence_if_async;
if (value_plan.alloc_kind == AllocKind::kReuse)
{
// Buffer reused, check original buffer to see if fence is shared.
has_fence = AllocPlan(value_plan.reused_buffer).create_fence_if_async;
}
else
{
has_fence = value_plan.create_fence_if_async;
has_fence = has_fence || AllocPlan(value_plan.reused_buffer).create_fence_if_async;
}
}

Expand All @@ -618,15 +616,15 @@ class PlannerImpl {

bool has_fence = false;
for (auto node_input : pnode->InputDefs()) {
has_fence |= HasFence(node_input);
has_fence = has_fence || HasFence(node_input);
}

for (auto node_input : pnode->ImplicitInputDefs()) {
has_fence |= HasFence(node_input);
has_fence = has_fence || HasFence(node_input);
}

for (auto node_output : pnode->OutputDefs()) {
has_fence |= HasFence(node_output);
has_fence = has_fence || HasFence(node_output);
}

plan_.node_has_fence[step.node_index] = has_fence;
Expand Down
6 changes: 2 additions & 4 deletions onnxruntime/core/framework/sequential_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:

// sync before compute
int queue_id = p_op_kernel->KernelDef().ExecQueueId();
//if (seq_exec_plan.NodeHasFence(node_index)) {
{
if (seq_exec_plan.NodeHasFence(node_index)) {
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
Expand Down Expand Up @@ -141,8 +140,7 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
}

// sync after compute for outputs
//if (seq_exec_plan.NodeHasFence(node_index)) {
{
if (seq_exec_plan.NodeHasFence(node_index)) {
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
Expand Down

0 comments on commit 5967c87

Please sign in to comment.