-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding model benchmarks #691
base: main_perf
Are you sure you want to change the base?
Changes from all commits
8a8c395
369eeaa
87d9532
4e666aa
abcf328
eaf2de0
1de51d9
68869e7
f178767
715158b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,9 @@ | |
import pytest | ||
import re | ||
|
||
import model_benchmarking | ||
import os | ||
|
||
|
||
@triton.autotune( | ||
configs=[ | ||
|
@@ -305,13 +308,26 @@ def benchmark(M, N, K, provider): | |
|
||
|
||
# TODO(vgokhale): Add more options to benchmarking | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May be you could delete this TODO since you addressed it :) |
||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unintentional blank lines? |
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
prog="GEMM tutorial example", | ||
allow_abbrev=False, | ||
) | ||
|
||
available_models = model_benchmarking.get_available_models() # Dynamically load model names | ||
model_help = ("Model name to benchmark. Select from: [" + ", ".join(available_models) + | ||
"]. Use 'all' to benchmark all models or leave blank for the default benchmark script.") | ||
|
||
parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config") | ||
parser.add_argument("-b", type=int, default=None) | ||
parser.add_argument("-sq", type=int, default=None) | ||
parser.add_argument("-model", type=str, default=None, help=model_help) | ||
parser.add_argument("-M", type=int, default=0) | ||
parser.add_argument("-N", type=int, default=0) | ||
parser.add_argument("-K", type=int, default=0) | ||
|
||
args = parser.parse_args() | ||
|
||
return args | ||
|
@@ -323,6 +339,23 @@ def main(): | |
global verbose | ||
args = parse_args() | ||
verbose = args.v | ||
|
||
if args.model: | ||
batch_size = args.b if args.b is not None else 1 | ||
config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_configs.json") | ||
if args.model.lower() == "all": | ||
# Benchmark all models | ||
x_vals = model_benchmarking.get_mnk(batch_size=batch_size, seq_len=args.sq, config_file=config_file) | ||
else: | ||
# Benchmark a specific model | ||
x_vals = model_benchmarking.get_mnk(batch_size=batch_size, config_file=config_file, seq_len=args.sq, | ||
model_name=args.model) | ||
benchmark.benchmarks.x_vals = x_vals | ||
|
||
if args.M and args.N and args.K: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May be add an assert that both MNK and model cannot be provided together (because model is actually fixing MNK so the user likely made a mistake if they provided both)? |
||
x_vals = [(args.M, args.N, args.K)] | ||
benchmark.benchmarks.x_vals = x_vals | ||
|
||
benchmark.run(show_plots=True, print_data=True) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import json | ||
import os | ||
|
||
|
||
def load_model_config(config_file='model_configs.json'): | ||
"""Load all model configurations from a JSON file.""" | ||
with open(config_file, 'r') as f: | ||
return json.load(f) | ||
|
||
|
||
def infer_mnk(model_name, batch_size, seq_len, config_file='model_configs.json'): | ||
"""Infer M, N, and K dimensions for a given model, batch size, and sequence length.""" | ||
configs = load_model_config(config_file) | ||
if model_name not in configs: | ||
raise ValueError(f"Model '{model_name}' not found in {config_file}") | ||
|
||
config = configs[model_name] | ||
head_count = config["head_count"] | ||
head_dimension = config["head_dimension"] | ||
|
||
# Infer M, N, K based on the feedforward network (FFN) dimensions | ||
M = batch_size * seq_len # Total tokens in a batch | ||
K = head_dimension * head_count # Hidden size (d) | ||
N = 4 * K # FFN dimension is typically 4× hidden size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm...I think this depends. On llama3-8b, the intermediate size (which is half the first FF output) is 14336. So the N dim would be 14336 x 2 or more generally 14336 x intermediate_size. the intermediate size is a config parameter so we should read it from the json. |
||
|
||
return M, N, K | ||
|
||
|
||
def get_mnk(batch_size=1, seq_len=None, config_file='model_configs.json', model_name=None): | ||
""" | ||
Retrieve MNK dimensions for benchmarking. Can return: | ||
- All models (default) | ||
- A specific model if model_name is provided | ||
""" | ||
configs = load_model_config(config_file) | ||
mnk_list = [] | ||
|
||
if model_name: | ||
# Check if the model exists | ||
if model_name not in configs: | ||
raise ValueError(f"Model '{model_name}' not found in {config_file}") | ||
# Handle a specific model | ||
config = configs[model_name] | ||
max_seq_len = config["max_seq_len"] | ||
actual_seq_len = max_seq_len if seq_len is None else seq_len | ||
M, N, K = infer_mnk(model_name, batch_size, actual_seq_len, config_file) | ||
mnk_list.append((M, N, K)) | ||
else: | ||
# Handle all models | ||
for model_name, config in configs.items(): | ||
max_seq_len = config["max_seq_len"] | ||
actual_seq_len = seq_len or max_seq_len | ||
if actual_seq_len > max_seq_len: | ||
raise ValueError(f"Sequence length {actual_seq_len} exceeds maximum {max_seq_len} for {model_name}") | ||
M, N, K = infer_mnk(model_name, batch_size, actual_seq_len, config_file) | ||
mnk_list.append((M, N, K)) | ||
|
||
return mnk_list | ||
|
||
|
||
def get_available_models(config_file='model_configs.json'): | ||
"""Load model names from the configuration file.""" | ||
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), config_file) | ||
with open(config_path, 'r') as f: | ||
configs = json.load(f) | ||
return list(configs.keys()) | ||
|
||
|
||
def get_FA_configs(batch_size=1, seq_len=None, model_name=None, config_file='model_configs.json'): | ||
""" | ||
Retrieve Flash Attention configurations. | ||
Args: | ||
batch_size: Batch size for the configurations. | ||
model_name: Name of the model. If None, return all models. | ||
config_file: Path to the model configuration file. | ||
Returns: | ||
List of Flash Attention configurations as tuples: (BATCH, HQ, HK, N_CTX_Q, N_CTX_K) | ||
""" | ||
configs = load_model_config(config_file) | ||
fa_configs = [] | ||
|
||
if model_name: | ||
# Check if the model exists | ||
if model_name not in configs: | ||
raise ValueError(f"Model '{model_name}' not found in {config_file}") | ||
# Handle a specific model | ||
config = configs[model_name] | ||
HQ = HK = config["head_count"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This depends on the config as well. All models have a "num_attention_heads" parameter, but some also have "num_key_value_heads". If they list this latter, then the kv heads is different from Q heads. |
||
max_seq_len = config["max_seq_len"] | ||
N_CTX_Q = N_CTX_K = max_seq_len if seq_len is None else seq_len | ||
fa_configs.append((batch_size, HQ, HK, N_CTX_Q, N_CTX_K)) | ||
else: | ||
# Handle all models | ||
for model_name, config in configs.items(): | ||
HQ = HK = config["head_count"] | ||
N_CTX_Q = N_CTX_K = config["max_seq_len"] | ||
fa_configs.append((batch_size, HQ, HK, N_CTX_Q, N_CTX_K)) | ||
|
||
return fa_configs |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
{ | ||
"llama3_8B": { | ||
"head_count": 32, | ||
"head_dimension": 128, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Usually this is backwards - the hidden_size and head_count are provided, and we work out the head_dimension from that. Here is an example https://huggingface.co/unsloth/llama-3-8b/blob/main/config.json |
||
"max_seq_len": 8192 | ||
}, | ||
"llama3_70B": { | ||
"head_count": 64, | ||
"head_dimension": 128, | ||
"max_seq_len": 8192 | ||
}, | ||
"llama3_405B": { | ||
"head_count": 128, | ||
"head_dimension": 128, | ||
"max_seq_len": 8192 | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We had tried this before. Unfortunately, it does not work because of precision loss due to conversion. While the math on paper is the same, here it is first upcasting Q to f32, doing the scalar mult with qk_scale, then downcasting to q.dtype. This downcast affects performance in some cases.