Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu] Optimize MatMulNBits for f16 Block32 prefill performance #23908

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

daijh
Copy link
Contributor

@daijh daijh commented Mar 6, 2025

Description

This commit improve the MatMulNBits f16 Block32 prefill performance, by increasing tiling size and enhancing memory efficiency. Achieved a +2x performance boost on Intel iGPUs for Phi-3.5-mini f16 model.

Motivation and Context

See above.

@daijh
Copy link
Contributor Author

daijh commented Mar 6, 2025

Tests:

model_benchmark.exe -i Phi-3.5-mini-instruct-onnx-web -l 1000
Prompt-1000 Prefill-default (tps) Prefill-opt (tps)
LNL 14.5829 327.627
MTL 61.4833 160.695
ADL 45.1106 101.871

@daijh
Copy link
Contributor Author

daijh commented Mar 6, 2025

@qjia7 @sushraja-msft @jchen10
Please take a look, thanks.

@daijh
Copy link
Contributor Author

daijh commented Mar 6, 2025

Add shader for easy review.

enable f16;
enable subgroups_f16;
enable subgroups;
const workgroup_size_x: u32 = 128;
const workgroup_size_y: u32 = 1;
const workgroup_size_z: u32 = 1;
@group(0) @binding(0) var<storage, read> input_a: array<vec4<f16>>;
@group(0) @binding(1) var<storage, read> input_b: array<vec4<u32>>;
@group(0) @binding(2) var<storage, read> scales: array<f16>;
@group(0) @binding(3) var<storage, read_write> output: array<f16>;
struct Uniforms {
  input_a_shape: vec3<u32>,
  input_a_stride: vec2<u32>,
  input_b_shape: vec3<u32>,
  input_b_stride: vec2<u32>,
  output_shape: vec3<u32>,
  output_stride: vec2<u32>,
  block_size: u32
};
@group(0) @binding(4) var<uniform> uniforms: Uniforms;

alias input_a_value_t = vec4<f16>;
alias input_a_indices_t = vec3<u32>;
fn i2o_input_a(indices : input_a_indices_t)->u32 {
  return indices[0] * uniforms.input_a_stride[0] + indices[1] * uniforms.input_a_stride[1] + indices[2];
}
fn get_input_a_by_indices(indices: input_a_indices_t)->input_a_value_t {
  return input_a[i2o_input_a(indices)];
}
alias input_b_value_t = vec4<u32>;
alias input_b_indices_t = vec3<u32>;
fn i2o_input_b(indices : input_b_indices_t)->u32 {
  return indices[0] * uniforms.input_b_stride[0] + indices[1] * uniforms.input_b_stride[1] + indices[2];
}
fn get_input_b_by_indices(indices: input_b_indices_t)->input_b_value_t {
  return input_b[i2o_input_b(indices)];
}
alias output_value_t = f16;
alias output_indices_t = vec3<u32>;
alias output_element_t = f16;
fn i2o_output(indices : output_indices_t)->u32 {
  return indices[0] * uniforms.output_stride[0] + indices[1] * uniforms.output_stride[1] + indices[2];
}
fn set_output_by_indices(indices: output_indices_t, value: output_value_t) {
  output[i2o_output(indices)]=value;
}

fn mm_read_a(batch : u32, row : u32, col : u32) -> input_a_value_t {
  if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {
    return get_input_a_by_indices(input_a_indices_t(batch, row, col));
  }
  return input_a_value_t(0);
}

fn mm_read_b(row : u32, col : u32) -> input_b_value_t {
  if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {
    return get_input_b_by_indices(input_b_indices_t(row, col, 0));
  }
  return input_b_value_t(0);
}

fn mm_read_scale(row : u32, col : u32) -> output_value_t {
  if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {
    return scales[row * uniforms.input_b_shape[1] + col];
  }
  return output_value_t(0);
}

fn mm_write_y(batch : u32, row : u32, col : u32, value : output_value_t) {
  if (row < uniforms.output_shape[1] && col < uniforms.output_shape[2]) {
    set_output_by_indices(output_indices_t(batch, row, col), value);
  }
}

const tile_m = 16u;
const tile_n = 128u;

var<workgroup> a_data_wg: array<array<input_a_value_t, 8u>, tile_m>;

@compute @workgroup_size(workgroup_size_x, workgroup_size_y, workgroup_size_z)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>,
        @builtin(workgroup_id) workgroup_id : vec3<u32>,
        @builtin(local_invocation_index) local_idx : u32,
        @builtin(local_invocation_id) local_id : vec3<u32>,
        @builtin(subgroup_invocation_id) sg_id : u32,
        @builtin(subgroup_size) sg_size : u32,
        @builtin(num_workgroups) num_workgroups : vec3<u32>) {
  let workgroup_idx = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x;
  let global_idx = workgroup_idx * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;

  let batch = workgroup_id.z;
  let row = workgroup_id.y * tile_m;
  let col = workgroup_id.x * tile_n;

  let a_elements_per_col = uniforms.input_a_shape[2];
  // A block32 containing 8 elements of `a`.
  let a_blocks_per_col = (a_elements_per_col + 7u) / 8u;

  // f32 accumulator
  var results : array<f32, tile_m>;
  for (var a_block_idx = 0u; a_block_idx < a_blocks_per_col; a_block_idx++) {
    // load `a` elements into workgroup memory, TileM x 8(block32).
    let a_row_idx = local_idx / 8u;
    let a_col_idx = local_idx % 8u;
    a_data_wg[a_row_idx][a Effect_col_idx] = mm_read_a(batch, row + a_row_idx, a_block_idx * 8u + a_col_idx);
    workgroupBarrier();

    let b_row = col + local_idx;
    let b_col = a_block_idx;

    let b_data = mm_read_b(b_row, b_col);
    let scale = mm_read_scale(b_row, b_col);
    let zero_point = output_element_t(8.0);

    for (var b_idx = 0u; b_idx < 4u; b_idx++) {
      let b_value = b_data[b_idx];
      let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
      let b_value_upper = unpack4xU8((b_value >> 4u) & 0x0F0F0F0Fu);
      let b_quantized_values = mat2x4<output_element_t>(
          output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]),
          output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]),
          output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]),
          output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));
      let b_dequantized_values =
          (b_quantized_values - mat2x4<output_element_t>(zero_point, zero_point,
                                                        zero_point, zero_point,
                                                        zero_point, zero_point,
                                                        zero_point, zero_point)) * scale;

      for (var m_idx = 0u; m_idx < tile_m; m_idx++) {
        let a_data0 = a_data_wg[m_idx][b_idx * 2u];
        let a_data1 = a_data_wg[m_idx][b_idx * 2u + 1u];

        results[m_idx] += f32(dot(a_data0, b_dequantized_values[0u])) +
                          f32(dot(a_data1, b_dequantized_values[1u]));
      }
    }

    workgroupBarrier();
  }

  // write the results
  for (var m_idx = 0u; m_idx < tile_m; m_idx++) {
    mm_write_y(batch, row + m_idx, col + local_idx, output_value_t(results[m_idx]));
  }

}

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Mar 6, 2025
@guschmue
Copy link
Contributor

guschmue commented Mar 6, 2025

/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline

Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@guschmue
Copy link
Contributor

guschmue commented Mar 6, 2025

/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline

@guschmue
Copy link
Contributor

guschmue commented Mar 6, 2025

/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models

@guschmue
Copy link
Contributor

guschmue commented Mar 6, 2025

/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI

Copy link

Azure Pipelines successfully started running 4 pipeline(s).

Copy link

Azure Pipelines successfully started running 9 pipeline(s).

Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@@ -867,6 +978,40 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
return context.RunProgram(mul_program);
}

// Block32 prefill program
// This program is optimized for Block32 prefill using Tile16x128.
const bool use_block32_program = block_size == 32 && batch_count == 1 && !has_zero_points &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what was the impact to generation speed ? Should you restrict this shader with a M > kMinMForTileOptimization check.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you also support batch size other than 1 and zero points, in your shader perhaps relax that check. Okay to do in a follow up change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what was the impact to generation speed ? Should you restrict this shader with a M > kMinMForTileOptimization check.

Thanks for the review.

The performance is similar with default shader when M ==1.
I only see performance improvement when M is greater than 2, according to the test results.

I will implement a restriction, M > kMinMForTileOptimization, to enforce this requirement."

// This program is optimized for Block32 prefill using Tile16x128.
const bool use_block32_program = block_size == 32 && batch_count == 1 && !has_zero_points &&
components_a == 4 && components_b == 4 && M > 1 &&
context.AdapterInfo().vendor == std::string_view{"intel"};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about this shader makes it intel specific, can we remove this check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only got intel devices to test the performance at the moments.
After verification on broaden devices, it's surely can be removed in a follow up change.

@@ -867,6 +978,40 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
return context.RunProgram(mul_program);
}

// Block32 prefill program
// This program is optimized for Block32 prefill using Tile16x128.
const bool use_block32_program = block_size == 32 && batch_count == 1 && !has_zero_points &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you also support batch size other than 1 and zero points, in your shader perhaps relax that check. Okay to do in a follow up change.

@@ -781,6 +781,117 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
return Status::OK();
}

Status MatMulNBitsBlock32Program::GenerateShaderCode(ShaderHelper& shader) const {
Copy link
Contributor

@sushraja-msft sushraja-msft Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for working on this, some thoughts. Your shader looks to be a generation optimized shader with different tile size than the current one. As far as matmul goes there are 2 genres of shaders with each genre having variants for special ops they use.

Generation Optimized Shaders - these will keep only A in shared memory - pool all threads to load A into shared memory and then have each thread work on a B from that A.

Prefill Optimization Shaders - These should use co-operative matmul - https://www.khronos.org/assets/uploads/developers/presentations/Cooperative_Matrix_May22.pdf
They keep both a and b in shared memory. Pool all threads to load shared memory and then each subgroup within the workgroup works on a subtile. This results in parts of the loads required for a subtile to be shared with other subtiles and hence saves loads.

From what I can tell yours is a generation mode shader, if you are seeing good perf with this tile size -we should just replace the current generation shader with yours. Even better if we can make these shaders have the tile sizes as tunable.

Net, I think we should try to avoid having similar shaders that don't differ algorithmically. Please do share numbers for generation perf with your shader, perhaps we can replace the current generation shader with yours.

As to why you are seeing great prefill speed, its because our prefill fp16 shader is not based on co-operative matmul (we havent got around to rewriting that shader that way, if you can pick that up that would be amazing as well). The DP4A matmul shader is using techniques of co-operative matmul, and we are using that for many models by passing accuracy_level 4 with model_builder.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can tell yours is a generation mode shader, if you are seeing good perf with this tile size -we should just replace the current generation shader with yours. Even better if we can make these shaders have the tile sizes as tunable.

This shader optimizes Input_A loading by leveraging workgroup-wide collective load operations within a workgroup(128), storing a 16x8 tile into shared memory with a single instruction.

This approach increases tiling size, resulting in performance improvement when the input matrix 'M' is sufficiently large.

Specifically, for M=1, the performance does not exceed that of the default decode shader.

@daijh
Copy link
Contributor Author

daijh commented Mar 7, 2025

From what I can tell yours is a generation mode shader, if you are seeing good perf with this tile size -we should just replace the current generation shader with yours. Even better if we can make these shaders have the tile sizes as tunable.

I'm trying to avoid making too many modifications in a single PR to keep it easier review, and comparable with previous shader.
If accepted, I'll subsequently integrate its improvements into the default shader prefill path (as decode performance is not improved).

What are your thoughts?

@daijh
Copy link
Contributor Author

daijh commented Mar 7, 2025

As to why you are seeing great prefill speed, its because our prefill fp16 shader is not based on co-operative matmul (we havent got around to rewriting that shader that way, if you can pick that up that would be amazing as well). The DP4A matmul shader is using techniques of co-operative matmul, and we are using that for many models by passing accuracy_level 4 with model_builder.py.

Yes, we observed quite good performance at accuracy level 4 using the DP4A shader. I'll investigate similar for f16.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants