-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding model benchmarks #691
base: main_perf
Are you sure you want to change the base?
Conversation
…bias[dtype0-True-True-2-4-7-16219-64] by adding the qk += (tl.dot(q, k) * QK_SCALE).to(q.type.element_ty) conversion
@@ -282,7 +284,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri | |||
else: | |||
if INT8_KV: | |||
k = (k * k_descale).to(q.type.element_ty) | |||
qk += tl.dot(q, k) * QK_SCALE | |||
qk += tl.dot((q * QK_SCALE).to(q.type.element_ty), k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We had tried this before. Unfortunately, it does not work because of precision loss due to conversion. While the math on paper is the same, here it is first upcasting Q to f32, doing the scalar mult with qk_scale, then downcasting to q.dtype. This downcast affects performance in some cases.
@@ -305,13 +308,26 @@ def benchmark(M, N, K, provider): | |||
|
|||
|
|||
# TODO(vgokhale): Add more options to benchmarking | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unintentional blank lines?
@@ -305,13 +308,26 @@ def benchmark(M, N, K, provider): | |||
|
|||
|
|||
# TODO(vgokhale): Add more options to benchmarking |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be you could delete this TODO since you addressed it :)
model_name=args.model) | ||
benchmark.benchmarks.x_vals = x_vals | ||
|
||
if args.M and args.N and args.K: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be add an assert that both MNK and model cannot be provided together (because model is actually fixing MNK so the user likely made a mistake if they provided both)?
{ | ||
"llama3_8B": { | ||
"head_count": 32, | ||
"head_dimension": 128, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually this is backwards - the hidden_size and head_count are provided, and we work out the head_dimension from that.
Here is an example
https://huggingface.co/unsloth/llama-3-8b/blob/main/config.json
# Infer M, N, K based on the feedforward network (FFN) dimensions | ||
M = batch_size * seq_len # Total tokens in a batch | ||
K = head_dimension * head_count # Hidden size (d) | ||
N = 4 * K # FFN dimension is typically 4× hidden size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm...I think this depends. On llama3-8b, the intermediate size (which is half the first FF output) is 14336. So the N dim would be 14336 x 2 or more generally 14336 x intermediate_size.
the intermediate size is a config parameter so we should read it from the json.
raise ValueError(f"Model '{model_name}' not found in {config_file}") | ||
# Handle a specific model | ||
config = configs[model_name] | ||
HQ = HK = config["head_count"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This depends on the config as well. All models have a "num_attention_heads" parameter, but some also have "num_key_value_heads". If they list this latter, then the kv heads is different from Q heads.
rules.
pre-commit run --from-ref origin/main --to-ref HEAD
.lit
tests.This PR adds advanced benchmarking for kernels inside perf-kernels. Aim is to have more comparable benchmark results by taking shapes from actual llms. model_configs.json holds the configs for various models, which we can then read in example for gemm.py and flash-attention.py to get benchmarking shapes for these kernels:
e.g.
python python/perf-kernels/gemm.py -model llama3_8B
To run the gemm kernel with a gemm shape taken from the 1st layer of the FNN of a Llama 3 8B model.
python python/perf-kernels/flash-attention.py -model all -b 2 -sq 1024
To run the flash attention kernel with the shapes from all the models in model_configs.json (currently llama3_8B, llama3_70B, llama3_405B).