Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fold arith.exti operation into conv #17765

Merged
merged 1 commit into from
Jul 2, 2024

Conversation

IanWood1
Copy link
Contributor

@IanWood1 IanWood1 commented Jun 27, 2024

@IanWood1 IanWood1 added benchmarks:cuda Run default CUDA benchmarks benchmarks:x86_64 Run default x86_64 benchmarks benchmarks:comp-stats Run default compilation statistics benchmarks benchmarks:android-cpu Run default Android CPU benchmarks benchmarks:android-gpu Run default Android GPU benchmarks benchmarks:vulkan-nvidia Run default Vulkan benchmarks on NVIDIA GPU labels Jun 27, 2024
@IanWood1 IanWood1 self-assigned this Jun 27, 2024
@IanWood1 IanWood1 marked this pull request as ready for review June 28, 2024 15:41
@IanWood1 IanWood1 requested a review from hanhanW as a code owner June 28, 2024 15:41
Copy link

github-actions bot commented Jun 28, 2024

Abbreviated Benchmark Summary

@ commit b43ea7037a2fb5ed008173160049c9b31ccc6114 (vs. base dcba7c56799a7303c347baf7ec21e9ca07a56fec)

Data-Tiling Comparison Table

Click to show
Name No-DT (baseline) DT-Only DT-UK
BertLargeTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 763.164 (1.0X) N/A 234.469 (3.3X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 7.343 (1.0X) N/A 8.690 (0.8X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 35.903 (1.0X) N/A 35.450 (1.0X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.851 (1.0X) N/A 5.081 (1.2X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 9.154 (1.0X) N/A 8.409 (1.1X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 10.995 (1.0X) N/A 8.887 (1.2X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.779 (1.0X) N/A 14.052 (0.8X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.303 (1.0X) N/A 62.902 (0.5X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.433 (1.0X) N/A 63.718 (0.5X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 68.459 (1.0X) N/A 64.768 (1.1X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 4.822 (1.0X) N/A 4.614 (1.0X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 3.788 (1.0X) N/A 4.926 (0.8X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 6.032 (1.0X) N/A 5.551 (1.1X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 2.876 (1.0X) N/A 2.780 (1.0X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 8.525 (1.0X) N/A 10.018 (0.9X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 0.766 (1.0X) N/A 0.585 (1.3X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 4.346 (1.0X) N/A 5.336 (0.8X)
matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 7.579 (1.0X) N/A 7.604 (1.0X)
matmul_256x256x2048_i8_i8_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 6.604 (1.0X) N/A 1.807 (3.7X)
BertForMaskedLMTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 224.417 (1.0X) N/A 109.265 (2.1X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 32.174 (1.0X) N/A 29.796 (1.1X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 272.379 (1.0X) N/A 229.900 (1.2X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 27.037 (1.0X) N/A 13.053 (2.1X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 70.579 (1.0X) N/A 40.378 (1.7X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 88.465 (1.0X) N/A 41.887 (2.1X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 79.739 (1.0X) N/A 56.754 (1.4X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 179.697 (1.0X) N/A 185.966 (1.0X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 181.523 (1.0X) N/A 191.474 (0.9X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 516.053 (1.0X) N/A 240.449 (2.1X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 25.178 (1.0X) N/A 18.085 (1.4X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 12.030 (1.0X) N/A 11.326 (1.1X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 21.693 (1.0X) N/A 11.812 (1.8X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 2.818 (1.0X) N/A 2.666 (1.1X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.454 (1.0X) N/A 31.606 (1.1X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.698 (1.0X) N/A 0.521 (1.3X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 17.816 (1.0X) N/A 19.413 (0.9X)
matmul_1x256x2048_i8_i4_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.054 (1.0X) N/A 0.054 (1.0X)
matmul_1x256x2048_i8_i8_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.042 (1.0X) N/A 0.021 (2.0X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 48.068 (1.0X) N/A 42.981 (1.1X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 49.752 (1.0X) N/A 43.135 (1.2X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 30.010 (1.0X) N/A 27.276 (1.1X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 92.762 (1.0X) N/A 21.204 (4.4X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 93.988 (1.0X) N/A 21.371 (4.4X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 52.531 (1.0X) N/A 21.772 (2.4X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 137.249 (1.0X) N/A 27.230 (5.0X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 123.532 (1.0X) N/A 28.698 (4.3X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 69.577 (1.0X) N/A 26.652 (2.6X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 699.944 (1.0X) N/A 350.854 (2.0X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 696.793 (1.0X) N/A 355.235 (2.0X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 390.592 (1.0X) N/A 214.994 (1.8X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 1060.606 (1.0X) N/A 276.603 (3.8X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 1062.456 (1.0X) N/A 272.013 (3.9X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 548.233 (1.0X) N/A 158.411 (3.5X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 2061.367 (1.0X) N/A 292.001 (7.1X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 2062.163 (1.0X) N/A 294.662 (7.0X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 1086.570 (1.0X) N/A 174.936 (6.2X)
matmul_1x256x2048_i8_i4_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 0.079 (1.0X) N/A 0.016 (5.1X)
matmul_1x256x2048_i8_i8_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 0.071 (1.0X) N/A 0.017 (4.3X)
matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 12.050 (1.0X) N/A 1.323 (9.1X)
matmul_256x256x2048_i8_i8_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 16.534 (1.0X) N/A 1.054 (15.7X)

Regressed Latencies 🚩

Benchmark Name Average Latency (ms) Median Latency (ms) Latency Standard Deviation (ms)
MobileBertSquad\_fp16(tflite) [arm-valhall-vulkan\_android31-vulkan\_spirv][experimental-flags,fuse-padding,max-concurrency,demote-f32-to-f16] vulkan(none)[full-inference,default-flags] with default @ pixel-6-pro[gpu] 113.089 (vs. 90.388, 25.11%↑) 113.628 1.385
GPT2\_117M\_TF\_1X4XI32(stablehlo) [armv8.2-a-generic-linux\_android29-llvm\_cpu][experimental-flags,no-dt] local\_sync(embedded\_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 137.249 (vs. 124.680, 10.08%↑) 137.179 0.313
MobileBertSquad\_fp16(tflite) [arm-valhall-vulkan\_android31-vulkan\_spirv][default-flags,demote-f32-to-f16] vulkan(none)[full-inference,default-flags] with default @ pixel-6-pro[gpu] 90.875 (vs. 82.686, 9.90%↑) 87.736 4.926

[Top 3 out of 5 results showed]

Improved Latencies 🎉

Benchmark Name Average Latency (ms) Median Latency (ms) Latency Standard Deviation (ms)
GPT2\_117M\_TF\_1X4XI32(stablehlo) [armv8.2-a-generic-linux\_android29-llvm\_cpu][experimental-flags,no-dt] local\_task(embedded\_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 123.532 (vs. 139.878, 11.69%↓) 123.411 0.570
GPT2\_117M\_TF\_1X4XI32(stablehlo) [armv8.2-a-generic-linux\_android29-llvm\_cpu][experimental-flags,no-dt] local\_task(embedded\_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 69.577 (vs. 76.695, 9.28%↓) 69.559 0.164
GPT2\_117M\_TF\_1X1XI32(stablehlo) [armv8.2-a-generic-linux\_android29-llvm\_cpu][default-flags,dt-uk] local\_task(embedded\_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 21.371 (vs. 22.939, 6.83%↓) 21.370 0.052

[Top 3 out of 11 results showed]

No improved or regressed compilation metrics 🏖️

For more information:

Source Workflow Run

%2 = tensor.empty() : tensor<10x40xi32>
%3 = arith.constant 0 : i32
%4 = linalg.fill ins(%3 : i32) outs(%2 : tensor<10x40xi32>) -> tensor<10x40xi32>
%5 = linalg.matmul ins(%arg0, %1 : tensor<10x20xi32>, tensor<20x40xi32>)
Copy link
Contributor

@qedawkins qedawkins Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this was a linalg.matmul_unsigned this pattern would be incorrect because it would change the extension semantics. This is why we need to handle each op one by one for integer extends because different named ops have different extension semantics, unlike float ops which all imply extf

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure how this doesnt work, printing out the body of the matmul_unsigned shows that it is extended as a signed integer

module {
  util.func public @matmul_extsi(%arg0: tensor<10x20xi32>, %arg1: tensor<20x40xi16>) -> tensor<10x40xi32> {
    %c0_i32 = arith.constant 0 : i32
    %0 = tensor.empty() : tensor<10x40xi32>
    %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<10x40xi32>) -> tensor<10x40xi32>
    %2 = linalg.matmul_unsigned ins(%arg0, %arg1 : tensor<10x20xi32>, tensor<20x40xi16>) outs(%1 : tensor<10x40xi32>) -> tensor<10x40xi32>
    util.return %2 : tensor<10x40xi32>
  }
}

// Generic form
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
"builtin.module"() ({
  "util.func"() <{function_type = (tensor<10x20xi32>, tensor<20x40xi16>) -> tensor<10x40xi32>, sym_name = "matmul_extsi", sym_visibility = "public", tied_operands = [-1 : index]}> ({
  ^bb0(%arg0: tensor<10x20xi32>, %arg1: tensor<20x40xi16>):
    %0 = "arith.constant"() <{value = 0 : i32}> : () -> i32
    %1 = "tensor.empty"() : () -> tensor<10x40xi32>
    %2 = "linalg.fill"(%0, %1) <{operandSegmentSizes = array<i32: 1, 1>}> ({
    ^bb0(%arg5: i32, %arg6: i32):
      "linalg.yield"(%arg5) : (i32) -> ()
    }) : (i32, tensor<10x40xi32>) -> tensor<10x40xi32>
    %3 = "linalg.matmul_unsigned"(%arg0, %arg1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> ({
    ^bb0(%arg2: i32, %arg3: i16, %arg4: i32):
      %4 = "arith.extsi"(%arg3) : (i16) -> i32
      %5 = "arith.muli"(%arg2, %4) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
      %6 = "arith.addi"(%arg4, %5) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
      "linalg.yield"(%6) : (i32) -> ()
    }) {linalg.memoized_indexing_maps = [#map, #map1, #map2]} : (tensor<10x20xi32>, tensor<20x40xi16>, tensor<10x40xi32>) -> tensor<10x40xi32>
    "util.return"(%3) : (tensor<10x40xi32>) -> ()
  }) : () -> ()
}) : () -> ()

Copy link
Contributor

@qedawkins qedawkins Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that's the case then that's a bug. If I take the IR you just posted and run linalg-generalize-named-ops I get an extui

#map = affine_map<(d0, d1) -> ()>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
  util.func public @matmul_extsi(%arg0: tensor<10x20xi32>, %arg1: tensor<20x40xi16>) -> tensor<10x40xi32> {
    %c0_i32 = arith.constant 0 : i32
    %0 = tensor.empty() : tensor<10x40xi32>
    %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%c0_i32 : i32) outs(%0 : tensor<10x40xi32>) {
    ^bb0(%in: i32, %out: i32):
      linalg.yield %in : i32
    } -> tensor<10x40xi32>
    %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<10x20xi32>, tensor<20x40xi16>) outs(%1 : tensor<10x40xi32>) {
    ^bb0(%in: i32, %in_0: i16, %out: i32):
      %3 = arith.extui %in_0 : i16 to i32
      %4 = arith.muli %in, %3 : i32
      %5 = arith.addi %out, %4 : i32
      linalg.yield %5 : i32
    } -> tensor<10x40xi32>
    util.return %2 : tensor<10x40xi32>
  }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, seems like a bug. linalg-generalize-named-op is producing different output than mlir-print-op-generic

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ugh yeah I think the hidden region on linalg named ops is probably broken then. In all other contexts matmul_unsigned is assumed to mean extui. Looking at the list of linalg operations: https://mlir.llvm.org/docs/Dialects/Linalg/#operations

it looks like only matmul_unsigned has these unsigned extension semantics, so we can just add a special case in the pattern to fail if trying to fuse with a matmul_unsigned.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, looked at this further. linalg-generalize-named-ops does print extsi, but I wasn't running against the mlir I posted here. I started with a generic (signed cast) -> matmul_unsigned and used the modified pass from this pr. This manually places the extsi inside of the matmul_unsigned's region, but it doesn't get printed out. I'm not sure if changing regions like this is allowed/good.

Maybe it makes sense to only fold extsi with signed variants

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if changing regions like this is allowed/good.

Agree :) although I see it more as a function of named ops being over engineered. It's probably because we're trying to use interfaces instead of just using the named op builders, but linalg is awkward...

@IanWood1 IanWood1 force-pushed the fold_exti_into_conv2d branch from 18d2460 to d37eaf1 Compare June 28, 2024 16:24
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
@IanWood1 IanWood1 force-pushed the fold_exti_into_conv2d branch from d37eaf1 to abe6b8d Compare July 1, 2024 19:46
@IanWood1
Copy link
Contributor Author

IanWood1 commented Jul 2, 2024

Mahesh added similar changes to shared/sdxl_quantized here 7eb8280. But as Quinn noted, we cant fuse extsi with unsigned variants.

@IanWood1 IanWood1 requested a review from qedawkins July 2, 2024 16:28
Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

%2 = tensor.empty() : tensor<10x40xi32>
%3 = arith.constant 0 : i32
%4 = linalg.fill ins(%3 : i32) outs(%2 : tensor<10x40xi32>) -> tensor<10x40xi32>
%5 = linalg.matmul ins(%arg0, %1 : tensor<10x20xi32>, tensor<20x40xi32>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if changing regions like this is allowed/good.

Agree :) although I see it more as a function of named ops being over engineered. It's probably because we're trying to use interfaces instead of just using the named op builders, but linalg is awkward...

@IanWood1 IanWood1 merged commit f4447bd into iree-org:main Jul 2, 2024
57 of 60 checks passed
@IanWood1 IanWood1 deleted the fold_exti_into_conv2d branch July 2, 2024 16:44
@ScottTodd
Copy link
Member

I'm still hooking tests up so they'll run on presubmit, but I think this fixed int8 punet compilation for llvm-cpu (previous error was error: One or more operations with large vector sizes (16384 bytes) were found).

Can't quite follow the paper trail here, but if this was what fixed it, thanks!

@MaheshRavishankar
Copy link
Contributor

It's papering over backend issues, but is a reasonable fix

LLITCHEV pushed a commit to LLITCHEV/iree that referenced this pull request Jul 30, 2024
nod-ai/SHARK-ModelDev#755

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
Signed-off-by: Lubo Litchev <lubol@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
benchmarks:android-cpu Run default Android CPU benchmarks benchmarks:android-gpu Run default Android GPU benchmarks benchmarks:comp-stats Run default compilation statistics benchmarks benchmarks:cuda Run default CUDA benchmarks benchmarks:vulkan-nvidia Run default Vulkan benchmarks on NVIDIA GPU benchmarks:x86_64 Run default x86_64 benchmarks
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants