Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix for "dot() got an unexpected keyword argument 'trans_b'" error #232

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

winglian
Copy link

@tridao
Copy link
Contributor

tridao commented May 20, 2023

Which Triton version does this require?

@winglian
Copy link
Author

winglian commented May 20, 2023

Which Triton version does this require?

v2.0.0

checked on my machine where I was this fixes the issue and I'm using the 2.0.0 release, not nightly

>>> import triton
trito>>> triton.__version__
'2.0.0'

@KeremTurgutlu
Copy link

KeremTurgutlu commented May 22, 2023

@winglian what kind of setup do you have and what steps did you take to make this change?

I am currently using a docker image nvcr.io/nvidia/pytorch:23.04-py3 from ngc. I cloned the flash attention repo made the changes you have and ran pip install -e .. Using a A100-80 GB machine. torch==2.1.0a0+fe05266, triton==2.0.0

However flash attention triton benchmark code is stuck as seen below:

### Flash Attn Benchmark
root@74a3003bb4ab:/workspace/workdir/flash-attention# PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
FlashAttention - Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbecc074c10>
fn_amp(*inputs, **kwinputs)
  3.39 ms
  1 measurement, 30 runs , 12 threads
FlashAttention - Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbecc074e80>
y.backward(grad, retain_graph=True)
  9.04 ms
  1 measurement, 30 runs , 12 threads
FlashAttention - Forward + Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbec17250a0>
f(grad, *inputs, **kwinputs)
  12.26 ms
  1 measurement, 30 runs , 12 threads
PyTorch Standard Attention - Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbecc074dc0>
fn_amp(*inputs, **kwinputs)
  18.00 ms
  1 measurement, 30 runs , 12 threads
PyTorch Standard Attention - Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbecc074be0>
y.backward(grad, retain_graph=True)
  40.53 ms
  1 measurement, 30 runs , 12 threads
PyTorch Standard Attention - Forward + Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbec1725100>
f(grad, *inputs, **kwinputs)
  57.74 ms
  1 measurement, 30 runs , 12 threads

### Causal Benchmark
root@74a3003bb4ab:/workspace/workdir/flash-attention# PYTHONPATH=$PWD python benchmarks/benchmark_causal.py 
FlashAttention - Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbaf0d00fa0>
fn_amp(*inputs, **kwinputs)
  1.94 ms
  1 measurement, 30 runs , 12 threads
FlashAttention - Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbaf0ceec40>
y.backward(grad, retain_graph=True)
  3.61 ms
  1 measurement, 30 runs , 12 threads
FlashAttention - Forward + Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbaf0902c70>
f(grad, *inputs, **kwinputs)
  5.14 ms
  1 measurement, 30 runs , 12 threads
PyTorch Attention - Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbaf0d17e80>
fn_amp(*inputs, **kwinputs)
  5.86 ms
  1 measurement, 30 runs , 12 threads
PyTorch Attention - Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbaf09026a0>
y.backward(grad, retain_graph=True)
  6.06 ms
  1 measurement, 30 runs , 12 threads
PyTorch Attention - Forward + Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbaf0cee6a0>
f(grad, *inputs, **kwinputs)
  12.26 ms
  1 measurement, 30 runs , 12 threads
FlashAttention Triton - Forward pass

@vchiley
Copy link

vchiley commented May 31, 2023

Tri and I have also tried it and it doesn't work... see here

@wptoux
Copy link

wptoux commented Jul 4, 2023

@jessiewiswjc
Copy link

Which Triton version does this require?

v2.0.0

checked on my machine where I was this fixes the issue and I'm using the 2.0.0 release, not nightly

>>> import triton
trito>>> triton.__version__
'2.0.0'

i am using triton==2.0.0 and your branch flash attn triton, but my procedure hang...

@tridao
Copy link
Contributor

tridao commented Jul 31, 2023

@tridao hello, do you have plans to run flash_attn_triton.py on triton==2.0.0? because my other code is based on 2.0.0 and would compile error on 2.0.0.dev20221202

I don't have bandwidth right now to work on it. We welcome community's contribution.

@tridao tridao force-pushed the main branch 2 times, most recently from e9018eb to 5400fdc Compare September 16, 2023 03:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants