-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[webgpu] Optimize MatMulNBits for f16 Block32 prefill performance #23908
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -781,6 +781,117 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { | |
return Status::OK(); | ||
} | ||
|
||
Status MatMulNBitsBlock32Program::GenerateShaderCode(ShaderHelper& shader) const { | ||
const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); | ||
const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); | ||
const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); | ||
const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); | ||
|
||
const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); | ||
|
||
// memory read/write helpers | ||
shader.AdditionalImplementation() << "fn mm_read_a(batch : u32, row : u32, col : u32) -> input_a_value_t {\n"; | ||
shader.AdditionalImplementation() << " if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n"; | ||
shader.AdditionalImplementation() << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n"; | ||
shader.AdditionalImplementation() << " }\n"; | ||
shader.AdditionalImplementation() << " return input_a_value_t(0);\n"; | ||
shader.AdditionalImplementation() << "}\n"; | ||
|
||
shader.AdditionalImplementation() << "\n"; | ||
shader.AdditionalImplementation() << "fn mm_read_b(row : u32, col : u32) -> input_b_value_t {\n"; | ||
shader.AdditionalImplementation() << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n"; | ||
shader.AdditionalImplementation() << " return " << b.GetByIndices("input_b_indices_t(row, col, 0)") << ";\n"; | ||
shader.AdditionalImplementation() << " }\n"; | ||
shader.AdditionalImplementation() << " return input_b_value_t(0);\n"; | ||
shader.AdditionalImplementation() << "}\n"; | ||
|
||
shader.AdditionalImplementation() << "\n"; | ||
shader.AdditionalImplementation() << "fn mm_read_scale(row : u32, col : u32) -> output_value_t {\n"; | ||
shader.AdditionalImplementation() << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n"; | ||
shader.AdditionalImplementation() << " return " << scales.GetByOffset("row * uniforms.input_b_shape[1] + col") << ";\n"; | ||
shader.AdditionalImplementation() << " }\n"; | ||
shader.AdditionalImplementation() << " return output_value_t(0);\n"; | ||
shader.AdditionalImplementation() << "}\n"; | ||
|
||
shader.AdditionalImplementation() << "\n"; | ||
shader.AdditionalImplementation() << "fn mm_write_y(batch : u32, row : u32, col : u32, value : output_value_t) {\n"; | ||
shader.AdditionalImplementation() << " if (row < uniforms.output_shape[1] && col < uniforms.output_shape[2]) {\n"; | ||
shader.AdditionalImplementation() << " " << y.SetByIndices("output_indices_t(batch, row, col)", "value") << "\n"; | ||
shader.AdditionalImplementation() << " }\n"; | ||
shader.AdditionalImplementation() << "}\n"; | ||
|
||
// declare const variables | ||
shader.AdditionalImplementation() << "\n"; | ||
shader.AdditionalImplementation() << "const tile_m = " << workgroup_size / 8 << "u;\n"; | ||
shader.AdditionalImplementation() << "const tile_n = " << workgroup_size << "u;\n"; | ||
|
||
// declare workgroup memory | ||
shader.AdditionalImplementation() << "\n"; | ||
shader.AdditionalImplementation() << "var<workgroup> a_data_wg: array<array<input_a_value_t, 8u>, tile_m>;\n"; | ||
shader.AdditionalImplementation() << "\n"; | ||
|
||
// main | ||
shader.MainFunctionBody() << R"MAIN_FN( | ||
let batch = workgroup_id.z; | ||
let row = workgroup_id.y * tile_m; | ||
let col = workgroup_id.x * tile_n; | ||
let a_elements_per_col = uniforms.input_a_shape[2]; | ||
// A block32 containing 8 elements of `a`. | ||
let a_blocks_per_col = (a_elements_per_col + 7u) / 8u; | ||
// f32 accumulator | ||
var results : array<f32, tile_m>; | ||
for (var a_block_idx = 0u; a_block_idx < a_blocks_per_col; a_block_idx++) { | ||
// load `a` elements into workgroup memory, TileM x 8(block32) | ||
let a_row_idx = local_idx / 8u; | ||
let a_col_idx = local_idx % 8u; | ||
a_data_wg[a_row_idx][a_col_idx] = mm_read_a(batch, row + a_row_idx, a_block_idx * 8u + a_col_idx); | ||
workgroupBarrier(); | ||
let b_row = col + local_idx; | ||
let b_col = a_block_idx; | ||
let b_data = mm_read_b(b_row, b_col); | ||
let scale = mm_read_scale(b_row, b_col); | ||
let zero_point = output_element_t(8.0); | ||
for (var b_idx = 0u; b_idx < 4u; b_idx++) { | ||
let b_value = b_data[b_idx]; | ||
let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu); | ||
let b_value_upper = unpack4xU8((b_value >> 4u) & 0x0F0F0F0Fu); | ||
let b_quantized_values = mat2x4<output_element_t>( | ||
output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), | ||
output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), | ||
output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), | ||
output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3])); | ||
let b_dequantized_values = | ||
(b_quantized_values - mat2x4<output_element_t>(zero_point, zero_point, | ||
zero_point, zero_point, | ||
zero_point, zero_point, | ||
zero_point, zero_point)) * scale; | ||
for (var m_idx = 0u; m_idx < tile_m; m_idx++) { | ||
let a_data0 = a_data_wg[m_idx][b_idx * 2u]; | ||
let a_data1 = a_data_wg[m_idx][b_idx * 2u + 1u]; | ||
results[m_idx] += f32(dot(a_data0, b_dequantized_values[0u])) + | ||
f32(dot(a_data1, b_dequantized_values[1u])); | ||
} | ||
} | ||
workgroupBarrier(); | ||
} | ||
// write the results | ||
for (var m_idx = 0u; m_idx < tile_m; m_idx++) { | ||
mm_write_y(batch, row + m_idx, col + local_idx, output_value_t(results[m_idx])); | ||
} | ||
)MAIN_FN"; | ||
|
||
return Status::OK(); | ||
} | ||
|
||
Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { | ||
const Tensor* a = context.Input(0); | ||
const Tensor* b = context.Input(1); | ||
|
@@ -867,6 +978,40 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context | |
return context.RunProgram(mul_program); | ||
} | ||
|
||
// Block32 prefill program | ||
// This program is optimized for Block32 prefill using Tile16x128. | ||
const bool use_block32_program = block_size == 32 && batch_count == 1 && !has_zero_points && | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what was the impact to generation speed ? Should you restrict this shader with a M > kMinMForTileOptimization check. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you also support batch size other than 1 and zero points, in your shader perhaps relax that check. Okay to do in a follow up change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Thanks for the review. The performance is similar with default shader when M ==1. I will implement a restriction, M > kMinMForTileOptimization, to enforce this requirement." |
||
components_a == 4 && components_b == 4 && M > 1 && | ||
context.AdapterInfo().vendor == std::string_view{"intel"}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about this shader makes it intel specific, can we remove this check. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I only got intel devices to test the performance at the moments. |
||
if (use_block32_program) { | ||
// enforce components to 1. | ||
components = 1; | ||
|
||
constexpr uint32_t workgroup_size = 128; | ||
constexpr uint32_t tile_m = workgroup_size / 8; | ||
constexpr uint32_t tile_n = workgroup_size; | ||
|
||
MatMulNBitsBlock32Program program{}; | ||
program.SetWorkgroupSize(workgroup_size); | ||
program.SetDispatchGroupSize((N + tile_n - 1) / tile_n, | ||
(M + tile_m - 1) / tile_m, | ||
batch_count); | ||
program.CacheHint("Tile" + std::to_string(tile_m) + "x" + std::to_string(tile_n) + "_Block32"); | ||
|
||
TensorShape reshaped_a_shape{batch_count, M, K / components_a}; | ||
TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b}; | ||
TensorShape reshaped_y_shape{batch_count, M, N / components}; | ||
|
||
program | ||
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow<int>(components_a)}, | ||
{b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow<int>(components_b * 4)}, | ||
{scales, ProgramTensorMetadataDependency::None}}) | ||
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int>(components)}) | ||
.AddUniformVariable({block_size}); | ||
return context.RunProgram(program); | ||
} | ||
|
||
// Generic program | ||
// TODO: Support output_number > 1. Some cases are failed when output_number > 1. | ||
constexpr uint32_t output_number = 1; | ||
const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you for working on this, some thoughts. Your shader looks to be a generation optimized shader with different tile size than the current one. As far as matmul goes there are 2 genres of shaders with each genre having variants for special ops they use.
Generation Optimized Shaders - these will keep only A in shared memory - pool all threads to load A into shared memory and then have each thread work on a B from that A.
Prefill Optimization Shaders - These should use co-operative matmul - https://www.khronos.org/assets/uploads/developers/presentations/Cooperative_Matrix_May22.pdf
They keep both a and b in shared memory. Pool all threads to load shared memory and then each subgroup within the workgroup works on a subtile. This results in parts of the loads required for a subtile to be shared with other subtiles and hence saves loads.
From what I can tell yours is a generation mode shader, if you are seeing good perf with this tile size -we should just replace the current generation shader with yours. Even better if we can make these shaders have the tile sizes as tunable.
Net, I think we should try to avoid having similar shaders that don't differ algorithmically. Please do share numbers for generation perf with your shader, perhaps we can replace the current generation shader with yours.
As to why you are seeing great prefill speed, its because our prefill fp16 shader is not based on co-operative matmul (we havent got around to rewriting that shader that way, if you can pick that up that would be amazing as well). The DP4A matmul shader is using techniques of co-operative matmul, and we are using that for many models by passing accuracy_level 4 with model_builder.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shader optimizes Input_A loading by leveraging workgroup-wide collective load operations within a workgroup(128), storing a 16x8 tile into shared memory with a single instruction.
This approach increases tiling size, resulting in performance improvement when the input matrix 'M' is sufficiently large.
Specifically, for M=1, the performance does not exceed that of the default decode shader.