-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updated python CLI for flash attention
- Loading branch information
1 parent
818d890
commit 2fccdc8
Showing
3 changed files
with
130 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,53 +1,110 @@ | ||
import os | ||
import iree.compiler as ireec | ||
from tqdm import tqdm | ||
from multiprocessing import Pool, cpu_count | ||
import itertools | ||
|
||
# Function to generate MLIR content and write to a file | ||
def generate_mlir_file(B, H, S_Q, S_KV, DH, datatype): | ||
def generate_attention_shapes(): | ||
batch_sizes = [1, 2, 4] | ||
head_counts = [12, 16, 24, 32, 40, 48, 64] | ||
seq_lengths = [64, 128, 256, 384, 512, 768, 1024, 2048, 4096, 8192] | ||
head_dims = [16, 32, 64, 128, 256] | ||
datatypes = ["f16"] | ||
|
||
shapes = [] | ||
for B, H, S, DH, datatype in itertools.product(batch_sizes, head_counts, seq_lengths, head_dims, datatypes): | ||
S_Q = S | ||
S_KV = S | ||
|
||
shapes.append((B, H, S_Q, S_KV, DH, datatype)) | ||
|
||
if S_KV > 64: | ||
shapes.append((B, H, S_KV // 2, S_KV, DH, datatype)) | ||
if S_Q > 64: | ||
shapes.append((B, H, S_Q, S_Q // 2, DH, datatype)) | ||
|
||
return shapes | ||
|
||
def generate_mlir_content(B, H, S_Q, S_KV, DH, datatype): | ||
key_shape = f"[{B},{H},{S_KV},{DH}]" | ||
query_shape = f"[{B},{H},{S_Q},{DH}]" | ||
value_shape = f"[{B},{H},{S_KV},{DH}]" | ||
output_shape = f"[{B},{H},{S_Q},{DH}]" | ||
|
||
mlir_template = f""" | ||
func.func @main(%295 : !torch.vtensor<{key_shape},{datatype}>, %298 : !torch.vtensor<{query_shape},{datatype}>, %301 : !torch.vtensor<{value_shape},{datatype}>) -> !torch.vtensor<{output_shape},{datatype}> {{ | ||
%false_371 = torch.constant.bool false | ||
%float0.000000e00 = torch.constant.float 0.000000e+00 | ||
%none_372 = torch.constant.none | ||
%none_373 = torch.constant.none | ||
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<{query_shape},{datatype}>, !torch.vtensor<{key_shape},{datatype}>, !torch.vtensor<{value_shape},{datatype}>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<{output_shape},{datatype}>, !torch.vtensor<[{B},{H},{S_Q}], f32>) | ||
return %282#0 : !torch.vtensor<{output_shape},{datatype}> | ||
module {{ | ||
func.func @main_0(%295 : !torch.vtensor<{query_shape},{datatype}>, %298 : !torch.vtensor<{key_shape},{datatype}>, %301 : !torch.vtensor<{value_shape},{datatype}>) -> !torch.vtensor<{output_shape},{datatype}> {{ | ||
%false_371 = torch.constant.bool false | ||
%float0.000000e00 = torch.constant.float 0.000000e+00 | ||
%none_372 = torch.constant.none | ||
%none_373 = torch.constant.none | ||
%282:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%295, %298, %301, %float0.000000e00, %false_371, %none_372, %none_373) : (!torch.vtensor<{query_shape},{datatype}>, !torch.vtensor<{key_shape},{datatype}>, !torch.vtensor<{value_shape},{datatype}>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<{output_shape},{datatype}>, !torch.vtensor<[{B},{H},{S_Q}], f32>) | ||
return %282#0 : !torch.vtensor<{output_shape},{datatype}> | ||
}} | ||
}} | ||
""" | ||
return mlir_template | ||
|
||
def compile_shape(shape): | ||
B, H, S_Q, S_KV, DH, datatype = shape | ||
|
||
# Generate MLIR content | ||
mlir_content = generate_mlir_content(B, H, S_Q, S_KV, DH, datatype) | ||
|
||
# Generate filenames | ||
mlir_filename = f"attention/mlir/attention_B{B}_H{H}_SQ{S_Q}_SKV{S_KV}_DH{DH}_{datatype}.mlir" | ||
vmfb_filename = f"attention/vmfb/attention_B{B}_H{H}_SQ{S_Q}_SKV{S_KV}_DH{DH}_{datatype}.vmfb" | ||
|
||
# Write MLIR content to file | ||
with open(mlir_filename, 'w') as f: | ||
f.write(mlir_content) | ||
|
||
# Compile MLIR to VMFB | ||
compile_options = ireec.CompilerOptions() | ||
compile_options.hal_target_backends = ["rocm"] | ||
compile_options.rocm_target_chip = "gfx942" | ||
|
||
try: | ||
compiled_binary = ireec.compile_str( | ||
mlir_content, | ||
target_backends=compile_options.hal_target_backends, | ||
input_type="torch", | ||
output_format="FLATBUFFER_BINARY", | ||
extra_args=[ | ||
f"--iree-rocm-target-chip={compile_options.rocm_target_chip}", | ||
"--iree-global-opt-propagate-transposes=true", | ||
"--iree-opt-outer-dim-concat=true", | ||
"--iree-opt-const-eval=false", | ||
"--iree-opt-data-tiling=false", | ||
"--iree-rocm-waves-per-eu=2", | ||
"--iree-vm-target-truncate-unsupported-floats", | ||
"--iree-codegen-llvmgpu-use-vector-distribution", | ||
"--iree-codegen-gpu-native-math-precision=true", | ||
"--iree-flow-enable-aggressive-fusion", | ||
f"--dump-compilation-phases-to=compile_phases_{B}_{H}_{S_Q}_{S_KV}_{DH}_{datatype}", | ||
] | ||
) | ||
|
||
# Write the compiled binary to the VMFB file | ||
with open(vmfb_filename, 'wb') as f: | ||
f.write(compiled_binary) | ||
|
||
return f"Successfully compiled {mlir_filename} to {vmfb_filename}" | ||
except Exception as e: | ||
return f"Error compiling {mlir_filename}: {str(e)}" | ||
|
||
# Generate unique filename based on parameters | ||
filename = f"attention_B{B}_H{H}_SQ{S_Q}_SKV{S_KV}_DH{DH}_{datatype}.mlir" | ||
with open(filename, 'w') as f: | ||
f.write(mlir_template) | ||
|
||
return filename | ||
|
||
# List of known attention shapes in popular LLM architectures | ||
known_shapes = [ | ||
(1, 42, 384, 64320, 64, "f16") | ||
# (1, 12, 512, 512, 64, "f16"), # Example shape for BERT base | ||
# (1, 16, 1024, 1024, 64, "f16"), # Example shape for GPT-3 small | ||
# (1, 12, 384, 384, 64, "f16"), # Example shape for some other model | ||
] | ||
|
||
# Function to add more shapes iteratively | ||
def add_more_shapes(shape_list): | ||
# Example of adding more shapes | ||
shape_list.append((1, 8, 256, 256, 32, "f16")) # Custom shape 1 | ||
shape_list.append((1, 32, 2048, 2048, 128, "f32")) # Custom shape 2 | ||
# Add more shapes as needed | ||
return shape_list | ||
|
||
# Main script | ||
if __name__ == "__main__": | ||
# Add more shapes to the known list | ||
# known_shapes = add_more_shapes(known_shapes) | ||
|
||
# Generate MLIR files for each shape in the list | ||
for shape in known_shapes: | ||
B, H, S_Q, S_KV, DH, datatype = shape | ||
filename = generate_mlir_file(B, H, S_Q, S_KV, DH, datatype) | ||
print(f"MLIR file '{filename}' generated successfully.") | ||
shapes = generate_attention_shapes() | ||
print(f"Generated {len(shapes)} attention shapes.") | ||
|
||
num_cpus = max(1, cpu_count() - 20) | ||
print(f"Using {num_cpus} CPUs for parallel processing.") | ||
|
||
with Pool(num_cpus) as pool: | ||
results = list(tqdm(pool.imap(compile_shape, shapes), total=len(shapes))) | ||
|
||
for result in results: | ||
if 'error' in result.lower(): | ||
print(result) | ||
|
||
print("Compilation process completed.") |