Skip to content

Commit

Permalink
Updated python CLI for flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
suryajasper committed Jul 30, 2024
1 parent 818d890 commit 2fccdc8
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 44 deletions.
16 changes: 13 additions & 3 deletions gemmbench/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,19 @@ def roofline(results=None, **kwargs):
data.append(dict(serial=int(serial), **experiment_group.attrs))

for item in data:
M, N, K = item['M'], item['N'], item['K']
flops = 2 * M * N * K
bytes = M * K + N * K + M * N
flops = 0
bytes = 1

if 'sharkfa' in result_file:
S_Q, S_KV, DH = item['M'], item['N'], item['K']
B, H = ord(item['A'][0]), ord(item['B'][0])
flops = 4 * S_Q * S_KV * DH * B * H
bytes = B * H * 2 * (2 * S_KV * DH + 2 * S_Q * DH + S_Q * S_KV)
else:
M, N, K = item['M'], item['N'], item['K']
flops = 2 * M * N * K
bytes = M * K + N * K + M * N

item['arithmetic_intensity'] = flops / bytes
item['tflops'] = (flops / 1e12) / (item['mean_microseconds'] / 1e6)

Expand Down
19 changes: 19 additions & 0 deletions gemmbench/problems.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""GEMM AI Performace problem suites."""

from gbm import Problem as GEMM, Configuration, Solution
import itertools

def is_compute_bound(M, N, K, bpe):
"""Is this GEMM compute (or memory) bound?"""
Expand Down Expand Up @@ -815,6 +816,24 @@ def unet():
for m, n, k in UNET:
yield GEMM("unet", m, n, k, tA, tB, dtype)

def flash_attention():
batch_sizes = [1, 2, 4]
head_counts = [12, 16, 24, 32]
seq_lengths = [64, 128, 256, 384, 512, 1024, 2048, 4096, 8192]
head_dims = [16, 32, 64, 128]
datatypes = ["fp16"]

for B, H, S, DH, datatype in itertools.product(batch_sizes, head_counts, seq_lengths, head_dims, datatypes):
S_Q = S
S_KV = S

yield GEMM("flash_attention", S_Q, S_KV, DH, str(chr(B)), str(chr(H)), datatype)

if S_KV > 64:
yield GEMM("flash_attention", S_KV // 2, S_KV, DH, str(chr(B)), str(chr(H)), datatype)



def all():
yield from llama13bmatvec()
yield from llama13bmatvecbf16()
Expand Down
139 changes: 98 additions & 41 deletions gemmbench/shark_fa2.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,110 @@
import os
import iree.compiler as ireec
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
import itertools

# Function to generate MLIR content and write to a file
def generate_mlir_file(B, H, S_Q, S_KV, DH, datatype):
def generate_attention_shapes():
batch_sizes = [1, 2, 4]
head_counts = [12, 16, 24, 32, 40, 48, 64]
seq_lengths = [64, 128, 256, 384, 512, 768, 1024, 2048, 4096, 8192]
head_dims = [16, 32, 64, 128, 256]
datatypes = ["f16"]

shapes = []
for B, H, S, DH, datatype in itertools.product(batch_sizes, head_counts, seq_lengths, head_dims, datatypes):
S_Q = S
S_KV = S

shapes.append((B, H, S_Q, S_KV, DH, datatype))

if S_KV > 64:
shapes.append((B, H, S_KV // 2, S_KV, DH, datatype))
if S_Q > 64:
shapes.append((B, H, S_Q, S_Q // 2, DH, datatype))

return shapes

def generate_mlir_content(B, H, S_Q, S_KV, DH, datatype):
key_shape = f"[{B},{H},{S_KV},{DH}]"
query_shape = f"[{B},{H},{S_Q},{DH}]"
value_shape = f"[{B},{H},{S_KV},{DH}]"
output_shape = f"[{B},{H},{S_Q},{DH}]"

mlir_template = f"""
func.func @main(%295 : !torch.vtensor<{key_shape},{datatype}>, %298 : !torch.vtensor<{query_shape},{datatype}>, %301 : !torch.vtensor<{value_shape},{datatype}>) -> !torch.vtensor<{output_shape},{datatype}> {{
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<{query_shape},{datatype}>, !torch.vtensor<{key_shape},{datatype}>, !torch.vtensor<{value_shape},{datatype}>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<{output_shape},{datatype}>, !torch.vtensor<[{B},{H},{S_Q}], f32>)
return %282#0 : !torch.vtensor<{output_shape},{datatype}>
module {{
func.func @main_0(%295 : !torch.vtensor<{query_shape},{datatype}>, %298 : !torch.vtensor<{key_shape},{datatype}>, %301 : !torch.vtensor<{value_shape},{datatype}>) -> !torch.vtensor<{output_shape},{datatype}> {{
%false_371 = torch.constant.bool false
%float0.000000e00 = torch.constant.float 0.000000e+00
%none_372 = torch.constant.none
%none_373 = torch.constant.none
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<{query_shape},{datatype}>, !torch.vtensor<{key_shape},{datatype}>, !torch.vtensor<{value_shape},{datatype}>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<{output_shape},{datatype}>, !torch.vtensor<[{B},{H},{S_Q}], f32>)
return %282#0 : !torch.vtensor<{output_shape},{datatype}>
}}
}}
"""
return mlir_template

def compile_shape(shape):
B, H, S_Q, S_KV, DH, datatype = shape

# Generate MLIR content
mlir_content = generate_mlir_content(B, H, S_Q, S_KV, DH, datatype)

# Generate filenames
mlir_filename = f"attention/mlir/attention_B{B}_H{H}_SQ{S_Q}_SKV{S_KV}_DH{DH}_{datatype}.mlir"
vmfb_filename = f"attention/vmfb/attention_B{B}_H{H}_SQ{S_Q}_SKV{S_KV}_DH{DH}_{datatype}.vmfb"

# Write MLIR content to file
with open(mlir_filename, 'w') as f:
f.write(mlir_content)

# Compile MLIR to VMFB
compile_options = ireec.CompilerOptions()
compile_options.hal_target_backends = ["rocm"]
compile_options.rocm_target_chip = "gfx942"

try:
compiled_binary = ireec.compile_str(
mlir_content,
target_backends=compile_options.hal_target_backends,
input_type="torch",
output_format="FLATBUFFER_BINARY",
extra_args=[
f"--iree-rocm-target-chip={compile_options.rocm_target_chip}",
"--iree-global-opt-propagate-transposes=true",
"--iree-opt-outer-dim-concat=true",
"--iree-opt-const-eval=false",
"--iree-opt-data-tiling=false",
"--iree-rocm-waves-per-eu=2",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-llvmgpu-use-vector-distribution",
"--iree-codegen-gpu-native-math-precision=true",
"--iree-flow-enable-aggressive-fusion",
f"--dump-compilation-phases-to=compile_phases_{B}_{H}_{S_Q}_{S_KV}_{DH}_{datatype}",
]
)

# Write the compiled binary to the VMFB file
with open(vmfb_filename, 'wb') as f:
f.write(compiled_binary)

return f"Successfully compiled {mlir_filename} to {vmfb_filename}"
except Exception as e:
return f"Error compiling {mlir_filename}: {str(e)}"

# Generate unique filename based on parameters
filename = f"attention_B{B}_H{H}_SQ{S_Q}_SKV{S_KV}_DH{DH}_{datatype}.mlir"
with open(filename, 'w') as f:
f.write(mlir_template)

return filename

# List of known attention shapes in popular LLM architectures
known_shapes = [
(1, 42, 384, 64320, 64, "f16")
# (1, 12, 512, 512, 64, "f16"), # Example shape for BERT base
# (1, 16, 1024, 1024, 64, "f16"), # Example shape for GPT-3 small
# (1, 12, 384, 384, 64, "f16"), # Example shape for some other model
]

# Function to add more shapes iteratively
def add_more_shapes(shape_list):
# Example of adding more shapes
shape_list.append((1, 8, 256, 256, 32, "f16")) # Custom shape 1
shape_list.append((1, 32, 2048, 2048, 128, "f32")) # Custom shape 2
# Add more shapes as needed
return shape_list

# Main script
if __name__ == "__main__":
# Add more shapes to the known list
# known_shapes = add_more_shapes(known_shapes)

# Generate MLIR files for each shape in the list
for shape in known_shapes:
B, H, S_Q, S_KV, DH, datatype = shape
filename = generate_mlir_file(B, H, S_Q, S_KV, DH, datatype)
print(f"MLIR file '{filename}' generated successfully.")
shapes = generate_attention_shapes()
print(f"Generated {len(shapes)} attention shapes.")

num_cpus = max(1, cpu_count() - 20)
print(f"Using {num_cpus} CPUs for parallel processing.")

with Pool(num_cpus) as pool:
results = list(tqdm(pool.imap(compile_shape, shapes), total=len(shapes)))

for result in results:
if 'error' in result.lower():
print(result)

print("Compilation process completed.")

0 comments on commit 2fccdc8

Please sign in to comment.