Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: attention refactor and base doc #216

Merged
merged 21 commits into from
Jan 17, 2023
Merged

fix: attention refactor and base doc #216

merged 21 commits into from
Jan 17, 2023

Conversation

gaetansnl
Copy link
Contributor

@gaetansnl gaetansnl commented Dec 14, 2022

Linear is partially refactored because input name is fixed by early_config_prune
I avoided important changes except refactor to linear:

  • Removed
rematerialize rm and rn to save registers
rm = m_block_idx * BLOCK_M + tl.arange(0, BLOCK_M)
rn = n_block_idx * BLOCK_N + tl.arange(0, BLOCK_N)
  • used variable initialized with tl.max_contiguous(tl.multiple_of everywhere

fixes #229

@github-actions github-actions bot added the fix hurrah, bug fixed! label Dec 14, 2022
@github-actions github-actions bot added fix hurrah, bug fixed! and removed fix hurrah, bug fixed! labels Jan 3, 2023
@gaetansnl gaetansnl marked this pull request as ready for review January 3, 2023 15:55
@gaetansnl
Copy link
Contributor Author

gaetansnl commented Jan 4, 2023

full test is passing, i suggest re-run on 3090 just in case

@github-actions github-actions bot added fix hurrah, bug fixed! and removed fix hurrah, bug fixed! labels Jan 4, 2023
@github-actions github-actions bot added fix hurrah, bug fixed! and removed fix hurrah, bug fixed! labels Jan 4, 2023
Copy link
Collaborator

@jonathlela jonathlela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit hard to review but the new naming convention make the code easier to read.
There is still some omissions in the whole renaming. I've pointed out 3-4 of them in the feedback, not sure if there is others remaining.

Suffixes list:
- `_idx` integer representing an index
- `_size` integer representing a size
- `_off` integer reoresenting an offset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

representing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

## Dimension naming
- Dimension is singular
- If dimension follows variable name from a formula. You can use this name. Example MNK for matmul
- Use `col` or `row` singular if you don't have a new for the last two dimensions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't have a name ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

output_row_stride,
output_col_stride,
a_row_stride,
a_col_stride,
N,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use n_size and block_n_size ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

pid_n = (pid % width) // (group_size)
group_idx = program_idx // width
group_size = min(grid_m - group_idx * GROUP_M, GROUP_M)
m_block_idx = group_idx * GROUP_M + (program_idx % group_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be block_m_idx and block_n_idx ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

BLOCK_N: tl.constexpr,
BLOCK_M_SIZE: tl.constexpr,
BLOCK_DHEAD_SIZE: tl.constexpr,
BLOCK_N_SIZE: tl.constexpr,
):
# Index of the block on M axis (M axis is the rows of matrix K)
n_block_idx = tl.program_id(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be block_n_idx and block_m_idx ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

ptrs_q = Q + offs_q
ptrs_k = K + offs_k
ptrs_v = V + offs_v
ptrs_q = q_ptr + offs_q
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q_ptrs, k_ptrs and v_ptrs ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

@pommedeterresautee pommedeterresautee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to review but it's really hard.
Did it passed tests? (full pytest)
Added few remarks on top of those of @jonathlela

@@ -0,0 +1,41 @@
# Naming conventions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

README in a more visible location like contribution folder? With a good title / name so we know its purpose.

@@ -186,21 +186,21 @@ def _fwd_kernel(
┌────────────┐
│ │ │
M Dimension│ ├────────────┤ ┌───┐
size_m │ │ │ │ │ BLOCK_M
m_size │ │ │ │ │ BLOCK_M
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BLOCK_M_SIZE ? same fo N

@gaetansnl
Copy link
Contributor Author

@pommedeterresautee no got an error at the end, I will try to reproduce

Copy link
Contributor

@white-gorilla white-gorilla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Markown is optimized but the lack of line breaks makes it less readable in .md for maintainers and not readable once converted to HTML on the site.

  • Don't hesitate to put line breaks after the titles and especially before the lists (otherwise they will be displayed inline).
  • Tip: stay consistent on punctuation, especially the . at the end of the line and the : (either you put them everywhere or you don't).

Otherwise it's fine.

Capture d’écran 2023-01-17 à 14 01 17

Capture d’écran 2023-01-17 à 13 58 43

Capture d’écran 2023-01-17 à 14 00 00

@pommedeterresautee
Copy link
Member

❯ pytest test/test_torchdynamo.py 
...
============================================================================================== 270 passed in 3488.03s (0:58:08) ==============================================================================================

Copy link
Member

@pommedeterresautee pommedeterresautee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quite hard to review line by line but from what I checked, it's ok.
plus test pass

@white-gorilla
Copy link
Contributor

@gaetansnl Still missing a line break.

Capture d’écran 2023-01-17 à 17 01 37

@gaetansnl gaetansnl merged commit 46e8a58 into main Jan 17, 2023
@gaetansnl gaetansnl deleted the fix/refactor-kernels branch January 17, 2023 16:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fix hurrah, bug fixed!
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add conventions to triton kernel writing
4 participants