Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (stablehlo) add stablehlo.dot #3090

Merged
merged 4 commits into from
Aug 29, 2024
Merged

dialects: (stablehlo) add stablehlo.dot #3090

merged 4 commits into from
Aug 29, 2024

Conversation

superlopuh
Copy link
Member

This is the attribute used to specify along which dimensions the dot products happen.

CC @efferifick

@superlopuh superlopuh added the dialects Changes on the dialects label Aug 23, 2024
@superlopuh superlopuh self-assigned this Aug 23, 2024
Copy link

codecov bot commented Aug 23, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 89.90%. Comparing base (ca76864) to head (af863ff).
Report is 2 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #3090   +/-   ##
=======================================
  Coverage   89.90%   89.90%           
=======================================
  Files         421      421           
  Lines       53264    53306   +42     
  Branches     8257     8263    +6     
=======================================
+ Hits        47885    47927   +42     
  Misses       4041     4041           
  Partials     1338     1338           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@alexarice alexarice left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small stylistic nits.

with printer.in_angle_brackets():
with printer.indented():
printer.print_string("\nlhs_batching_dimensions = [")
printer.print_list(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does directly printing the ArrayAttr not work here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wants to print the types for each int, I'll look into a more elegant way of doing this

xdsl/dialects/stablehlo.py Outdated Show resolved Hide resolved
@superlopuh superlopuh requested a review from alexarice August 26, 2024 13:32
printer.print_string(f"\n{name} = [")
printer.print_list(
value.data,
lambda dim: printer.print_string(f"{dim.value.data}"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lambda dim: printer.print_string(f"{dim.value.data}"),
lambda dim: printer.print(dim.value.data),

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can ignore this if you want

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually prefer print_string to print, which is unfortunately pretty slow right now

@superlopuh superlopuh merged commit 37702db into main Aug 29, 2024
14 checks passed
@superlopuh superlopuh deleted the sasha/jax/dot branch August 29, 2024 12:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dialects Changes on the dialects
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants