Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gpu][rocm] Numeric mismatch with/out pad-to-intrinsics for convolution with trailing fusion #20024

Open
jerryyin opened this issue Feb 18, 2025 · 2 comments
Assignees
Labels
bug 🐞 Something isn't working

Comments

@jerryyin
Copy link
Member

jerryyin commented Feb 18, 2025

What happened?

We (@nirvedhmeshram, @Max191) discovered a numeric mismatch for SDXL dispatch 13. Dispatch 13 is convolution + elementwise trailing fusion. Appending or not appending pre-processing pass iree-preprocessing-pad-to-intrinsics will yield a value mismatch in kernel compute results. This shouldn't happen.

The failure looks like the following when testing the standalone kernel extracted from dispatch 13:

[FAILED] result[0]: element at index 59301 (11.1562) does not match the expected (11.1641); expected that the view is equal to contents of a view of 2x128x128x320xf16

Note, it is essential to include the full kernel in the test, either a standalone convolution or a standalone element wise trailing kernel will not reproduce the value mismatches.

Steps to reproduce your issue

To reproduce, please refer to the following IR and scripts. Generate a random sequence of input, weights and element wise vector using python's np.random and run the bash script.

The dispatch_13:

!INPUT_TYPE = tensor<2x130x130x4xf16>
!WEIGHT_TYPE = tensor<3x3x4x320xf16>
!ELEMENT_TYPE = tensor<320xf16>
!CONV_TYPE = tensor<2x128x128x320xf32>
!OUTPUT_TYPE = tensor<2x128x128x320xf16>

func.func @dispatch_13(%input : !INPUT_TYPE, %weight : !WEIGHT_TYPE, %7 : !ELEMENT_TYPE) -> !OUTPUT_TYPE {
    %c0 = arith.constant 0.0 : f32
    %empty = tensor.empty() : !CONV_TYPE
    %fill = linalg.fill ins(%c0 : f32) outs(%empty : !CONV_TYPE) -> !CONV_TYPE
    %conv = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%input, %weight : !INPUT_TYPE, !WEIGHT_TYPE) outs(%fill : !CONV_TYPE) -> !CONV_TYPE
    %9 = tensor.empty() : !OUTPUT_TYPE
    %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%conv, %7 : !CONV_TYPE, !ELEMENT_TYPE) outs(%9 : !OUTPUT_TYPE) {
    ^bb0(%in: f32, %in_0: f16, %out: f16):
      %14 = arith.extf %in_0 : f16 to f32
      %15 = arith.addf %in, %14 : f32
      %16 = arith.truncf %15 : f32 to f16
      linalg.yield %16 : f16
    } -> !OUTPUT_TYPE
    return %13 : !OUTPUT_TYPE
}

Compile and test scripts:

#!/bin/bash

mlir_test_file="test_conv.mlir"
CONV_INPUT="2x130x130x4xf16"
CONV_WEIGHTS="3x3x4x320xf16"
VECTOR="320xf16"
CONV_OUTPUT="2x128x128x320xf32"


# Function to compile
compile() {
    echo "Compiling modules..."
    iree-compile --iree-hal-target-backends=rocm --iree-hip-target=gfx942 $mlir_test_file -o ${PAD} \
--iree-preprocessing-pass-pipeline="builtin.module(util.func(iree-flow-canonicalize), \
iree-preprocessing-transpose-convolution-pipeline, \
iree-preprocessing-pad-to-intrinsics)" 

    iree-compile --iree-hal-target-backends=rocm --iree-hip-target=gfx942 $mlir_test_file -o ${NO_PAD} \
--iree-preprocessing-pass-pipeline="builtin.module(util.func(iree-flow-canonicalize), \
iree-preprocessing-transpose-convolution-pipeline)" 
}

# Function to run correctness tests
test() {
    echo "Running correctness tests..."
    iree-run-module --device=hip --module=${PAD} --input="$CONV_INPUT=@$CONV_INPUT.bin" --input="$CONV_WEIGHTS=@$CONV_WEIGHTS.bin" --input="$VECTOR=@$VECTOR.bin" --output=@with_pad0.npy --function=dispatch_13
    iree-run-module --device=hip --module=${NO_PAD} --input="$CONV_INPUT=@$CONV_INPUT.bin" --input="$CONV_WEIGHTS=@$CONV_WEIGHTS.bin" --input="$VECTOR=@$VECTOR.bin" --expected_output=@with_pad0.npy --function=dispatch_13
}

compile
test

What component(s) does this issue relate to?

No response

Version information

I'm using 5767be3, but should be able to reproduce on latest of main branch.

Additional context

No response

@jerryyin jerryyin added the bug 🐞 Something isn't working label Feb 18, 2025
@nirvedhmeshram
Copy link
Contributor

Since this is an issue with truncation I wanted to share some IR of how this trunc is done in the two cases, @MaheshRavishankar , @qedawkins any theories why these two truncations could be different, Here is the program level padded code and here is the dispatch level padded code

@jerryyin
Copy link
Member Author

The investigation so far pointing at the issue is a red herring.

MFMA grouping and value differences in compute

The two cases (whether or not we have pad-to-intrinsics pass will yield in different gemm problems and ended in different MFMA sequence:

  1. In case of pad-to-intrinsics, the input and filter's channel dimension is padded from 4 to 16. So the gemmK dimension become 16x3x3=144
  2. In case of no pad-to-intrinsics, the gemmK dimension get padded from 3x3x4 = 36 to 48 (least multiple of 16)

For 1: after PadToIntrinsicsPass, note gemmK dimension is padded from 4 to 16, this results in gemmK to be 144 later

    %padded = tensor.pad %0 low[0, 0, 0, 0] high[0, 0, 0, 12] {
    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
      tensor.yield %cst_0 : f16
    } : tensor<2x130x130x4xf16> to tensor<2x130x130x16xf16>
    %cst_1 = arith.constant 0.000000e+00 : f16
    %padded_2 = tensor.pad %1 low[0, 0, 0, 0] high[0, 0, 12, 0] {
    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
      tensor.yield %cst_1 : f16
    } : tensor<3x3x4x320xf16> to tensor<3x3x16x320xf16>
    %5 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%padded, %padded_2 : tensor<2x130x130x16xf16>, tensor<3x3x16x320xf16>) outs(%4 : tensor<2x128x128x320xf32>) -> tensor<2x128x128x320xf32>
    ...
  %7 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] m_offset = [0, 0] * [128, 1] k_offset = [0] * [1] batch_pos = [0] m_pos = [1, 2] k_pos = [3] 
    ins(%3 : tensor<2x130x130x16xf16>) outs(%6 : tensor<2x128x128x144xf16>) -> tensor<2x128x128x144xf16>

For 2: after GPUPadOperandsPass, without pad-to-intrinsics, note gemmK dimension is padded from 36 to 48

    %padded = tensor.pad %8 low[0, 0, 0, 0] high[0, 0, 0, 12] {
    ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
      tensor.yield %cst_2 : f16
    } : tensor<1x1x128x36xf16> to tensor<1x1x128x48xf16>
    %cst_3 = arith.constant 0.000000e+00 : f16
    %padded_4 = tensor.pad %extracted_slice_0 low[0, 0] high[12, 0] {
    ^bb0(%arg4: index, %arg5: index):
      tensor.yield %cst_3 : f16
    } : tensor<36x64xf16> to tensor<48x64xf16>
    %10 = linalg.generic {
    indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], 
    iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} 
    ins(%padded, %padded_4 : tensor<1x1x128x48xf16>, tensor<48x64xf16>) 
    outs(%9 : tensor<1x1x128x64xf32>) attrs =  
    {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, padding = [1, 1, 128, 64, 16], 
    promote_operands = [0, 1, 2], reduction = [0, 0, 0, 0, 1], subgroup = [1, 1, 8, 1, 0], workgroup = [1, 1, 128, 64, 0]}>} {
    ...

Accuracy confirmation with real sdxl data

@Max191 is able to confirm that placement of convert-filter-layout pass will impact the result values, but still within threshold of the accuracy verification of fp8 sdxl model with real weights.

So another explanation that the originally recorded error is large because this is input/filter relevant and may be due to the random generated input/filter exponent larger, causing a larger rounding error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants