Skip to content

Commit

Permalink
[xla:cpu] Optimize ThunkExecutor::Execute part #2
Browse files Browse the repository at this point in the history
Use std::aligned_storage_t trick to avoid default-initializing Node struct on a hot path.

name                                     old cpu/op   new cpu/op   delta
BM_SelectAndScatterF32/128/process_time   791µs ± 4%   720µs ± 2%  -8.93%
BM_SelectAndScatterF32/256/process_time  3.20ms ± 4%  2.96ms ± 2%  -7.46%
BM_SelectAndScatterF32/512/process_time  13.7ms ± 5%  12.8ms ± 2%  -6.80%

name                                     old time/op          new time/op          delta
BM_SelectAndScatterF32/128/process_time   790µs ± 5%           719µs ± 1%   -9.00%
BM_SelectAndScatterF32/256/process_time  3.20ms ± 3%          2.96ms ± 1%   -7.58%
BM_SelectAndScatterF32/512/process_time  13.2ms ± 4%          12.3ms ± 1%   -6.82%

PiperOrigin-RevId: 657741110
  • Loading branch information
ezhulenev authored and copybara-github committed Jul 31, 2024
1 parent 1b09c08 commit f8db3d5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
13 changes: 7 additions & 6 deletions xla/service/cpu/runtime/thunk_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ absl::StatusOr<ThunkExecutor> ThunkExecutor::Create(
return ThunkExecutor(std::move(thunk_sequence), std::move(defs), options);
}

ThunkExecutor::ExecuteState::Node::Node(const NodeDef& node_def)
: counter(node_def.in_edges.size()), out_edges(&node_def.out_edges) {}

ThunkExecutor::ExecuteState::ExecuteState(ThunkExecutor* executor,
Thunk::TaskRunner* runner)
: executor(executor),
Expand All @@ -133,11 +136,9 @@ ThunkExecutor::ExecuteState::ExecuteState(ThunkExecutor* executor,
DCHECK(runner == nullptr || static_cast<bool>(*runner))
<< "`runner` must be nullptr or a valid TaskRunner";

Node* node = nodes.data();
NodeStorage* node = nodes.data();
for (const NodeDef& node_def : executor->nodes_defs()) {
node->counter.store(node_def.in_edges.size(), std::memory_order_release);
node->out_edges = &node_def.out_edges;
++node;
new (node++) Node(node_def);
}
}

Expand Down Expand Up @@ -271,7 +272,7 @@ void ThunkExecutor::Execute(ExecuteState* state,

for (int64_t i = 0; i < ready_queue.size(); ++i) {
NodeId id = ready_queue[i];
ExecuteState::Node& node = state->nodes[id];
ExecuteState::Node& node = state->node(id);

int64_t cnt = node.counter.load(std::memory_order_acquire);
DCHECK_EQ(cnt, 0) << "Node counter must be 0"; // Crash Ok
Expand Down Expand Up @@ -375,7 +376,7 @@ void ThunkExecutor::ProcessOutEdges(

// Append ready nodes to the back of the ready queue.
for (NodeId out_edge : *node.out_edges) {
ExecuteState::Node& out_node = state->nodes[out_edge];
ExecuteState::Node& out_node = state->node(out_edge);

int64_t cnt = out_node.counter.fetch_sub(1, std::memory_order_release);
DCHECK_GE(cnt, 1) << "Node counter can't drop below 0";
Expand Down
14 changes: 13 additions & 1 deletion xla/service/cpu/runtime/thunk_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include <limits>
#include <new>
#include <string>
#include <type_traits>
#include <vector>

#include "absl/base/thread_annotations.h"
Expand Down Expand Up @@ -113,16 +114,27 @@ class ThunkExecutor {
// At run time NodeDef instantiated as a Node with an atomic counter that
// drops to zero when all `in_edges` are ready.
struct Node {
explicit Node(const NodeDef& node_def);

alignas(kAtomicAlignment) std::atomic<int64_t> counter;
const std::vector<NodeId>* out_edges;
};

static_assert(std::is_trivially_destructible_v<Node>,
"Node must be trivially destructible");

// We use indirection via NodeStorage to be able to allocate uninitialized
// memory and do not pay the cost of default initializing all nodes.
using NodeStorage = std::aligned_storage_t<sizeof(Node), alignof(Node)>;

ExecuteState(ThunkExecutor* executor, Thunk::TaskRunner* runner);

Node& node(NodeId id) { return *reinterpret_cast<Node*>(&nodes[id]); }

ThunkExecutor* executor;
Thunk::TaskRunner* runner;

absl::FixedArray<Node> nodes;
absl::FixedArray<NodeStorage> nodes;
tsl::AsyncValueRef<ExecuteEvent> execute_event;

// Once the number of pending sink nodes drops to zero, the execution is
Expand Down

0 comments on commit f8db3d5

Please sign in to comment.