Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Sep 18, 2025

Summary by CodeRabbit

  • New Features
    • Expanded GEMM and GEMM v2 to accept buffer-loaded inputs, enabling broader input compatibility without changing the public API.
  • Bug Fixes
    • Clearer, more specific error messages when unsupported argument types are provided.
  • Refactor
    • Internal updates to how GEMM v2 invokes its intrinsic, with no impact on external behavior.
    • Unified shape, stride, pointer, and offset handling across GEMM paths for improved consistency and reliability.

@cursor
Copy link

cursor bot commented Sep 18, 2025

You have run out of free Bugbot PR reviews for this billing cycle. This will reset on October 13.

To receive reviews on all of your PRs, visit the Cursor dashboard to activate Pro and start your 14-day free trial.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 18, 2025

Caution

Review failed

The pull request is closed.

Walkthrough

Adds tir.BufferLoad support to GEMM helper functions by deriving regions for shape, stride, pointer, and offset retrieval. Updates error messages to explicit ValueError variants. Applies changes to both gemm and gemm_v2; gemm_v2 switches final intrinsic to tl.gemm_py. Imports get_buffer_region_from_load.

Changes

Cohort / File(s) Change Summary
BufferLoad handling across retrieval helpers
tilelang/language/gemm.py
retrieve_shape/stride/ptr/offset now accept tir.BufferLoad by using get_buffer_region_from_load(...).region; compute shape from extents, strides from buffer, offset from region mins and strides; return access_ptr with computed offset.
Error handling updates
tilelang/language/gemm.py
Replaced generic errors with explicit ValueError messages for unsupported types in each retrieve_* helper.
GEMM blocks adjustments
tilelang/language/gemm.py
Applied new retrieval logic in gemm and gemm_v2; gemm_v2 switches intrinsic call from tl.gemm to tl.gemm_py.
Imports
tilelang/language/gemm.py
Added from tilelang.utils.language import get_buffer_region_from_load.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant U as User Code
  participant G as gemm/gemm_v2
  participant H as retrieve_* helpers
  participant B as Buffer/Region
  participant I as tl.gemm / tl.gemm_py

  U->>G: call with A,B,C (may be tir.BufferLoad)
  G->>H: retrieve_shape/stride/ptr/offset(A,B,C)
  alt arg is tir.BufferLoad
    H->>B: get_buffer_region_from_load(arg)
    Note right of H: Derive region mins/extents<br/>Compute shape/strides/offset
    H-->>G: shapes, strides, ptr via access_ptr(offset)
  else arg is Buffer/var
    H-->>G: shapes, strides, ptr, offset (existing paths)
  end
  G->>I: invoke intrinsic (gemm_v2 -> gemm_py)
  I-->>G: result / side-effects
  G-->>U: completion
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

I hop through tiles where buffers load,
Nibbling strides along the road.
From region mins to offsets true,
I stitch the shapes and pass them through.
Now gemm_v2 says “py!” with glee—
A carrot-counting matrix spree. 🥕

✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 232782d and 2e59910.

📒 Files selected for processing (1)
  • tilelang/language/gemm.py (9 hunks)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a test failure by implementing comprehensive support for tir.BufferLoad nodes within the gemm module's helper functions. The modifications ensure that when T.gemm receives inputs that are BufferLoad objects, their shapes, strides, pointers, and offsets are correctly identified and handled, particularly when dealing with strided memory access patterns. This improves the robustness and correctness of the gemm operation under more complex input conditions.

Highlights

  • Enhanced gemm Utility Functions: The utility functions retrieve_shape, retrieve_stride, retrieve_ptr, and retrieve_offset within tilelang/language/gemm.py have been extended to correctly process tir.BufferLoad objects.
  • Support for Strided Buffer Inputs: This change specifically addresses scenarios where T.gemm inputs are tir.BufferLoad nodes that might represent strided buffers, ensuring accurate calculations for their properties.
  • New Import for BufferLoad Handling: The get_buffer_region_from_load utility function has been imported and integrated to facilitate the extraction of buffer region information from tir.BufferLoad objects.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@LeiWang1999 LeiWang1999 merged commit ebea77d into tile-ai:main Sep 18, 2025
4 of 5 checks passed
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for tir.BufferLoad nodes as inputs to T.gemm by extending the internal helper functions in both gemm and gemm_v2.

While the change correctly extends functionality, I've identified several areas for improvement:

  1. Potential AttributeError: The new logic in retrieve_shape, retrieve_ptr, and retrieve_offset does not handle the case where get_buffer_region_from_load returns None, which can lead to an AttributeError. I've added comments with suggestions to fix this potential bug.
  2. Code Duplication: There is significant code duplication within the helper functions (e.g., between tir.BufferRegion and tir.BufferLoad cases) and across gemm and gemm_v2 (the entire set of helper functions is duplicated).

I highly recommend refactoring to address the duplication. A good first step would be to move the helper functions (retrieve_shape, retrieve_stride, etc.) outside of gemm and gemm_v2 so they can be shared. This will greatly improve the code's maintainability and readability. My specific comments focus on the added code blocks and provide suggestions for immediate fixes and local refactoring.

Comment on lines +70 to +75
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
shape = []
for r in region:
shape.append(r.extent)
return shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This new block to handle tir.BufferLoad introduces a potential AttributeError and duplicates code.

  1. Potential Bug: The function get_buffer_region_from_load can return None. If it does, get_buffer_region_from_load(object).region will raise an AttributeError. You should add a check for this.

  2. Code Duplication: The logic in this block is nearly identical to the logic in the elif isinstance(object, tir.BufferRegion) block. This makes the code harder to maintain. Ideally, the common logic should be extracted.

The suggestion below fixes the potential bug and simplifies the list creation using a list comprehension.

Suggested change
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
shape = []
for r in region:
shape.append(r.extent)
return shape
elif isinstance(object, tir.BufferLoad):
buffer_region = get_buffer_region_from_load(object)
if not buffer_region:
raise ValueError(f"Could not get BufferRegion from BufferLoad: {object}")
return [r.extent for r in buffer_region.region]

Comment on lines +154 to +168
elif isinstance(object, tir.BufferLoad):
buffer = object.buffer
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
offset = 0
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This large block of code has a few issues:

  1. Potential Bug: get_buffer_region_from_load(object) can return None, which would cause an AttributeError on the next line (...region). A None check is required.

  2. Code Duplication: This block is almost a copy-paste of the tir.BufferRegion case. Furthermore, it duplicates logic from retrieve_stride (for calculating strides) and retrieve_offset (for calculating indices).

This significant duplication makes the code difficult to maintain. I strongly recommend refactoring. You could start by fixing the bug by adding a None check after calling get_buffer_region_from_load. Then, consider a broader refactoring to reduce duplication, for example by extracting common parts into helper functions.

Comment on lines +183 to +188
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
return indices
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to retrieve_shape, this block has a potential AttributeError if get_buffer_region_from_load returns None, and it duplicates logic from the tir.BufferRegion case.

Please add a None check and consider refactoring to reduce code duplication. The suggestion below fixes the bug and uses a list comprehension for conciseness.

Suggested change
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
return indices
elif isinstance(object, tir.BufferLoad):
buffer_region = get_buffer_region_from_load(object)
if not buffer_region:
raise ValueError(f"Could not get BufferRegion from BufferLoad: {object}")
return [r.min for r in buffer_region.region]

Comment on lines +286 to +291
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
shape = []
for r in region:
shape.append(r.extent)
return shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This new block to handle tir.BufferLoad introduces a potential AttributeError and duplicates code.

  1. Potential Bug: The function get_buffer_region_from_load can return None. If it does, get_buffer_region_from_load(object).region will raise an AttributeError. You should add a check for this.

  2. Code Duplication: The logic in this block is nearly identical to the logic in the elif isinstance(object, tir.BufferRegion) block. This makes the code harder to maintain. Ideally, the common logic should be extracted.

The suggestion below fixes the potential bug and simplifies the list creation using a list comprehension.

Suggested change
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
shape = []
for r in region:
shape.append(r.extent)
return shape
elif isinstance(object, tir.BufferLoad):
buffer_region = get_buffer_region_from_load(object)
if not buffer_region:
raise ValueError(f"Could not get BufferRegion from BufferLoad: {object}")
return [r.extent for r in buffer_region.region]

Comment on lines +370 to +384
elif isinstance(object, tir.BufferLoad):
buffer = object.buffer
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
offset = 0
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This large block of code has a few issues:

  1. Potential Bug: get_buffer_region_from_load(object) can return None, which would cause an AttributeError on the next line (...region). A None check is required.

  2. Code Duplication: This block is almost a copy-paste of the tir.BufferRegion case. Furthermore, it duplicates logic from retrieve_stride (for calculating strides) and retrieve_offset (for calculating indices).

This significant duplication makes the code difficult to maintain. I strongly recommend refactoring. You could start by fixing the bug by adding a None check after calling get_buffer_region_from_load. Then, consider a broader refactoring to reduce duplication, for example by extracting common parts into helper functions.

Comment on lines +399 to +404
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
return indices
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to retrieve_shape, this block has a potential AttributeError if get_buffer_region_from_load returns None, and it duplicates logic from the tir.BufferRegion case.

Please add a None check and consider refactoring to reduce code duplication. The suggestion below fixes the bug and uses a list comprehension for conciseness.

Suggested change
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
return indices
elif isinstance(object, tir.BufferLoad):
buffer_region = get_buffer_region_from_load(object)
if not buffer_region:
raise ValueError(f"Could not get BufferRegion from BufferLoad: {object}")
return [r.min for r in buffer_region.region]

Comment on lines +96 to +103
elif isinstance(object, tir.BufferLoad):
buffer = object.buffer
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
return strides
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block for tir.BufferLoad duplicates the stride calculation logic that already exists for tir.Buffer and tir.BufferRegion. This increases maintenance overhead.

To improve maintainability, you could refactor retrieve_stride to extract this common logic into a helper function or restructure the conditional to handle all buffer-like types that have a .shape attribute in a single path.

Comment on lines +312 to +319
elif isinstance(object, tir.BufferLoad):
buffer = object.buffer
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
return strides
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block for tir.BufferLoad duplicates the stride calculation logic that already exists for tir.Buffer and tir.BufferRegion. This increases maintenance overhead.

To improve maintainability, you could refactor retrieve_stride to extract this common logic into a helper function or restructure the conditional to handle all buffer-like types that have a .shape attribute in a single path.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant