Skip to content

Commit

Permalink
doc: graph: update sdpa document to include implicit causal mask
Browse files Browse the repository at this point in the history
  • Loading branch information
ElaineBao committed Jan 6, 2025
1 parent df94bcc commit 78c01cc
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
Binary file added doc/graph/fusion_patterns/images/sdpa-mask-3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified doc/graph/fusion_patterns/images/sdpa.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 16 additions & 2 deletions doc/graph/fusion_patterns/sdpa.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ optional.
factor is given by users as an input of SDPA. \f$\sqrt{d_k}\f$ in the formula
is not considered as part of the SDPA pattern as it is constant.
3. The Mask node is optional and is used to apply an attention mask to the
output of the previous Scale node. It can be constructed by [Add](@ref dev_guide_op_add)
output of the previous Scale node. There are two types of masks that can
be applied:

1. Explicit user-generated mask: users can explicitly create a mask tensor
and pass it to the library for the computation of SDPA. In this case, mask
can be constructed by [Add](@ref dev_guide_op_add)
or [Select](@ref dev_guide_op_select) operation in Graph API for different
mask policies (eg. causal mask or padding mask). When Add operation is used
to apply the mask, the input mask is usually an upper triangular matrix with
Expand All @@ -60,6 +65,14 @@ optional.

![SDPA-mask-1](images/sdpa-mask-1.png) ![SDPA-mask-2](images/sdpa-mask-2.png)

2. Implicit library-generated mask: users can ask the library to generate
a mask by constructing a subgraph. Currently, Graph API supports generating
an implicit causal mask (top-left aligned) using operations of
[GenIndex](@ref dev_guide_op_genindex), [GreaterEqual](@ref dev_guide_op_greaterequal)
and [Select](@ref dev_guide_op_select).

![SDPA-mask-3](images/sdpa-mask-3.png)

4. The SoftMax operation takes the masked output and transforms it into
probabilities between 0 and 1. See [SoftMax](@ref dev_guide_op_softmax)
operation in Graph API.
Expand Down Expand Up @@ -97,7 +110,8 @@ platforms follow the general description in @ref dev_guide_data_types.
softmax primitives. The reference implementation requires memory to store the
intermediate results of the dot products between Query and Key which takes
\f$O(S^2)\f$ memory. It may lead to out-of-memory error when computing long
sequence length input on platforms with limited memory.
sequence length input on platforms with limited memory. For implicit causal
mask, the reference implementation is only available on CPU.
2. The SDPA patterns functionally supports all input shapes meeting the shape
requirements of each operation in the graph. For example, Add, Multiply,
Divide, and Select operations require the input tensors to have the same
Expand Down

0 comments on commit 78c01cc

Please sign in to comment.