Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SDXL dispatch count regression from llvm integrate #19002

Open
IanWood1 opened this issue Nov 4, 2024 · 11 comments
Open

Fix SDXL dispatch count regression from llvm integrate #19002

IanWood1 opened this issue Nov 4, 2024 · 11 comments
Assignees

Comments

@IanWood1
Copy link
Contributor

IanWood1 commented Nov 4, 2024

The most recent llvm integrate, #18987, introduced a minor regression in SDXL clip dispatch count (1139 ⇾ 1141). I tracked it to llvm/llvm-project@df0d249. I was able to restore the dispatch count by locally reverting this single commit.

Here are the 2 additional dispatches after LLVM integrate:

Command used:

iree-compile artifacts/sdxl_clip/model.mlirbc -o extra-dispatches.mlir --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --iree-opt-const-eval=false --iree-global-opt-propagate-transposes=true --iree-dispatch-creation-enable-fuse-horizontal-contractions=true --iree-dispatch-creation-enable-aggressive-fusion=true --iree-opt-aggressively-propagate-transposes=true --iree-opt-outer-dim-concat=true --iree-llvmgpu-enable-prefetch=true --iree-opt-data-tiling=false --iree-codegen-gpu-native-math-precision=true --iree-codegen-llvmgpu-use-vector-distribution --iree-hip-waves-per-eu=2 --iree-execution-model=async-external --iree-scheduling-dump-statistics-format=json --iree-scheduling-dump-statistics-file=compilation_info.json '--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)' --compile-to=dispatch-creation

MLIR:

util.global private @__hoisted_tensor_64x768xf16_255 {stream.affinity.default = #hal.device.affinity<@__device_0>} : tensor<64x768xf16>
  util.initializer attributes {stream.affinity.default = #hal.device.affinity<@__device_0>} {
    %cst = arith.constant dense_resource<torch_tensor_1_77_torch.int64> : tensor<1x77xi64>
    %_params.text_encoder_model_1.text_model.embeddings.position_embedding.weight = util.global.load immutable @_params.text_encoder_model_1.text_model.embeddings.position_embedding.weight : tensor<77x768xf16>
    %0 = flow.dispatch.workgroups(%cst, %_params.text_encoder_model_1.text_model.embeddings.position_embedding.weight) : (tensor<1x77xi64>, tensor<77x768xf16>) -> tensor<64x768xf16> =
        (%arg0: !flow.dispatch.tensor<readonly:tensor<1x77xi64>>, %arg1: !flow.dispatch.tensor<readonly:tensor<77x768xf16>>, %arg2: !flow.dispatch.tensor<writeonly:tensor<64x768xf16>>) {
      %c77 = arith.constant 77 : index
      %c0_i64 = arith.constant 0 : i64
      %1 = flow.dispatch.tensor.load %arg1, offsets = [0, 0], sizes = [77, 768], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<77x768xf16>> -> tensor<77x768xf16>
      %2 = tensor.empty() : tensor<64x768xf16>
      %3 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [1, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x77xi64>> -> tensor<64xi64>
      %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%3 : tensor<64xi64>) outs(%2 : tensor<64x768xf16>) {
      ^bb0(%in: i64, %out: f16):
        %5 = arith.index_cast %in : i64 to index
        %6 = linalg.index 1 : index
        %7 = arith.cmpi slt, %5, %c77 : index
        cf.assert %7, "index must be smaller than dim size"
        %8 = arith.cmpi sge, %in, %c0_i64 : i64
        cf.assert %8, "index must be larger or equal to 0"
        %extracted = tensor.extract %1[%5, %6] : tensor<77x768xf16>
        linalg.yield %extracted : f16
      } -> tensor<64x768xf16>
      flow.dispatch.tensor.store %4, %arg2, offsets = [0, 0], sizes = [64, 768], strides = [1, 1] : tensor<64x768xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x768xf16>>
      flow.return
    } count() -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      flow.return %x, %y, %z : index, index, index
    }
    util.global.store %0, @__hoisted_tensor_64x768xf16_255 : tensor<64x768xf16>
    util.return
  }
  util.global private @__hoisted_tensor_64x1280xf16_256 {stream.affinity.default = #hal.device.affinity<@__device_0>} : tensor<64x1280xf16>
  util.initializer attributes {stream.affinity.default = #hal.device.affinity<@__device_0>} {
    %cst = arith.constant dense_resource<torch_tensor_1_77_torch.int64_1> : tensor<1x77xi64>
    %_params.text_encoder_model_2.text_model.embeddings.position_embedding.weight = util.global.load immutable @_params.text_encoder_model_2.text_model.embeddings.position_embedding.weight : tensor<77x1280xf16>
    %0 = flow.dispatch.workgroups(%cst, %_params.text_encoder_model_2.text_model.embeddings.position_embedding.weight) : (tensor<1x77xi64>, tensor<77x1280xf16>) -> tensor<64x1280xf16> =
        (%arg0: !flow.dispatch.tensor<readonly:tensor<1x77xi64>>, %arg1: !flow.dispatch.tensor<readonly:tensor<77x1280xf16>>, %arg2: !flow.dispatch.tensor<writeonly:tensor<64x1280xf16>>) {
      %c77 = arith.constant 77 : index
      %c0_i64 = arith.constant 0 : i64
      %1 = flow.dispatch.tensor.load %arg1, offsets = [0, 0], sizes = [77, 1280], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<77x1280xf16>> -> tensor<77x1280xf16>
      %2 = tensor.empty() : tensor<64x1280xf16>
      %3 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [1, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x77xi64>> -> tensor<64xi64>
      %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%3 : tensor<64xi64>) outs(%2 : tensor<64x1280xf16>) {
      ^bb0(%in: i64, %out: f16):
        %5 = arith.index_cast %in : i64 to index
        %6 = linalg.index 1 : index
        %7 = arith.cmpi slt, %5, %c77 : index
        cf.assert %7, "index must be smaller than dim size"
        %8 = arith.cmpi sge, %in, %c0_i64 : i64
        cf.assert %8, "index must be larger or equal to 0"
        %extracted = tensor.extract %1[%5, %6] : tensor<77x1280xf16>
        linalg.yield %extracted : f16
      } -> tensor<64x1280xf16>
      flow.dispatch.tensor.store %4, %arg2, offsets = [0, 0], sizes = [64, 1280], strides = [1, 1] : tensor<64x1280xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x1280xf16>>
      flow.return
    } count() -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      flow.return %x, %y, %z : index, index, index
    }
    util.global.store %0, @__hoisted_tensor_64x1280xf16_256 : tensor<64x1280xf16>
    util.return
  }

It appears these linalg.generics were getting CSE'd before the change but can't anymore because of the cf.assert which have side effects.

@IanWood1 IanWood1 self-assigned this Nov 4, 2024
@benvanik
Copy link
Collaborator

benvanik commented Nov 4, 2024

ew, we should really not be seeing those asserts - what's adding those?

@IanWood1
Copy link
Contributor Author

IanWood1 commented Nov 4, 2024

I think this is coming from torch-mlir's lowering of some op (not entirely sure which). Don't these get dropped somewhere around flow/stream anyway?

@benvanik
Copy link
Collaborator

benvanik commented Nov 4, 2024

Nope - they make it all the way to runtime if they are outside of dispatches and as you're seeing here will have bad influences during dispatch region formation/executable generation. Asserts should only be added explicitly by users unless a debug mode is enabled, IMO. Asserts inside of dispatches are no-ops today and get removed too late, so they just make compilation worse.

They could be used for int range analysis hints in a release build - but if that's the case we should probably absorb them into the int range ops at input time instead.

@IanWood1
Copy link
Contributor Author

IanWood1 commented Nov 4, 2024

Asserts should only be added explicitly by users unless a debug mode is enabled, IMO. Asserts inside of dispatches are no-ops today and get removed too late, so they just make compilation worse.

That makes sense, I think they are there to conform with pytorch ops specs. We don't currently have a "debug/release mode" right?

@MaheshRavishankar do you have any suggestions on how to fix this?

@benvanik
Copy link
Collaborator

benvanik commented Nov 4, 2024

--iree-opt-strip-assertions can be used to strip them near input-time (I forget if it walks into linalg ops, but it should).
As a middle-stage compiler we want debug options for assertions that come in as user input to be controlled by the user creating the input - it's not possible to know if an assert was added by the user, a dialect conversion above us (like this), etc. If a dialect is inserting assertions it'd be nice if it had an option to stop inserting them.

For now though, --iree-opt-strip-assertions unless you're testing correctness (and even then as seen here we'll never report assertions inside of dispatches today, though we could in debug modes if it ever proved useful - it's just really tricky logic per backend).

@MaheshRavishankar
Copy link
Contributor

I thought that is on by default?

@benvanik
Copy link
Collaborator

benvanik commented Nov 4, 2024

Doesn't seem like it. It could be. I suspect only a fraction of users care about the assertions and more would be confused by how badly they mess up performance.

@MaheshRavishankar
Copy link
Contributor

@IanWood1 maybe start with adding this flag to all the CI tests, and a separate PR that turns it on by default.

@IanWood1
Copy link
Contributor Author

IanWood1 commented Nov 5, 2024

I tried turning it on in #19014 but I didn't realize assertions don't get stripped until after hoisting, so there is no effect on dispatch count. Should this pass be moved? There is a comment explaining why it need to be after optimizations:

// Strip std.assert & co after we perform optimizations; prior to this we
// may use the assertions to derive information during analysis.
.addPredicatedPass(transformOptions.options.stripAssertions,
IREE::Util::createStripDebugOpsPass);

@benvanik
Copy link
Collaborator

benvanik commented Nov 5, 2024

Good catch. That may not be true anymore now that we have information coming from the frontend and util.assume - we could lower the assertions to those assume ops prior to removal as one of the first steps.

@benvanik
Copy link
Collaborator

benvanik commented Nov 5, 2024

(oh and I'm pretty sure we don't derive information from the assertions today - so it'd be safe to move now!)

IanWood1 added a commit that referenced this issue Nov 14, 2024
This change _temporarily_ adds `iree-opt-strip-assertions` to SDXL CI
which should fix the regression discussed in
#19002. Assertions within
`linalg.generic` ops can mess with dispatch creation and lead to poorly
performing dispatches, despite the fact that they get stripped in later
pipelines.

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
Groverkss pushed a commit to Groverkss/iree that referenced this issue Dec 1, 2024
This change _temporarily_ adds `iree-opt-strip-assertions` to SDXL CI
which should fix the regression discussed in
iree-org#19002. Assertions within
`linalg.generic` ops can mess with dispatch creation and lead to poorly
performing dispatches, despite the fact that they get stripped in later
pipelines.

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
giacs-epic pushed a commit to giacs-epic/iree that referenced this issue Dec 4, 2024
This change _temporarily_ adds `iree-opt-strip-assertions` to SDXL CI
which should fix the regression discussed in
iree-org#19002. Assertions within
`linalg.generic` ops can mess with dispatch creation and lead to poorly
performing dispatches, despite the fact that they get stripped in later
pipelines.

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants