Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

[Layout] Add layout transformation analysis for PrimFunc #449

Closed
wants to merge 5 commits into from

Conversation

psrivas2
Copy link
Contributor

This change adds a PrimFunc level analysis to propose layout transformations to block and buffers in the PrimFunc based on the layout transformations to PrimFunc outputs. It analyzes buffer access to figure this out. It tries to preserve sequential access to buffers when it does this.

For example given the following PrimFunc and write buffer "relu" transformation lambda n, c, h, w: (n, h, w, c // 4, c % 4), it will suggest to make the following transformations.

  • Block transformation on "compute": lambda n, c, h, w: (n, h, w, c // 4, c % 4)
  • Buffer transformation on "relu": lambda n, c, h, w: (n, h, w, c // 4, c % 4)
  • Buffer transformation on "arg": lambda n, c, h, w: (n, h, w, c // 4, c % 4)
@T.prim_func
  def elemwise_relu(
      arg: T.Buffer((32, 64, 224, 224), "float32"),
      relu: T.Buffer((32, 64, 224, 224), "float32"),
  ):
      for i0, i1, i2, i3 in T.grid(32, 64, 224, 224):
          with T.block("compute"):
              v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
              T.reads(arg[v_i0, v_i1, v_i2, v_i3])
              T.writes(relu[v_i0, v_i1, v_i2, v_i3])
              relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2, v_i3], T.float32(0))

These transformations can then be applied on the PrimFunc to get the PrimFunc with new layout.

@T.prim_func
def elemwise_relu(
    arg: T.Buffer((32, 224, 224, 16, 4), "float32"),
    relu: T.Buffer((32, 224, 224, 16, 4), "float32"),
):
    for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 16, 4):
        with T.block("compute"):
            v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4])
            T.reads(arg[v0, v1, v2, v3, v4])
            T.writes(relu[v0, v1, v2, v3, v4])
            relu[v0, v1, v2, v3, v4] = T.max(arg[v0, v1, v2, v3, v4], T.float32(0))

This change adds a PrimFunc level analysis to suggest layout transformations to block and buffers in the PrimFunc based on the layout transformations to PrimFunc outputs.
@tqchen
Copy link
Contributor

tqchen commented Feb 18, 2023

#453, this PR should be part of PRs to send to unity. We can either directly transition to PR to unity, or continue review merge before transition, depending on what authors and reviewers want

@psrivas2 psrivas2 changed the title [WIP][Layout] Add layout transformation analysis for PrimFunc [Layout] Add layout transformation analysis for PrimFunc Feb 21, 2023
@psrivas2
Copy link
Contributor Author

cc @sunggg @masahi @YuchenJin

@psrivas2
Copy link
Contributor Author

Migrated to unity branch apache/tvm#14066

@psrivas2 psrivas2 closed this Feb 21, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants