Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph, benchdnn: support implicit causal mask in graph API #2330

Merged
merged 16 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified doc/graph/fusion_patterns/images/sdpa.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
39 changes: 27 additions & 12 deletions doc/graph/fusion_patterns/sdpa.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,36 @@ optional.
MatMul with a scaling factor. It can be constructed by [Multiply](@ref dev_guide_op_multiply)
or [Divide](@ref dev_guide_op_divide) operation in Graph API. The scaling
factor is given by users as an input of SDPA. \f$\sqrt{d_k}\f$ in the formula
is not considered as part of the SDPA pattern as it is constant.
is not considered as a part of the SDPA pattern because it is a constant.
3. The Mask node is optional and is used to apply an attention mask to the
output of the previous Scale node. It can be constructed by [Add](@ref dev_guide_op_add)
output of the previous Scale node. There are two types of masks that can
be applied:

1. Explicit user-generated mask: You can explicitly create a mask tensor
and pass it to the library for the computation of SDPA. In this case, mask
can be constructed by [Add](@ref dev_guide_op_add)
or [Select](@ref dev_guide_op_select) operation in Graph API for different
mask policies (eg. causal mask or padding mask). When Add operation is used
to apply the mask, the input mask is usually an upper triangular matrix with
all the elements above the diagonal filled with `-inf` and zeroes elsewhere.
The `-inf` entries will become zero probability after Softmax is applied in
the next step. Alternately, a Select operation may be used. In this case, the
input is a boolean tensor (for example, with `true` on and below the
diagonal, and `false` above the diagonal). A `false` element in the mask
forces the corresponding element of the scaled output to `-inf`, while a
`true` element leaves it unchanged.
mask policies (for example, causal mask or padding mask). When the
Add operation is used to apply the mask, the input mask is usually an upper
triangular matrix with all the elements above the diagonal filled with
`-inf` and zeroes elsewhere. The `-inf` entries will become zero probability
after Softmax is applied in the next step.
Alternatively, a Select operation may be used. In this case, the
input is a boolean tensor (for example, with the boolean value set to `true`
on and below the diagonal, and `false` above the diagonal).
A `false` element in the mask forces the corresponding element of the scaled
output to `-inf`, while a `true` element leaves it unchanged.

![SDPA-mask-1](images/sdpa-mask-1.png) ![SDPA-mask-2](images/sdpa-mask-2.png)

2. Implicit library-generated mask: You can use the operations in the library
to generate a mask by constructing a subgraph. Currently, Graph API supports
generating an implicit causal mask (top-left aligned) using operations of
[GenIndex](@ref dev_guide_op_genindex), [GreaterEqual](@ref dev_guide_op_greaterequal)
and [Select](@ref dev_guide_op_select).

![SDPA-mask-3](images/sdpa-mask-3.png)

4. The SoftMax operation takes the masked output and transforms it into
probabilities between 0 and 1. See [SoftMax](@ref dev_guide_op_softmax)
operation in Graph API.
Expand Down Expand Up @@ -97,7 +111,8 @@ platforms follow the general description in @ref dev_guide_data_types.
softmax primitives. The reference implementation requires memory to store the
intermediate results of the dot products between Query and Key which takes
\f$O(S^2)\f$ memory. It may lead to out-of-memory error when computing long
sequence length input on platforms with limited memory.
sequence length input on platforms with limited memory. For an implicit
causal mask, the reference implementation is only available on CPU.
2. The SDPA patterns functionally supports all input shapes meeting the shape
requirements of each operation in the graph. For example, Add, Multiply,
Divide, and Select operations require the input tensors to have the same
Expand Down
39 changes: 39 additions & 0 deletions doc/graph/operations/GenIndex.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
GenIndex{#dev_guide_op_genindex}
================================

## General
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
## General
## Overview

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, as other operations are also using General, I suggest to change them all in a separate PR.


The GenIndex operation creates an index tensor along a specified axis of
an input tensor. The resulting index tensor has the same shape as the
input tensor, with each element representing the index along the
specified axis.

## Operation Attributes

| Attribute Name | Description | Value Type | Supported Values | Required or Optional |
|:------------------------------------------|:----------------------------------------------------------------|:-----------|:-----------------------------------------------------------|:---------------------|
| [axis] (@ref dnnl::graph::op::attr::axis) | Specifies the dimension along which index values are generated. | s64 | An s64 value in the range of [-r, r-1] where r = rank(src) | Required |

## Execution Arguments

### Input

| Index | Argument Name | Required or Optional |
|:------|:--------------|:---------------------|
| 0 | `src` | Required |

### Output

| Index | Argument Name | Required or Optional |
|:------|:--------------|:---------------------|
| 0 | `dst` | Required |

## Supported Data Types

The GenIndex operation supports the following data type combinations.

| Src | Dst |
|:-------|:-------|
| f32 | s32 |
| bf16 | s32 |
| f16 | s32 |
49 changes: 49 additions & 0 deletions doc/graph/operations/GreaterEqual.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
GreaterEqual{#dev_guide_op_greaterequal}
========================================

## General
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
## General
## Overview


The GreaterEqual operation performs an element-wise greater-than-or-equal
comparison between two given tensors. This operation applies
the multi-directional broadcast rules to ensure compatibility between
the tensors of different shapes.

\f[ dst = \begin{cases} true & \text{if}\ src_0 \ge src_1 \\
false & \text{if}\ src_0 < src_1 \end{cases} \f]

## Operation Attributes

| Attribute Name | Description | Value Type | Supported Values | Required or Optional |
|:-------------------------------------------------------------|:-----------------------------------------------------------|:-----------|:-------------------------|:---------------------|
| [auto_broadcast](@ref dnnl::graph::op::attr::auto_broadcast) | Specifies rules used for auto-broadcasting of src tensors. | string | `none`,`numpy` (default) | Optional |

## Execution Arguments

### Input

| Index | Argument Name | Required or Optional |
|:------|:--------------|:---------------------|
| 0 | `src_0` | Required |
| 1 | `src_1` | Required |

@note Both src shapes should match and no auto-broadcasting is allowed if
the `auto_broadcast` attribute is `none`. `src_0` and `src_1` shapes can be
different and auto-broadcasting is allowed if the `auto_broadcast` attribute
is `numpy`. Broadcasting is performed according to the `auto_broadcast` value.

### Output

| Index | Argument Name | Required or Optional |
|:------|:--------------|:---------------------|
| 0 | `dst` | Required |

## Supported Data Types

The GreaterEqual operation supports the following data type combinations.

| Src_0 / Src_1 | Dst |
|:--------------|:---------|
| f32 | boolean |
| bf16 | boolean |
| f16 | boolean |
| s32 | boolean |
4 changes: 3 additions & 1 deletion include/oneapi/dnnl/dnnl_graph.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2024 Intel Corporation
* Copyright 2020-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -846,6 +846,8 @@ class op : public op_handle {
TanhBackward = dnnl_graph_op_tanh_backward,
TypeCast = dnnl_graph_op_type_cast,
Wildcard = dnnl_graph_op_wildcard,
GenIndex = dnnl_graph_op_gen_index,
GreaterEqual = dnnl_graph_op_greater_equal,
ElaineBao marked this conversation as resolved.
Show resolved Hide resolved
// Sentinel
LastSymbol = dnnl_graph_op_last_symbol,
};
Expand Down
4 changes: 3 additions & 1 deletion include/oneapi/dnnl/dnnl_graph_types.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2024 Intel Corporation
* Copyright 2020-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -256,6 +256,8 @@ typedef enum {
dnnl_graph_op_select,
dnnl_graph_op_pow,
dnnl_graph_op_group_norm,
dnnl_graph_op_gen_index,
dnnl_graph_op_greater_equal,
dnnl_graph_op_last_symbol,
} dnnl_graph_op_kind_t;

Expand Down
17 changes: 16 additions & 1 deletion src/graph/backend/dnnl/dnnl_op_def.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2024 Intel Corporation
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -774,6 +774,21 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_eltwise_bwd, 1,
executable_creator<eltwise_bwd_executable_t>)
.SET_ARG_INDICES_GETTER(eltwise_bwd_executable_t))

DNNL_GRAPH_OP_SCHEMA(dnnl_gen_index, 1,
ElaineBao marked this conversation as resolved.
Show resolved Hide resolved
op_schema_t()
.set_num_inputs(1)
.set_num_outputs(1)
.set_input(0, "input")
.set_output(0, "output")
// Attributes inherited from front GenIndex ops
.set_attr(op_attr::axis, true, attribute_kind::i)
// Analysis rules
.set_shape_inference_function(infer_identity_output_shape)
.SET_LAYOUT_PROPAGATOR(layout_propagator_for_gen_index)
.SET_EXECUTABLE_CREATOR(
executable_creator<genindex_executable_t>)
.SET_ARG_INDICES_GETTER(genindex_executable_t))

DNNL_GRAPH_OP_SCHEMA(dnnl_shuffle, 1,
op_schema_t()
.set_num_inputs(1)
Expand Down
3 changes: 2 additions & 1 deletion src/graph/backend/dnnl/dnnl_opset.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2024 Intel Corporation
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -72,6 +72,7 @@ class dnnl_opset_t {
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_eltwise, 1)>());
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(
dnnl_eltwise_bwd, 1)>());
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_gen_index, 1)>());
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_shuffle, 1)>());
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_sum, 1)>());
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_prelu, 1)>());
Expand Down
5 changes: 3 additions & 2 deletions src/graph/backend/dnnl/internal_ops.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2024 Intel Corporation
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -77,7 +77,8 @@ namespace op_kind {
X(dnnl_reorder, Dnnl_reorder) \
X(dnnl_convtranspose_bwd_data, Dnnl_convtranspose_bwd_data) \
X(dnnl_convtranspose_bwd_weights, Dnnl_convtranspose_bwd_weights) \
X(dnnl_groupnorm, Dnnl_groupnorm)
X(dnnl_groupnorm, Dnnl_groupnorm) \
X(dnnl_gen_index, Dnnl_gen_index)

enum kind_t {
kDNNL_INTERNAL_OP_STARTER = 0x1234,
Expand Down
154 changes: 154 additions & 0 deletions src/graph/backend/dnnl/kernels/gen_index.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*******************************************************************************
* Copyright 2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "graph/backend/dnnl/kernels/gen_index.hpp"

#include "graph/backend/dnnl/passes/compile_ops.hpp"
#include "graph/backend/dnnl/passes/constant_propagation.hpp"
#include "graph/backend/dnnl/passes/insert_ops.hpp"
#include "graph/backend/dnnl/passes/layout_propagation.hpp"
#include "graph/backend/dnnl/passes/lower.hpp"
#include "graph/backend/dnnl/passes/memory_planning.hpp"
#include "graph/backend/dnnl/passes/transform.hpp"
#include "graph/backend/dnnl/passes/utils.hpp"

#include "graph/backend/dnnl/op_executable.hpp"
namespace dnnl {
namespace impl {
namespace graph {
namespace dnnl_impl {

status_t genindex_t::compile_impl(const dnnl_partition_impl_t *part,
const engine_t *g_engine, const std::vector<logical_tensor_t> &inputs,
const std::vector<logical_tensor_t> &outputs) {
p_engine_ = make_dnnl_engine(*g_engine);
g_alloc_
= reinterpret_cast<graph::allocator_t *>(g_engine->get_allocator());

subgraph_ = std::make_shared<subgraph_t>(part->get_ops(), p_engine_,
part->get_fpmath_mode(), part->get_use_blocked_layout(), true);
BACKEND_DNNL_CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs));

subgraph_visualizer_t vis(part->id(), [this](const value_t *val) {
return this->memory_planner_.get_memory_info(val);
});
pass_pipeline_t pipeline(vis);

BACKEND_DNNL_ADD_PASS(pipeline, lower_down);

pipeline.reset_visualize_arg(true, false);

BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);

// constant propagation
if (enabled_constant_cache()) {
BACKEND_DNNL_ADD_PASS(pipeline, constant_propagation);
}

// bind the memory for each op
auto memory_plan = [&](std::shared_ptr<subgraph_t> &sg) {
return memory_planner_.run(sg);
};
pipeline.reset_visualize_arg(true, true);
BACKEND_DNNL_ADD_PASS(pipeline, memory_plan);
BACKEND_DNNL_ADD_PASS(pipeline, compile_ops);

// Run the added passes
BACKEND_DNNL_CHECK(pipeline.run(subgraph_));

// fill information for outputs logical tensors
for (size_t i = 0; i < outputs.size(); i++) {
auto &out = const_cast<logical_tensor_t &>(outputs[i]);
out = subgraph_->outs_[i];
}

// generate a hash key for exec_args_mgr
resource_ctor_ = [this]() {
return this->memory_planner_.get_exec_args_set().clone();
};

return status::success;
}

void genindex_t::prepare_args_set(const execution_args_set_t *res,
const std::vector<tensor_t> &inputs,
const std::vector<tensor_t> &outputs, const scratchpad_t &scratchpad) {
// update the data of partition in/outputs args
for (const auto &mem_idx : res->get_mems_use_external_inputs()) {
mem_idx.first.set_data_handle(inputs[mem_idx.second].get_data_handle());
}
for (const auto &mem_idx : res->get_mems_use_external_outputs()) {
mem_idx.first.set_data_handle(
outputs[mem_idx.second].get_data_handle());
}

grantor_t var_grantor = memory_planner_.internal_temporary_grantor(
scratchpad.get_buffer());

for (auto &mem_offkey : res->get_mems_use_internal_temporary()) {
mem_offkey.first.set_data_handle(var_grantor.get(mem_offkey.second));
}
}

status_t genindex_t::execute_impl(const stream_t *g_stream,
const std::vector<tensor_t> &inputs,
const std::vector<tensor_t> &outputs) {
dnnl::stream p_stream = make_dnnl_stream(p_engine_, *g_stream);

// each thread's own local resource
thread_local_cache_t<execution_args_set_t> res_cache;
execution_args_set_t *res = res_cache.get_or_add(
reinterpret_cast<size_t>(this), resource_ctor_);

temporary_scratchpad_t scratchpad(
memory_planner_.total_internal_temporary_size(), p_engine_,
*g_alloc_);
assertm(scratchpad.size()
>= memory_planner_.total_internal_temporary_size(),
"no enough scratchpad memory");
prepare_args_set(res, inputs, outputs, scratchpad);

constant_cache_t::cached_t c_buffer;

for (size_t i = 0; i < subgraph_->execs_.size(); i++) {
if (subgraph_->is_constant_[i]) continue;
subgraph_->execs_[i]->execute(p_stream, res->get_exec_args()[i]);
}

return status::success;
}
#ifdef DNNL_WITH_SYCL
status_t genindex_t::sycl_execute_impl(const stream_t *g_stream,
const std::vector<tensor_t> &inputs,
const std::vector<tensor_t> &outputs,
const std::vector<::sycl::event> &sycl_deps,
::sycl::event *sycl_event) {
if (p_engine_.get_kind() == engine::kind::gpu) return status::unimplemented;
return execute_impl(g_stream, inputs, outputs);
}
#endif
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
status_t genindex_t::ocl_execute_impl(const stream_t *g_stream,
const std::vector<tensor_t> &inputs,
const std::vector<tensor_t> &outputs,
const std::vector<cl_event> &ocl_deps, cl_event *ocl_event) {
// TODO: add support
return status::unimplemented;
dzarukin marked this conversation as resolved.
Show resolved Hide resolved
}
#endif
} // namespace dnnl_impl
} // namespace graph
} // namespace impl
} // namespace dnnl
Loading
Loading