Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding model benchmarks #691

Open
wants to merge 10 commits into
base: main_perf
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

import triton
import triton.language as tl
import model_benchmarking
import os


class MetaData():
Expand Down Expand Up @@ -282,7 +284,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
else:
if INT8_KV:
k = (k * k_descale).to(q.type.element_ty)
qk += tl.dot(q, k) * QK_SCALE
qk += tl.dot((q * QK_SCALE).to(q.type.element_ty), k)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had tried this before. Unfortunately, it does not work because of precision loss due to conversion. While the math on paper is the same, here it is first upcasting Q to f32, doing the scalar mult with qk_scale, then downcasting to q.dtype. This downcast affects performance in some cases.


if bias_ptrs is not None:
bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None
Expand Down Expand Up @@ -1870,6 +1872,11 @@ def varlen_benchmark_configs():
return configs


def model_benchmark_configs(batch_size, seq_len):
config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_configs.json")
return model_benchmarking.get_FA_configs(batch_size=batch_size, seq_len=seq_len, config_file=config_file)


def run_benchmark(custom, args):

dtype = arg_to_torch_dtype[args.dtype]
Expand All @@ -1892,6 +1899,9 @@ def run_benchmark(custom, args):
else:
x_vals_list = nonvarlen_benchmark_configs()

if args.model:
x_vals_list = model_benchmark_configs(batch_size=args.b, seq_len=args.sq)

print_time = args.return_time
line_names = 'Time (ms)' if print_time else 'TFLOPS'
configs.append(
Expand Down Expand Up @@ -1976,6 +1986,11 @@ def parse_args():
prog="Benchmark FlashAttention",
allow_abbrev=False,
)

available_models = model_benchmarking.get_available_models() # Dynamically load model names
model_help = ("Model name to benchmark. Select from: [" + ", ".join(available_models) +
"]. Use 'all' to benchmark all models or leave blank for the default benchmark script.")
parser.add_argument("-model", type=str, default=None, help=model_help)
parser.add_argument("-b", type=int, default=0)
parser.add_argument("-hq", type=int, default=0)
parser.add_argument("-hk", type=int, default=0)
Expand Down Expand Up @@ -2006,7 +2021,7 @@ def main():
custom_config = False
assert args.layout == 'thd' or not args.equal_seqlens, \
"Equal sequence lengths arg must be used with the thd layout."
if args.b or args.hq or args.hk or args.sq or args.sk or args.d:
if args.hq or args.hk or args.sk or args.d:
custom_config = True
assert args.b and args.hq and args.sq and args.d, \
"If custom config is specified, please provide \
Expand Down
33 changes: 33 additions & 0 deletions python/perf-kernels/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import pytest
import re

import model_benchmarking
import os


@triton.autotune(
configs=[
Expand Down Expand Up @@ -305,13 +308,26 @@ def benchmark(M, N, K, provider):


# TODO(vgokhale): Add more options to benchmarking
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be you could delete this TODO since you addressed it :)



Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unintentional blank lines?

def parse_args():
parser = argparse.ArgumentParser(
prog="GEMM tutorial example",
allow_abbrev=False,
)

available_models = model_benchmarking.get_available_models() # Dynamically load model names
model_help = ("Model name to benchmark. Select from: [" + ", ".join(available_models) +
"]. Use 'all' to benchmark all models or leave blank for the default benchmark script.")

parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config")
parser.add_argument("-b", type=int, default=None)
parser.add_argument("-sq", type=int, default=None)
parser.add_argument("-model", type=str, default=None, help=model_help)
parser.add_argument("-M", type=int, default=0)
parser.add_argument("-N", type=int, default=0)
parser.add_argument("-K", type=int, default=0)

args = parser.parse_args()

return args
Expand All @@ -323,6 +339,23 @@ def main():
global verbose
args = parse_args()
verbose = args.v

if args.model:
batch_size = args.b if args.b is not None else 1
config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_configs.json")
if args.model.lower() == "all":
# Benchmark all models
x_vals = model_benchmarking.get_mnk(batch_size=batch_size, seq_len=args.sq, config_file=config_file)
else:
# Benchmark a specific model
x_vals = model_benchmarking.get_mnk(batch_size=batch_size, config_file=config_file, seq_len=args.sq,
model_name=args.model)
benchmark.benchmarks.x_vals = x_vals

if args.M and args.N and args.K:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be add an assert that both MNK and model cannot be provided together (because model is actually fixing MNK so the user likely made a mistake if they provided both)?

x_vals = [(args.M, args.N, args.K)]
benchmark.benchmarks.x_vals = x_vals

benchmark.run(show_plots=True, print_data=True)


Expand Down
99 changes: 99 additions & 0 deletions python/perf-kernels/model_benchmarking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import json
import os


def load_model_config(config_file='model_configs.json'):
"""Load all model configurations from a JSON file."""
with open(config_file, 'r') as f:
return json.load(f)


def infer_mnk(model_name, batch_size, seq_len, config_file='model_configs.json'):
"""Infer M, N, and K dimensions for a given model, batch size, and sequence length."""
configs = load_model_config(config_file)
if model_name not in configs:
raise ValueError(f"Model '{model_name}' not found in {config_file}")

config = configs[model_name]
head_count = config["head_count"]
head_dimension = config["head_dimension"]

# Infer M, N, K based on the feedforward network (FFN) dimensions
M = batch_size * seq_len # Total tokens in a batch
K = head_dimension * head_count # Hidden size (d)
N = 4 * K # FFN dimension is typically 4× hidden size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm...I think this depends. On llama3-8b, the intermediate size (which is half the first FF output) is 14336. So the N dim would be 14336 x 2 or more generally 14336 x intermediate_size.

the intermediate size is a config parameter so we should read it from the json.


return M, N, K


def get_mnk(batch_size=1, seq_len=None, config_file='model_configs.json', model_name=None):
"""
Retrieve MNK dimensions for benchmarking. Can return:
- All models (default)
- A specific model if model_name is provided
"""
configs = load_model_config(config_file)
mnk_list = []

if model_name:
# Check if the model exists
if model_name not in configs:
raise ValueError(f"Model '{model_name}' not found in {config_file}")
# Handle a specific model
config = configs[model_name]
max_seq_len = config["max_seq_len"]
actual_seq_len = max_seq_len if seq_len is None else seq_len
M, N, K = infer_mnk(model_name, batch_size, actual_seq_len, config_file)
mnk_list.append((M, N, K))
else:
# Handle all models
for model_name, config in configs.items():
max_seq_len = config["max_seq_len"]
actual_seq_len = seq_len or max_seq_len
if actual_seq_len > max_seq_len:
raise ValueError(f"Sequence length {actual_seq_len} exceeds maximum {max_seq_len} for {model_name}")
M, N, K = infer_mnk(model_name, batch_size, actual_seq_len, config_file)
mnk_list.append((M, N, K))

return mnk_list


def get_available_models(config_file='model_configs.json'):
"""Load model names from the configuration file."""
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), config_file)
with open(config_path, 'r') as f:
configs = json.load(f)
return list(configs.keys())


def get_FA_configs(batch_size=1, seq_len=None, model_name=None, config_file='model_configs.json'):
"""
Retrieve Flash Attention configurations.
Args:
batch_size: Batch size for the configurations.
model_name: Name of the model. If None, return all models.
config_file: Path to the model configuration file.
Returns:
List of Flash Attention configurations as tuples: (BATCH, HQ, HK, N_CTX_Q, N_CTX_K)
"""
configs = load_model_config(config_file)
fa_configs = []

if model_name:
# Check if the model exists
if model_name not in configs:
raise ValueError(f"Model '{model_name}' not found in {config_file}")
# Handle a specific model
config = configs[model_name]
HQ = HK = config["head_count"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This depends on the config as well. All models have a "num_attention_heads" parameter, but some also have "num_key_value_heads". If they list this latter, then the kv heads is different from Q heads.

max_seq_len = config["max_seq_len"]
N_CTX_Q = N_CTX_K = max_seq_len if seq_len is None else seq_len
fa_configs.append((batch_size, HQ, HK, N_CTX_Q, N_CTX_K))
else:
# Handle all models
for model_name, config in configs.items():
HQ = HK = config["head_count"]
N_CTX_Q = N_CTX_K = config["max_seq_len"]
fa_configs.append((batch_size, HQ, HK, N_CTX_Q, N_CTX_K))

return fa_configs
17 changes: 17 additions & 0 deletions python/perf-kernels/model_configs.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"llama3_8B": {
"head_count": 32,
"head_dimension": 128,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually this is backwards - the hidden_size and head_count are provided, and we work out the head_dimension from that.

Here is an example

https://huggingface.co/unsloth/llama-3-8b/blob/main/config.json

"max_seq_len": 8192
},
"llama3_70B": {
"head_count": 64,
"head_dimension": 128,
"max_seq_len": 8192
},
"llama3_405B": {
"head_count": 128,
"head_dimension": 128,
"max_seq_len": 8192
}
}
Loading