Skip to content

Fused Bwd #137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
May 3, 2025
Merged
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
050e98a
Fused with Good perf and stride fixed
micmelesse Apr 18, 2025
7b32e6b
new grid
micmelesse Apr 22, 2025
0abd905
BLK_SLICE_FACTOR = 1
micmelesse Apr 22, 2025
18a3e57
add tflops
micmelesse Apr 22, 2025
fc9565d
new commit
micmelesse Apr 22, 2025
b9045bf
test in parrallel
micmelesse Apr 23, 2025
c589ec9
strides added by jusson
micmelesse Apr 23, 2025
3abfeeb
disable alibi
micmelesse Apr 23, 2025
99a8cf8
fix bugs again
micmelesse Apr 23, 2025
0c7fa0f
default to fused
micmelesse Apr 23, 2025
63439be
add bwd options for varlen
micmelesse Apr 23, 2025
31e2ba9
backend filter
micmelesse Apr 23, 2025
64a81c1
default to jingning and batch 4
micmelesse Apr 23, 2025
29d79d8
best fwd config
micmelesse Apr 23, 2025
fb78555
fix TRITON_PRINT_AUTOTUNING flag bug
micmelesse Apr 23, 2025
afbb34c
tune
micmelesse Apr 24, 2025
6efea74
Tuning fwd prefill
azaidy Apr 24, 2025
4d0e861
add if else
micmelesse Apr 24, 2025
dcf115b
use flag
micmelesse Apr 24, 2025
6acf41a
Minor mask fix
azaidy Apr 24, 2025
8c694ea
FLIP GRID
micmelesse Apr 24, 2025
0560afa
use best config for default
micmelesse Apr 24, 2025
46323c1
print when autotuning
micmelesse Apr 24, 2025
64d0dc5
test bfloat16
micmelesse Apr 25, 2025
6c55632
fix k and v stride bugs
micmelesse Apr 28, 2025
cb636f7
skip bfloat16
micmelesse Apr 28, 2025
619ad31
test kvpacked
micmelesse Apr 28, 2025
42cf911
disable internal tests
micmelesse Apr 28, 2025
ad75f76
pick default config based on arch
micmelesse Apr 28, 2025
f67bdde
Add alibi in the new bwd kernel (#139)
micmelesse May 1, 2025
bb79cbb
Update amd_tests.yml
micmelesse May 1, 2025
23e1383
upgrad to triton==3.3.0
micmelesse May 1, 2025
1023eae
increase shm
micmelesse May 1, 2025
34952b2
use 64 x 64 for now
micmelesse May 1, 2025
28eaaa4
save
micmelesse May 1, 2025
a06147f
handle 1d alibi
micmelesse May 2, 2025
a512564
Add fp8 to fused kernel (#140)
micmelesse May 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/amd_nightly.yml
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@ jobs:
timeout-minutes: 720 # self hosted runners can run jobs for longer than the default of 360 minutes
container:
image: rocm/pytorch:latest
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --shm-size 16G --group-add video --user root
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -38,7 +38,7 @@ jobs:
- name: Install Triton
run: |
pip install triton==3.2.0
pip install triton==3.3.0
- name: Show Triton version
run: |
@@ -50,15 +50,15 @@ jobs:
- name: Install dependencies for bench and misc
run: |
pip install numpy==1.24 matplotlib pandas tabulate
pip install matplotlib pandas tabulate
- name: AMD Internal Tests
run: |
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest flash_attn/flash_attn_triton_amd/test.py
- name: Flash Attention Tests
run: |
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest tests/test_flash_attn_triton_amd.py
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest -n 8 tests/test_flash_attn_triton_amd.py
- name: AMD Bench
run: |
@@ -90,7 +90,7 @@ jobs:
- name: Install Triton
run: |
pip install triton==3.2.0
pip install triton==3.3.0
- name: Show Triton version
run: |
8 changes: 4 additions & 4 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@ jobs:
timeout-minutes: 720 # self hosted runners can run jobs for longer than the default of 360 minutes
container:
image: rocm/pytorch:latest
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --shm-size 16G --group-add video --user root
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -36,7 +36,7 @@ jobs:
- name: Install Triton
run: |
pip install triton==3.2.0
pip install triton==3.3.0
- name: Show Triton version
run: |
@@ -48,15 +48,15 @@ jobs:
- name: Install dependencies for bench and misc
run: |
pip install numpy==1.24 matplotlib pandas tabulate
pip install matplotlib pandas tabulate
- name: AMD Internal Tests
run: |
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest flash_attn/flash_attn_triton_amd/test.py
- name: Flash Attention Tests
run: |
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest tests/test_flash_attn_triton_amd.py
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest -n 8 tests/test_flash_attn_triton_amd.py
- name: AMD Bench
run: |
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -154,7 +154,7 @@ To get started with the triton backend for AMD, follow the steps below.
First install the recommended Triton version

```
pip install triton==3.2.0
pip install triton==3.3.0
```
Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`.

@@ -182,7 +182,7 @@ FROM rocm/pytorch:latest
WORKDIR /workspace
# install triton
RUN pip install triton==3.2.0
RUN pip install triton==3.3.0
# install flash attention
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
2 changes: 1 addition & 1 deletion flash_attn/flash_attn_triton_amd/Dockerfile
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ FROM rocm/pytorch:latest
WORKDIR /workspace

# install triton
RUN pip install triton==3.2.0
RUN pip install triton==3.3.0

# install flash attention
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
4 changes: 2 additions & 2 deletions flash_attn/flash_attn_triton_amd/README.md
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ To get started with the triton backend for AMD, follow the steps below.
First install the recommended Triton version

```
pip install triton==3.2.0
pip install triton==3.3.0
```
Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`.

@@ -56,7 +56,7 @@ FROM rocm/pytorch:latest
WORKDIR /workspace
# install triton
RUN pip install triton==3.2.0
RUN pip install triton==3.3.0
# install flash attention
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
155 changes: 116 additions & 39 deletions flash_attn/flash_attn_triton_amd/bench.py
Original file line number Diff line number Diff line change
@@ -80,7 +80,7 @@ class EnvVariableConfig:
backend: Optional[Literal["triton", "ck"]] = None

ENV_VARIABLE_CONFIGS : List[EnvVariableConfig] = [
EnvVariableConfig(key="BWD_MODE", values=["split", "fused", "jingning"], backend="triton"),
# EnvVariableConfig(key="BWD_MODE", values=["split", "fused", "jingning"], backend="triton"),
]

class FunctionConfig:
@@ -871,8 +871,8 @@ def load_flash_attn_module(backend: Literal["triton", "ck"], env_configs: Dict =
# set environment variable for the desired backend
if backend == "triton":
os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE"
os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "0"
os.environ["FLASH_ATTENTION_TRITON_AMD_DEBUG"] = "0"
os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "1"
elif backend == "ck":
os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "FALSE"
else:
@@ -1016,15 +1016,30 @@ def get_input_config_set(config_type):
# batch, hq, hk, sq, sk, d_head, causal, dropout
input_configs = [
# LLaMA 3 8B
(1, 32, 8, 8192, 8192, 128, True, 0.0),
(4, 32, 8, 8192, 8192, 128, True, 0.0),
# LLaMA 3 70B
(1, 64, 8, 8192, 8192, 128, True, 0.0),
(4, 64, 8, 8192, 8192, 128, True, 0.0),
]
else:
raise ValueError(f"Unknown input config: {config_type}")

return input_configs

def filter_backends(requested_backends, supported_backends, fn_name):
if requested_backends:
selected = []
for be in requested_backends:
if be in supported_backends:
selected.append(be)
else:
warning(
f"backend '{be}' requested but not supported by "
f"function '{fn_name}'. skipping this back-end."
)
return selected
else:
return supported_backends


def process_args():
"""
@@ -1052,6 +1067,14 @@ def process_args():
default=None,
help=f"Benchmarking mode(s) to run. If omitted, runs all supported modes for each function.",
)
parser.add_argument(
"--backend",
type=str,
nargs='*',
choices=["triton", "ck"],
default=None,
help="Back-end(s) to run (triton, ck). Omit to run every back-end that is both available and supported by the function.",
)
# config
parser.add_argument("-b", type=int, default=None, help="Batch size")
parser.add_argument("-hq", type=int, default=None, help="Q Number of heads")
@@ -1067,7 +1090,8 @@ def process_args():

# parse function args
benchmark_fns = args.benchmark_fn
requested_modes = args.mode
requested_modes = args.mode
requested_backends = args.backend

# fenerate function configurations and input configurations separately
all_function_configs = []
@@ -1101,9 +1125,17 @@ def process_args():
if not modes_to_run:
warning(f"No valid modes to run for function '{fn_name}' based on request and function support. Skipping this function.")
continue

# filter by backend
backends_to_run = filter_backends(requested_backends,
supported_backends,
fn_name)
if not backends_to_run:
warning(f"no valid back-ends left for '{fn_name}'. skipping.")
continue

# create a function config for each backend and dtype combination
for backend in supported_backends:
for backend in backends_to_run:
for dtype in supported_dtypes:
for mode in modes_to_run:
for env_config in supported_env_configs[backend]:
@@ -1124,6 +1156,52 @@ def check_environment_variables():
if key in os.environ:
raise ValueError(f"Running with {key} environment variable is not recommended for the benching script. Use --help to see how to use the benching script.")

def compute_flops(batch, hq, hk, sq, sk, d_head, causal):
# 2 FLOPs per multiply‑add
if causal:
valid_pairs = ((sk * (sk + 1)) // 2 if sq > sk else
sq * sk - (sq * (sq - 1)) // 2)
else:
valid_pairs = sq * sk
return 2 * batch * hq * valid_pairs * d_head

# see ref, https://github.com/ROCm/aiter/blob/jukorhon/mha-bwd/op_benchmarks/triton/bench_mha.py
def _flops_single_row(row: pd.Series, mode: str) -> float:
b, hq, d_head = int(row["BATCH"]), int(row["HQ"]), int(row["D_HEAD"])
sq, sk = int(row["N_CTX_Q"]), int(row["N_CTX_K"])
causal = bool(row["CAUSAL"])

# -------- number of (query, key) products per head ----------------
if not causal:
valid_pairs = sq * sk
else: # triangular mask
if sq > sk:
valid_pairs = sk * (sk + 1) // 2 + (sq - sk) * sk
else: # sq <= sk
valid_pairs = sq * (sq + 1) // 2

# one matmul FLOPs (mul + add) = 2 · m · n · k
flops_per_matmul = 2.0 * b * hq * valid_pairs * d_head
total_flops = 2.0 * flops_per_matmul # 2 matmuls in forward

if mode == "fwd":
pass
elif mode == "bwd":
total_flops *= 2.5 # 2·bwd + 0.5·recompute
elif mode == "full":
total_flops *= 3.5 # fwd + bwd
else:
raise ValueError(f"unknown mode {mode}")

return total_flops

def add_tflops_columns(df: pd.DataFrame, func_cfg: FunctionConfig) -> pd.DataFrame:
ms_col = func_cfg.column_name()
tf_col = ms_col.replace("_ms", "_tflops")
flops = df.apply(_flops_single_row, axis=1, mode=func_cfg.mode)
df[tf_col] = flops / df[ms_col] * 1e-9
return df

def main():
"""
Main function to run benchmarks.
@@ -1137,27 +1215,30 @@ def main():
# process args to get function configs and input configs
function_configs, all_input_configs = process_args()

# Check if we have multiple function configurations
has_multiple_func_configs = len(function_configs) > 1
combined_df = None

# run benchmarks for each function configuration
combined_ms_df = None
combined_tf_df = None
input_cols = ["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K", "D_HEAD", "CAUSAL", "DROPOUT"]
for func_config in function_configs:
# run benchmark with the input configs for this function config
input_configs = all_input_configs[func_config]
df = run_benchmark(func_config, input_configs)
df = add_tflops_columns(df, func_config)

# Define the columns that represent input configurations
input_config_cols = ["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K", "D_HEAD", "CAUSAL", "DROPOUT"]

# merge into one final dataframe
if combined_df is None:
combined_df = df
# add to combined table
ms_cols = [c for c in df.columns if c.endswith('_ms')]
tf_cols = [c for c in df.columns if c.endswith('_tflops')]

ms_df = df[input_cols + ms_cols]
tf_df = df[input_cols + tf_cols]

if combined_ms_df is None:
combined_ms_df = ms_df
combined_tf_df = tf_df
else:
# Ensure we're joining on input configuration columns
combined_df = combined_df.merge(df, on=input_config_cols, how="outer")
combined_ms_df = combined_ms_df.merge(ms_df, on=input_cols, how="outer")
combined_tf_df = combined_tf_df.merge(tf_df, on=input_cols, how="outer")


# print new line to seperate the combined data information from the benchmark specific information
print()

@@ -1166,6 +1247,7 @@ def main():
print(f"Total time for all benchmarks: {total_elapsed_time:.2f} seconds")

# save combined data and make comparisons if we have multiple function configs
has_multiple_func_configs = False # len(function_configs) > 1
if has_multiple_func_configs:
if len(function_configs) == 2:
func1 = function_configs[0]
@@ -1194,30 +1276,25 @@ def main():
ratio_col = f"ck_to_triton_ratio"

# Calculate ratio: ck_time / triton_time (values > 1 mean triton is faster)
combined_df[ratio_col] = combined_df[ck_col] / combined_df[triton_col]
combined_ms_df[ratio_col] = combined_ms_df[ck_col] / combined_ms_df[triton_col]

# print explanation
print(f"Comparison Results (triton vs ck):")
print(f"Ratio values: values > 1 mean triton is faster (by that factor), values < 1 mean ck is faster")
elif False:
# For other comparisons, use the standard approach
ratio_col = f"{func1}_to_{func2}_ratio"

# Calculate the ratio
combined_df[ratio_col] = combined_df[col2] / combined_df[col1]

# print explanation
print(f"Comparison Results ({func1} vs {func2}):")
print(f"Ratio values: values > 1 mean {func1} is faster than {func2} (by that factor), values < 1 mean slower")

print(f"Combined data:")
print(combined_df)

# save csv & markdown
combined_filename = f"benchmark_combined"
combined_df.to_csv(f"{combined_filename}.csv", index=False)
with open(f"{combined_filename}.md", 'w') as f:
f.write(combined_df.to_markdown(index=False, floatfmt=".2f"))

if combined_ms_df is not None:
print("\nCombined wall‑time (ms) table:")
print(combined_ms_df)
combined_ms_df.to_csv("benchmark_ms.csv", index=False)
with open("benchmark_ms.md", 'w') as f:
f.write(combined_ms_df.to_markdown(index=False, floatfmt=".2f"))

if combined_tf_df is not None:
print("\nCombined throughput (TFLOPs) table:")
print(combined_tf_df)
combined_tf_df.to_csv("benchmark_tflops.csv", index=False)
with open("benchmark_tflops.md", 'w') as f:
f.write(combined_tf_df.to_markdown(index=False, floatfmt=".2f"))

if __name__ == "__main__":
main()
Loading