From e38a675f60cfa2039d67494f878d6d5e37772945 Mon Sep 17 00:00:00 2001 From: "Bao, Yixin" Date: Wed, 1 Jan 2025 22:35:35 -0800 Subject: [PATCH] benchdnn: graph: inputs: add a sdpa implicit causal mask case --- .../graph/complex_fusion/harness_mha_all | 1 + .../graph/complex_fusion/harness_mha_ci | 1 + ...a-plain-implicit-causal-mask-fp32-bs1.json | 530 ++++++++++++++++++ 3 files changed, 532 insertions(+) create mode 100644 tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json diff --git a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all index 05399f458fd..13d8e7ccd6d 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all +++ b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all @@ -14,6 +14,7 @@ --reset --dt=f32,bf16,f16 --case=complex_fusion/mha/GQA-fp16.json --reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json --reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-scale-by-mul-f16.json +--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json # int8 graphs --reset --case=complex_fusion/mha/MHA-GPT-inf-int8-bs1.json diff --git a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci index 8b86b687abc..d593483363e 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci +++ b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci @@ -12,6 +12,7 @@ --reset --dt=f32,bf16,f16 --case=complex_fusion/mha/GQA-fp16.json --reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json --reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-scale-by-mul-f16.json +--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json # int8 graphs --reset --case=complex_fusion/mha/MHA-GPT-inf-int8-bs1.json diff --git a/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json b/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json new file mode 100644 index 00000000000..b16217b31ca --- /dev/null +++ b/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json @@ -0,0 +1,530 @@ +{ + "version": "3.7.0", + "engine_kind": "cpu", + "fpmath_mode": "strict", + "fpmath_mode_apply_to_int": "false", + "input_ports": [ + 0, + 1, + 3, + 8, + 11 + ], + "output_ports": [ + 12 + ], + "graph": [ + { + "id": 0, + "name": "matmul_qk", + "kind": "MatMul", + "attrs": { + "transpose_a": { + "type": "bool", + "value": 0 + }, + "transpose_b": { + "type": "bool", + "value": 1 + } + }, + "inputs": [ + { + "id": 0, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 1, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 2, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 1, + "name": "scale_mul", + "kind": "Multiply", + "attrs": { + "auto_broadcast": { + "type": "string", + "value": "numpy" + } + }, + "inputs": [ + { + "id": 2, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 3, + "dtype": "f32", + "shape": [ + 1 + ], + "stride": [ + 1 + ], + "layout_type": "strided", + "property_type": "constant" + } + ], + "outputs": [ + { + "id": 4, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 2, + "name": "genindex_row", + "kind": "GenIndex", + "attrs": { + "axis": { + "type": "s64", + "value": 2 + } + }, + "inputs": [ + { + "id": 4, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 5, + "dtype": "s32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 3, + "name": "genindex_col", + "kind": "GenIndex", + "attrs": { + "axis": { + "type": "s64", + "value": 3 + } + }, + "inputs": [ + { + "id": 4, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 6, + "dtype": "s32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 4, + "name": "mask_greater_equal", + "kind": "GreaterEqual", + "attrs": { + "auto_broadcast": { + "type": "string", + "value": "numpy" + } + }, + "inputs": [ + { + "id": 5, + "dtype": "s32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 6, + "dtype": "s32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 7, + "dtype": "boolean", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 5, + "name": "Select", + "kind": "Select", + "attrs": { + "auto_broadcast": { + "type": "string", + "value": "numpy" + } + }, + "inputs": [ + { + "id": 7, + "dtype": "boolean", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 4, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 8, + "dtype": "f32", + "shape": [ + 1 + ], + "stride": [ + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 9, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 6, + "name": "softmax", + "kind": "SoftMax", + "attrs": { + "axis": { + "type": "s64", + "value": -1 + } + }, + "inputs": [ + { + "id": 9, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 10, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 7, + "name": "matmul_v", + "kind": "MatMul", + "attrs": { + "transpose_a": { + "type": "bool", + "value": 0 + }, + "transpose_b": { + "type": "bool", + "value": 0 + } + }, + "inputs": [ + { + "id": 10, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 11, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 12, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + } + ] +} \ No newline at end of file