Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph, benchdnn: support implicit causal mask in graph API #2330

Merged
merged 16 commits into from
Jan 9, 2025

Conversation

ElaineBao
Copy link
Contributor

@ElaineBao ElaineBao commented Jan 2, 2025

Description

The implementation of option 1.1 (top-left aligned causal mask with subgraph approach) in rfcs: graph api: support implicit causal mask in SDPA

@ElaineBao ElaineBao added component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch labels Jan 2, 2025
@ElaineBao ElaineBao self-assigned this Jan 2, 2025
@ElaineBao ElaineBao requested review from a team as code owners January 2, 2025 04:27
@github-actions github-actions bot added the component:api Codeowner: @oneapi-src/onednn-arch label Jan 2, 2025
@ElaineBao ElaineBao force-pushed the graph-implicit-causal-mask branch from f444b05 to 4651fa6 Compare January 2, 2025 06:46
@gyhintel gyhintel force-pushed the graph-implicit-causal-mask branch from 4651fa6 to 2a9ab89 Compare January 2, 2025 07:50
@gyhintel gyhintel force-pushed the graph-implicit-causal-mask branch 2 times, most recently from ab92e00 to 28f56e6 Compare January 2, 2025 13:27
@ElaineBao
Copy link
Contributor Author

make test
enable benchdnn_nightly
disable benchdnn_all
enable benchdnn_graph

src/graph/backend/dnnl/kernels/gen_index.hpp Outdated Show resolved Hide resolved
src/graph/backend/dnnl/kernels/gen_index.hpp Show resolved Hide resolved
@@ -146,6 +146,37 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_fusion)
return std::make_shared<sdp_base_t<>>();
});

DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_implicit_mask_fusion)
.set_priority(21.0f)
.set_engine_kind(engine_kind::cpu)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why cpu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the currently implementation of gen_index only focus on CPU

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have any plan to implement it for GPU also?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but most likely in a separate PR, and not target v3.7.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this pattern should work for optimized GPU version, I don't think it make sense to limit to engine kind.
I'd expect it returns unimplemented later in the flow (it won't pick up any other patterns anyway).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The concern is that: if we don't limit the engine kind here, then for GPU case, the unimplemented status returned to user will be at execution stage, which may be too late for them to make necessary modifications.

In contrast, if we adopt the current CPU handling (and plan to remove this constraint once GPU support is implemented), then unsupported partitions will be returned earlier at get_partitions stage, this approach allows users to handle these unsupported partitions earlier by themselves.

src/graph/backend/dnnl/utils.hpp Outdated Show resolved Hide resolved
tests/benchdnn/graph/custom_driver.cpp Show resolved Hide resolved
tests/gtests/graph/api/test_cpp_api_op.cpp Outdated Show resolved Hide resolved
@gyhintel gyhintel force-pushed the graph-implicit-causal-mask branch from 28f56e6 to 4975a13 Compare January 3, 2025 08:44
@ElaineBao ElaineBao requested a review from a team as a code owner January 4, 2025 07:35
@github-actions github-actions bot added the documentation A request to change/fix/improve the documentation. Codeowner: @oneapi-src/onednn-doc label Jan 4, 2025
@ElaineBao
Copy link
Contributor Author

make test
enable benchdnn_nightly
disable benchdnn_all
enable benchdnn_graph

doc/graph/operations/GenIndex.md Outdated Show resolved Hide resolved
doc/graph/operations/GreaterEqual.md Outdated Show resolved Hide resolved
@ElaineBao ElaineBao force-pushed the graph-implicit-causal-mask branch 2 times, most recently from 78c01cc to 0ac759c Compare January 6, 2025 08:29
@ElaineBao ElaineBao changed the title graph: api, interface, backend; benchdnn: graph: support implicit causal mask in graph API graph, benchdnn: support implicit causal mask in graph API Jan 6, 2025
Copy link
Contributor

@ranukund ranukund left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left a few comments, please incorporate as you see fit, thanks!

doc/graph/fusion_patterns/sdpa.md Outdated Show resolved Hide resolved
doc/graph/fusion_patterns/sdpa.md Outdated Show resolved Hide resolved
doc/graph/fusion_patterns/sdpa.md Outdated Show resolved Hide resolved
doc/graph/fusion_patterns/sdpa.md Outdated Show resolved Hide resolved
doc/graph/operations/GenIndex.md Outdated Show resolved Hide resolved
doc/graph/operations/GreaterEqual.md Outdated Show resolved Hide resolved
doc/graph/operations/GreaterEqual.md Outdated Show resolved Hide resolved
doc/graph/operations/GreaterEqual.md Outdated Show resolved Hide resolved
doc/graph/operations/GreaterEqual.md Outdated Show resolved Hide resolved
doc/graph/operations/GreaterEqual.md Outdated Show resolved Hide resolved
@ElaineBao ElaineBao force-pushed the graph-implicit-causal-mask branch from 0ac759c to 4e81d51 Compare January 7, 2025 01:28
@gyhintel gyhintel force-pushed the graph-implicit-causal-mask branch from 3bbb517 to 1a555aa Compare January 8, 2025 03:14
@gyhintel
Copy link
Contributor

gyhintel commented Jan 8, 2025

make test
enable benchdnn_nightly
disable benchdnn_all
enable benchdnn_graph

1 similar comment
@ElaineBao
Copy link
Contributor Author

make test
enable benchdnn_nightly
disable benchdnn_all
enable benchdnn_graph

Copy link
Contributor

@ranukund ranukund left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few more edits suggested, please incorporate, thanks!

doc/graph/fusion_patterns/sdpa.md Outdated Show resolved Hide resolved
doc/graph/fusion_patterns/sdpa.md Outdated Show resolved Hide resolved
doc/graph/fusion_patterns/sdpa.md Outdated Show resolved Hide resolved
doc/graph/fusion_patterns/sdpa.md Outdated Show resolved Hide resolved
doc/graph/fusion_patterns/sdpa.md Outdated Show resolved Hide resolved
doc/graph/operations/GreaterEqual.md Outdated Show resolved Hide resolved
Copy link
Contributor

@TaoLv TaoLv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I reviewed API, document, interface folder, part of the DNNL backend implementation.

@ElaineBao ElaineBao force-pushed the graph-implicit-causal-mask branch from e133094 to 027e58b Compare January 8, 2025 15:40
@ElaineBao
Copy link
Contributor Author

make test
enable benchdnn_nightly
disable benchdnn_all
enable benchdnn_graph

Copy link
Contributor

@ranukund ranukund left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor edits suggested. Rest looks good, thanks

doc/graph/operations/GenIndex.md Outdated Show resolved Hide resolved
GenIndex{#dev_guide_op_genindex}
================================

## General
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
## General
## Overview

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, as other operations are also using General, I suggest to change them all in a separate PR.

GreaterEqual{#dev_guide_op_greaterequal}
========================================

## General
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
## General
## Overview

@ElaineBao ElaineBao force-pushed the graph-implicit-causal-mask branch from 027e58b to 9208bcf Compare January 8, 2025 16:03
@ElaineBao
Copy link
Contributor Author

make test
enable benchdnn_nightly
disable benchdnn_all
enable benchdnn_graph

@TaoLv TaoLv merged commit 8b017b0 into main Jan 9, 2025
16 of 18 checks passed
@TaoLv TaoLv deleted the graph-implicit-causal-mask branch January 9, 2025 07:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component:api Codeowner: @oneapi-src/onednn-arch component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch documentation A request to change/fix/improve the documentation. Codeowner: @oneapi-src/onednn-doc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants