Skip to content

Commit

Permalink
doc: graph: update sdpa document to include implicit causal mask
Browse files Browse the repository at this point in the history
  • Loading branch information
ElaineBao committed Jan 7, 2025
1 parent d9e07d7 commit 4e81d51
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
Binary file added doc/graph/fusion_patterns/images/sdpa-mask-3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified doc/graph/fusion_patterns/images/sdpa.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
31 changes: 23 additions & 8 deletions doc/graph/fusion_patterns/sdpa.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,36 @@ optional.
MatMul with a scaling factor. It can be constructed by [Multiply](@ref dev_guide_op_multiply)
or [Divide](@ref dev_guide_op_divide) operation in Graph API. The scaling
factor is given by users as an input of SDPA. \f$\sqrt{d_k}\f$ in the formula
is not considered as part of the SDPA pattern as it is constant.
is not considered part of the SDPA pattern because it is constant.
3. The Mask node is optional and is used to apply an attention mask to the
output of the previous Scale node. It can be constructed by [Add](@ref dev_guide_op_add)
output of the previous Scale node. There are two types of masks that can
be applied:

1. Explicit user-generated mask: You can explicitly create a mask tensor
and pass it to the library for the computation of SDPA. In this case, mask
can be constructed by [Add](@ref dev_guide_op_add)
or [Select](@ref dev_guide_op_select) operation in Graph API for different
mask policies (eg. causal mask or padding mask). When Add operation is used
to apply the mask, the input mask is usually an upper triangular matrix with
all the elements above the diagonal filled with `-inf` and zeroes elsewhere.
The `-inf` entries will become zero probability after Softmax is applied in
the next step. Alternately, a Select operation may be used. In this case, the
mask policies (for example, causal mask or padding mask). When the
Add operation is used to apply the mask, the input mask is usually an upper
triangular matrix with all the elements above the diagonal filled with
`-inf` and zeroes elsewhere. The `-inf` entries will become zero probability
after Softmax is applied in the next step.
Alternately, a Select operation may be used. In this case, the
input is a boolean tensor (for example, with `true` on and below the
diagonal, and `false` above the diagonal). A `false` element in the mask
forces the corresponding element of the scaled output to `-inf`, while a
`true` element leaves it unchanged.

![SDPA-mask-1](images/sdpa-mask-1.png) ![SDPA-mask-2](images/sdpa-mask-2.png)

2. Implicit library-generated mask: You can ask the library to generate
a mask by constructing a subgraph. Currently, Graph API supports generating
an implicit causal mask (top-left aligned) using operations of
[GenIndex](@ref dev_guide_op_genindex), [GreaterEqual](@ref dev_guide_op_greaterequal)
and [Select](@ref dev_guide_op_select).

![SDPA-mask-3](images/sdpa-mask-3.png)

4. The SoftMax operation takes the masked output and transforms it into
probabilities between 0 and 1. See [SoftMax](@ref dev_guide_op_softmax)
operation in Graph API.
Expand Down Expand Up @@ -97,7 +111,8 @@ platforms follow the general description in @ref dev_guide_data_types.
softmax primitives. The reference implementation requires memory to store the
intermediate results of the dot products between Query and Key which takes
\f$O(S^2)\f$ memory. It may lead to out-of-memory error when computing long
sequence length input on platforms with limited memory.
sequence length input on platforms with limited memory. For implicit causal
mask, the reference implementation is only available on CPU.
2. The SDPA patterns functionally supports all input shapes meeting the shape
requirements of each operation in the graph. For example, Add, Multiply,
Divide, and Select operations require the input tensors to have the same
Expand Down

0 comments on commit 4e81d51

Please sign in to comment.