Skip to content

Commit

Permalink
Merge branch 'master' into docs
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunsuresh authored Aug 20, 2024
2 parents 98b945c + c7db1c3 commit 4940585
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
1 change: 1 addition & 0 deletions tools/submission/generate_final_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def main():
'Latency (ms)',
'Samples/s',
'Queries/s',
'Tokens/s',
'millijoules',
'Watts',
]]
Expand Down
45 changes: 40 additions & 5 deletions tools/submission/submission_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,12 +962,13 @@ def check_accuracy_dir(config, model, path, verbose):
is_valid = False
all_accuracy_valid = True
acc = None
result_acc = None
result_acc = {}
hash_val = None
target = config.get_accuracy_target(model)
acc_upper_limit = config.get_accuracy_upper_limit(model)
patterns = []
acc_targets = []
acc_types = []
if acc_upper_limit is not None:
acc_limits = []
up_patterns = []
Expand All @@ -981,10 +982,11 @@ def check_accuracy_dir(config, model, path, verbose):
acc_type, acc_target = target[i:i+2]
patterns.append(ACC_PATTERN[acc_type])
acc_targets.append(acc_target)
acc_types.append(acc_type)
acc_seen = [False for _ in acc_targets]
with open(os.path.join(path, "accuracy.txt"), "r", encoding="utf-8") as f:
for line in f:
for i, (pattern, acc_target) in enumerate(zip(patterns, acc_targets)):
for i, (pattern, acc_target, acc_type) in enumerate(zip(patterns, acc_targets, acc_types)):
m = re.match(pattern, line)
if m:
acc = m.group(1)
Expand All @@ -997,8 +999,8 @@ def check_accuracy_dir(config, model, path, verbose):
elif acc is not None:
all_accuracy_valid = False
log.warning("%s accuracy not met: expected=%f, found=%s", path, acc_target, acc)
if i == 0 and acc:
result_acc = acc
if acc:
result_acc[acc_type] = acc
acc = None
if acc_upper_limit is not None:
for i, (pattern, acc_limit) in enumerate(zip(up_patterns, acc_limits)):
Expand Down Expand Up @@ -1590,6 +1592,38 @@ def log_result(
if system_json.get("sw_notes"):
notes = notes + ". " if notes else ""
notes = notes + system_json.get("sw_notes")
special_unit_dict = {
"gptj-99": {
"SingleStream": "Latency (ms)",
"MultiStream": "Latency (ms)",
"Offline": "Tokens/s",
"Server": "Tokens/s",
},
"gptj-99.9": {
"SingleStream": "Latency (ms)",
"MultiStream": "Latency (ms)",
"Offline": "Tokens/s",
"Server": "Tokens/s",
},
"llama2-70b-99" : {
"SingleStream": "Latency (ms)",
"MultiStream": "Latency (ms)",
"Offline": "Tokens/s",
"Server": "Tokens/s",
},
"llama2-70b-99.9" : {
"SingleStream": "Latency (ms)",
"MultiStream": "Latency (ms)",
"Offline": "Tokens/s",
"Server": "Tokens/s",
},
"mixtral-8x7b" : {
"SingleStream": "Latency (ms)",
"MultiStream": "Latency (ms)",
"Offline": "Tokens/s",
"Server": "Tokens/s",
}
}
unit_dict = {
"SingleStream": "Latency (ms)",
"MultiStream": "Latency (ms)",
Expand All @@ -1602,7 +1636,7 @@ def log_result(
"Offline": "Watts",
"Server": "Watts",
}
unit = unit_dict[scenario_fixed]
unit = special_unit_dict.get(model_name, unit_dict)[scenario_fixed]
power_unit = power_unit_dict[scenario_fixed]

csv.write(
Expand Down Expand Up @@ -2012,6 +2046,7 @@ def log_result(
acc_path,
debug or is_closed_or_network,
)
acc = json.dumps(acc).replace(",", " ").replace('"', "").replace("{", "").replace("}", "")
if mlperf_model in REQUIRED_ACC_BENCHMARK:
if config.version in REQUIRED_ACC_BENCHMARK[mlperf_model]:
extra_files_pass, missing_files = check_extra_files(acc_path, REQUIRED_ACC_BENCHMARK[mlperf_model][config.version])
Expand Down

0 comments on commit 4940585

Please sign in to comment.