-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
graph, benchdnn: support implicit causal mask in graph API #2330
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
e8e9dfd
benchdnn: graph: styling
dzarukin 07589b9
benchdnn: graph: move custom op on f32 completely
dzarukin 173c983
graph: api: support GenIndex and GreaterEqual operations
ElaineBao 2c64f55
graph: interface: support GenIndex and GreaterEqual operations
ElaineBao 28441d5
graph: backend: dnnl: add internal op of dnnl_gen_index
gyhintel 2290e02
graph: dnnl: enable gen_index in dnnl backend
gyhintel 524a847
graph: dnnl: enable greater_equal in dnnl backend
gyhintel b481d78
benchdnn: graph: support genindex and greaterequal op in benchdnn graph
gyhintel 1272246
benchdnn: graph: inputs: add gen_index and greater_equal cases
gyhintel 69181fc
graph: utils: pm: support multi-consumers input for repetition
ElaineBao df15f6b
graph: backend: dnnl: add pattern for sdp with implicit causal mask
ElaineBao a8c89aa
benchdnn: graph: inputs: add a sdpa implicit causal mask case
ElaineBao c9dab3b
gtests: graph: api: add gtest for GenIndex and GreaterEqual
gyhintel f30023b
doc: graph: add document for GenIndex and GreaterEqual
ElaineBao 6ee19e0
doc: graph: update sdpa document to include implicit causal mask
ElaineBao 9208bcf
benchdnn: graph: skip unimplemented GenIndex op on gpu
ElaineBao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
GenIndex{#dev_guide_op_genindex} | ||
================================ | ||
|
||
## General | ||
|
||
The GenIndex operation creates an index tensor along a specified axis of | ||
an input tensor. The resulting index tensor has the same shape as the | ||
input tensor, with each element representing the index along the | ||
specified axis. | ||
|
||
## Operation Attributes | ||
|
||
| Attribute Name | Description | Value Type | Supported Values | Required or Optional | | ||
|:------------------------------------------|:----------------------------------------------------------------|:-----------|:-----------------------------------------------------------|:---------------------| | ||
| [axis] (@ref dnnl::graph::op::attr::axis) | Specifies the dimension along which index values are generated. | s64 | An s64 value in the range of [-r, r-1] where r = rank(src) | Required | | ||
|
||
## Execution Arguments | ||
|
||
### Input | ||
|
||
| Index | Argument Name | Required or Optional | | ||
|:------|:--------------|:---------------------| | ||
| 0 | `src` | Required | | ||
|
||
### Output | ||
|
||
| Index | Argument Name | Required or Optional | | ||
|:------|:--------------|:---------------------| | ||
| 0 | `dst` | Required | | ||
|
||
## Supported Data Types | ||
|
||
The GenIndex operation supports the following data type combinations. | ||
|
||
| Src | Dst | | ||
|:-------|:-------| | ||
| f32 | s32 | | ||
| bf16 | s32 | | ||
| f16 | s32 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,49 @@ | ||||||
GreaterEqual{#dev_guide_op_greaterequal} | ||||||
======================================== | ||||||
|
||||||
## General | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
The GreaterEqual operation performs an element-wise greater-than-or-equal | ||||||
comparison between two given tensors. This operation applies | ||||||
the multi-directional broadcast rules to ensure compatibility between | ||||||
the tensors of different shapes. | ||||||
|
||||||
\f[ dst = \begin{cases} true & \text{if}\ src_0 \ge src_1 \\ | ||||||
false & \text{if}\ src_0 < src_1 \end{cases} \f] | ||||||
|
||||||
## Operation Attributes | ||||||
|
||||||
| Attribute Name | Description | Value Type | Supported Values | Required or Optional | | ||||||
|:-------------------------------------------------------------|:-----------------------------------------------------------|:-----------|:-------------------------|:---------------------| | ||||||
| [auto_broadcast](@ref dnnl::graph::op::attr::auto_broadcast) | Specifies rules used for auto-broadcasting of src tensors. | string | `none`,`numpy` (default) | Optional | | ||||||
|
||||||
## Execution Arguments | ||||||
|
||||||
### Input | ||||||
|
||||||
| Index | Argument Name | Required or Optional | | ||||||
|:------|:--------------|:---------------------| | ||||||
| 0 | `src_0` | Required | | ||||||
| 1 | `src_1` | Required | | ||||||
|
||||||
@note Both src shapes should match and no auto-broadcasting is allowed if | ||||||
the `auto_broadcast` attribute is `none`. `src_0` and `src_1` shapes can be | ||||||
different and auto-broadcasting is allowed if the `auto_broadcast` attribute | ||||||
is `numpy`. Broadcasting is performed according to the `auto_broadcast` value. | ||||||
|
||||||
### Output | ||||||
|
||||||
| Index | Argument Name | Required or Optional | | ||||||
|:------|:--------------|:---------------------| | ||||||
| 0 | `dst` | Required | | ||||||
|
||||||
## Supported Data Types | ||||||
|
||||||
The GreaterEqual operation supports the following data type combinations. | ||||||
|
||||||
| Src_0 / Src_1 | Dst | | ||||||
|:--------------|:---------| | ||||||
| f32 | boolean | | ||||||
| bf16 | boolean | | ||||||
| f16 | boolean | | ||||||
| s32 | boolean | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
/******************************************************************************* | ||
* Copyright 2025 Intel Corporation | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*******************************************************************************/ | ||
#include "graph/backend/dnnl/kernels/gen_index.hpp" | ||
|
||
#include "graph/backend/dnnl/passes/compile_ops.hpp" | ||
#include "graph/backend/dnnl/passes/constant_propagation.hpp" | ||
#include "graph/backend/dnnl/passes/insert_ops.hpp" | ||
#include "graph/backend/dnnl/passes/layout_propagation.hpp" | ||
#include "graph/backend/dnnl/passes/lower.hpp" | ||
#include "graph/backend/dnnl/passes/memory_planning.hpp" | ||
#include "graph/backend/dnnl/passes/transform.hpp" | ||
#include "graph/backend/dnnl/passes/utils.hpp" | ||
|
||
#include "graph/backend/dnnl/op_executable.hpp" | ||
namespace dnnl { | ||
namespace impl { | ||
namespace graph { | ||
namespace dnnl_impl { | ||
|
||
status_t genindex_t::compile_impl(const dnnl_partition_impl_t *part, | ||
const engine_t *g_engine, const std::vector<logical_tensor_t> &inputs, | ||
const std::vector<logical_tensor_t> &outputs) { | ||
p_engine_ = make_dnnl_engine(*g_engine); | ||
g_alloc_ | ||
= reinterpret_cast<graph::allocator_t *>(g_engine->get_allocator()); | ||
|
||
subgraph_ = std::make_shared<subgraph_t>(part->get_ops(), p_engine_, | ||
part->get_fpmath_mode(), part->get_use_blocked_layout(), true); | ||
BACKEND_DNNL_CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs)); | ||
|
||
subgraph_visualizer_t vis(part->id(), [this](const value_t *val) { | ||
return this->memory_planner_.get_memory_info(val); | ||
}); | ||
pass_pipeline_t pipeline(vis); | ||
|
||
BACKEND_DNNL_ADD_PASS(pipeline, lower_down); | ||
|
||
pipeline.reset_visualize_arg(true, false); | ||
|
||
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation); | ||
|
||
// constant propagation | ||
if (enabled_constant_cache()) { | ||
BACKEND_DNNL_ADD_PASS(pipeline, constant_propagation); | ||
} | ||
|
||
// bind the memory for each op | ||
auto memory_plan = [&](std::shared_ptr<subgraph_t> &sg) { | ||
return memory_planner_.run(sg); | ||
}; | ||
pipeline.reset_visualize_arg(true, true); | ||
BACKEND_DNNL_ADD_PASS(pipeline, memory_plan); | ||
BACKEND_DNNL_ADD_PASS(pipeline, compile_ops); | ||
|
||
// Run the added passes | ||
BACKEND_DNNL_CHECK(pipeline.run(subgraph_)); | ||
|
||
// fill information for outputs logical tensors | ||
for (size_t i = 0; i < outputs.size(); i++) { | ||
auto &out = const_cast<logical_tensor_t &>(outputs[i]); | ||
out = subgraph_->outs_[i]; | ||
} | ||
|
||
// generate a hash key for exec_args_mgr | ||
resource_ctor_ = [this]() { | ||
return this->memory_planner_.get_exec_args_set().clone(); | ||
}; | ||
|
||
return status::success; | ||
} | ||
|
||
void genindex_t::prepare_args_set(const execution_args_set_t *res, | ||
const std::vector<tensor_t> &inputs, | ||
const std::vector<tensor_t> &outputs, const scratchpad_t &scratchpad) { | ||
// update the data of partition in/outputs args | ||
for (const auto &mem_idx : res->get_mems_use_external_inputs()) { | ||
mem_idx.first.set_data_handle(inputs[mem_idx.second].get_data_handle()); | ||
} | ||
for (const auto &mem_idx : res->get_mems_use_external_outputs()) { | ||
mem_idx.first.set_data_handle( | ||
outputs[mem_idx.second].get_data_handle()); | ||
} | ||
|
||
grantor_t var_grantor = memory_planner_.internal_temporary_grantor( | ||
scratchpad.get_buffer()); | ||
|
||
for (auto &mem_offkey : res->get_mems_use_internal_temporary()) { | ||
mem_offkey.first.set_data_handle(var_grantor.get(mem_offkey.second)); | ||
} | ||
} | ||
|
||
status_t genindex_t::execute_impl(const stream_t *g_stream, | ||
const std::vector<tensor_t> &inputs, | ||
const std::vector<tensor_t> &outputs) { | ||
dnnl::stream p_stream = make_dnnl_stream(p_engine_, *g_stream); | ||
|
||
// each thread's own local resource | ||
thread_local_cache_t<execution_args_set_t> res_cache; | ||
execution_args_set_t *res = res_cache.get_or_add( | ||
reinterpret_cast<size_t>(this), resource_ctor_); | ||
|
||
temporary_scratchpad_t scratchpad( | ||
memory_planner_.total_internal_temporary_size(), p_engine_, | ||
*g_alloc_); | ||
assertm(scratchpad.size() | ||
>= memory_planner_.total_internal_temporary_size(), | ||
"no enough scratchpad memory"); | ||
prepare_args_set(res, inputs, outputs, scratchpad); | ||
|
||
constant_cache_t::cached_t c_buffer; | ||
|
||
for (size_t i = 0; i < subgraph_->execs_.size(); i++) { | ||
if (subgraph_->is_constant_[i]) continue; | ||
subgraph_->execs_[i]->execute(p_stream, res->get_exec_args()[i]); | ||
} | ||
|
||
return status::success; | ||
} | ||
#ifdef DNNL_WITH_SYCL | ||
status_t genindex_t::sycl_execute_impl(const stream_t *g_stream, | ||
const std::vector<tensor_t> &inputs, | ||
const std::vector<tensor_t> &outputs, | ||
const std::vector<::sycl::event> &sycl_deps, | ||
::sycl::event *sycl_event) { | ||
if (p_engine_.get_kind() == engine::kind::gpu) return status::unimplemented; | ||
return execute_impl(g_stream, inputs, outputs); | ||
} | ||
#endif | ||
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL | ||
status_t genindex_t::ocl_execute_impl(const stream_t *g_stream, | ||
const std::vector<tensor_t> &inputs, | ||
const std::vector<tensor_t> &outputs, | ||
const std::vector<cl_event> &ocl_deps, cl_event *ocl_event) { | ||
// TODO: add support | ||
return status::unimplemented; | ||
dzarukin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
#endif | ||
} // namespace dnnl_impl | ||
} // namespace graph | ||
} // namespace impl | ||
} // namespace dnnl |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion, as other operations are also using General, I suggest to change them all in a separate PR.