diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 3f46b69132b0..5d1e35ecafe6 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -595,14 +595,12 @@ class PlannerImpl { if (arg && arg->Exists()) { OrtValueIndex index = Index(arg->Name()); AllocPlanPerValue& value_plan = AllocPlan(index); + + has_fence = value_plan.create_fence_if_async; if (value_plan.alloc_kind == AllocKind::kReuse) { // Buffer reused, check original buffer to see if fence is shared. - has_fence = AllocPlan(value_plan.reused_buffer).create_fence_if_async; - } - else - { - has_fence = value_plan.create_fence_if_async; + has_fence = has_fence || AllocPlan(value_plan.reused_buffer).create_fence_if_async; } } @@ -618,15 +616,15 @@ class PlannerImpl { bool has_fence = false; for (auto node_input : pnode->InputDefs()) { - has_fence |= HasFence(node_input); + has_fence = has_fence || HasFence(node_input); } for (auto node_input : pnode->ImplicitInputDefs()) { - has_fence |= HasFence(node_input); + has_fence = has_fence || HasFence(node_input); } for (auto node_output : pnode->OutputDefs()) { - has_fence |= HasFence(node_output); + has_fence = has_fence || HasFence(node_output); } plan_.node_has_fence[step.node_index] = has_fence; diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index a1d5f0ecd76b..0f08e8613cc1 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -71,8 +71,7 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std: // sync before compute int queue_id = p_op_kernel->KernelDef().ExecQueueId(); - //if (seq_exec_plan.NodeHasFence(node_index)) { - { + if (seq_exec_plan.NodeHasFence(node_index)) { for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) { Fence_t fence = op_kernel_context.InputFence(input_index); if (fence) { @@ -141,8 +140,7 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std: } // sync after compute for outputs - //if (seq_exec_plan.NodeHasFence(node_index)) { - { + if (seq_exec_plan.NodeHasFence(node_index)) { for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) { Fence_t fence = op_kernel_context.InputFence(input_index); if (fence) {