Skip to content

Commit

Permalink
Merge pull request #1059 from AI-Hypercomputer:patemotter_acc_eval
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700018264
  • Loading branch information
maxtext authors committed Nov 25, 2024
2 parents ea717d9 + 7eca5f9 commit ee01072
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 14 deletions.
196 changes: 196 additions & 0 deletions MaxText/inference_mlperf/evaluate-accuracy-fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import evaluate
import json
import nltk
import numpy as np
import pandas as pd
import tqdm

from multiprocessing import Pool, cpu_count
from functools import partial
from transformers import AutoTokenizer
from typing import List, Dict, Tuple


def split_data(preds: List[str], refs: List[str], num_chunks: int) -> List[Tuple[List[str], List[str]]]:
"""Split predictions and references into roughly equal chunks"""
chunk_size = len(preds) // num_chunks + (1 if len(preds) % num_chunks else 0)
chunks = []

for i in range(0, len(preds), chunk_size):
chunk_preds = preds[i : i + chunk_size]
chunk_refs = refs[i : i + chunk_size]
chunks.append((chunk_preds, chunk_refs))

return chunks


def compute_rouge_chunk(chunk: Tuple[List[str], List[str]], metric) -> Dict:
"""Compute ROUGE scores for a chunk of data"""
preds, refs = chunk
return metric.compute(predictions=preds, references=refs, use_stemmer=True, use_aggregator=False)


def aggregate_rouge_scores(chunk_results: List[Dict]) -> Dict:
"""Aggregate ROUGE scores from chunks"""
# Concatenate all scores
all_scores = {}
for scores in chunk_results:
for metric, values in scores.items():
if metric not in all_scores:
all_scores[metric] = []
all_scores[metric].extend(values)

return {metric: round(np.mean(values) * 100, 4) for metric, values in all_scores.items()}


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint-path", required=True, help="Path to Llama2-70b-hf-chat checkpoint")
parser.add_argument("--mlperf-accuracy-file", required=True, help="path to mlperf_log_accuracy.json")
parser.add_argument("--dataset-file", required=True, help="path to processed openorca validation set")
parser.add_argument("--verbose", action="store_true", help="verbose messages")
parser.add_argument("--dtype", default="int64", help="dtype of the accuracy log", choices=["int32", "int64", "float"])
parser.add_argument("--num-workers", type=int, default=None, help="Number of worker processes (default: CPU count)")
args = parser.parse_args()
return args


def get_groundtruth(processed_dataset_file):
data = pd.read_pickle(processed_dataset_file)
return data["output"]


def process_batch(batch, tokenizer, eval_dtype):
"""Process a batch of predictions"""
preds_token_ids = []
seen = set()
gen_tok_len = 0
target_indices = []

for pred in batch:
qsl_idx = pred["qsl_idx"]
if qsl_idx in seen:
continue

seen.add(qsl_idx)
target_indices.append(qsl_idx)

pred_data = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype)
if pred_data[0] > 32000 or pred_data[0] < 0:
pred_data = np.concatenate([[1], pred_data[1:]])

gen_tok_len += len(pred_data)
preds_token_ids.append(pred_data)

# Batch decode predictions
preds_decoded = tokenizer.batch_decode(preds_token_ids, skip_special_tokens=True)
return preds_decoded, target_indices, gen_tok_len


def postprocess_text(pred, target):
"""Process a single prediction-target pair"""
pred = pred.strip()
target = target.strip()

# rougeLSum expects newline after each sentence
pred = "\n".join(nltk.sent_tokenize(pred))
target = "\n".join(nltk.sent_tokenize(target))

return pred, target


def chunk_list(lst, n):
"""Split list into n roughly equal chunks"""
chunk_size = len(lst) // n + (1 if len(lst) % n else 0)
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]


def main():
args = get_args()
num_workers = args.num_workers or cpu_count()
print(f"Using {num_workers} worker processes")

print(f"Loading checkpoint from {args.checkpoint_path}")
tokenizer = AutoTokenizer.from_pretrained(
args.checkpoint_path,
model_max_length=2048,
padding_side="left",
use_fast=False,
)

metric = evaluate.load("rouge")
nltk.download("punkt", quiet=True)

print(f"Getting groundtruth from {args.dataset_file}")
targets = get_groundtruth(args.dataset_file)

eval_dtype = {"int32": np.int32, "int64": np.int64, "float": np.float32}[args.dtype]

print(f"Loading accuracy log from {args.mlperf_accuracy_file}")
with open(args.mlperf_accuracy_file, "r") as f:
results = json.load(f)

# Split results into chunks for parallel processing
result_chunks = chunk_list(results, num_workers)

# Process predictions in parallel
process_func = partial(process_batch, tokenizer=tokenizer, eval_dtype=eval_dtype)
total_gen_tok_len = 0
all_preds = []
all_target_indices = []

print("Processing predictions...")
with Pool(num_workers) as pool:
for preds, target_indices, gen_tok_len in tqdm.tqdm(pool.imap(process_func, result_chunks), total=len(result_chunks)):
all_preds.extend(preds)
all_target_indices.extend(target_indices)
total_gen_tok_len += gen_tok_len

target_required = [targets[idx] for idx in all_target_indices]

# Parallel postprocessing of texts
print("Post-processing texts...")
with Pool(num_workers) as pool:
processed_pairs = list(tqdm.tqdm(pool.starmap(postprocess_text, zip(all_preds, target_required)), total=len(all_preds)))
preds, refs = zip(*processed_pairs)

# Split data into chunks for parallel ROUGE computation
print("Computing ROUGE scores...")
data_chunks = split_data(preds, refs, num_workers)
with Pool(num_workers) as pool:
chunk_results = list(
tqdm.tqdm(pool.imap(partial(compute_rouge_chunk, metric=metric), data_chunks), total=len(data_chunks))
)
rouge_scores = aggregate_rouge_scores(chunk_results)

prediction_lens = [len(pred) for pred in preds]
gen_num = len(preds)
result = {
**rouge_scores,
"gen_len": np.sum(prediction_lens),
"gen_num": gen_num,
"gen_tok_len": total_gen_tok_len,
"tokens_per_sample": round(total_gen_tok_len / gen_num, 1),
}

print("\nResults\n")
print(result)


if __name__ == "__main__":
main()
32 changes: 18 additions & 14 deletions MaxText/inference_mlperf/llama_offline_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@ enable_profiler=false
performance=true
audit=false
accuracy=false
fast_eval=false


while getopts "ntspdar:" opt
while getopts "ntspdarfr:" opt
do
case "$opt" in
n ) dry_run=true ;;
t ) test_run=true ;;
t ) test_run=true ;;
s ) skip_warmup=true ;;
p ) enable_profiler=true ;;
d ) audit=true ;;
a ) accuracy=true ;;
f ) fast_eval=true ;;
r ) run_name="$OPTARG" ;;
? ) helpFunction ;; # Print helpFunction in case parameter is non-existent
esac
Expand Down Expand Up @@ -101,13 +102,11 @@ export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2"
export LIBTPU_INIT_ARGS

run_loadgen() {

OUTPUT_LOG_ID=llama70b-${run_name}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TYPE}_${LOADGEN_RUN_TIMESTAMP}
OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID}
mkdir -p ${OUTPUT_LOG_DIR} && cp ${USER_CONFIG} ${OUTPUT_LOG_DIR}
OUTPUT_ACCURACY_JSON_PATH=${OUTPUT_LOG_DIR}/mlperf_log_accuracy.json


echo "LOADGEN_RUN_TIMESTAMP: ${LOADGEN_RUN_TIMESTAMP}"
echo "DATASET_PATH: ${DATASET_PATH}"
echo "TOTAL_SAMPLE_COUNT: ${TOTAL_SAMPLE_COUNT}"
Expand All @@ -118,19 +117,18 @@ run_loadgen() {

${cmd} python -m offline_mode \
--mlperf_test_mode=${TEST_MODE} \
--input_mode tokenized \
--input_mode tokenized \
--output_mode tokenized \
--mlperf_conf $BASEDIR/mlperf.conf \
--user_conf ${USER_CONFIG} \
--audit_conf ${AUDIT_CONF} \
--total_sample_count ${TOTAL_SAMPLE_COUNT} \
--dataset_path ${DATASET_PATH} \
--mlperf_conf $BASEDIR/mlperf.conf \
--user_conf ${USER_CONFIG} \
--audit_conf ${AUDIT_CONF} \
--total_sample_count ${TOTAL_SAMPLE_COUNT} \
--dataset_path ${DATASET_PATH} \
--prefill_lengths_and_batch_sizes ${BATCH_AND_PREFILL_LEN} \
--maxengine_args "${MAXENGINE_ARGS}" \
--output_log_dir ${OUTPUT_LOG_DIR} \
--output_log_dir ${OUTPUT_LOG_DIR} \
--tok_outlen_multiplier ${TOK_OUTLEN_MULTIPLIER} \
${SKIP_WARMUP_OPTION} ${PROFILER_OPTION} 2>&1 | tee ${OUTPUT_LOG_DIR}/${LOADGEN_RUN_TYPE}_log.log

}

run_loadgen_performance () {
Expand All @@ -155,7 +153,13 @@ run_loadgen_accuracy () {

# Eval Run
if [ -e ${OUTPUT_ACCURACY_JSON_PATH} ]; then
${CMD} python3 evaluate-accuracy.py \
if [ "${FAST_EVAL:-false}" = "true" ] || "$fast_eval"; then
EVAL_SCRIPT="evaluate-accuracy-fast.py"
else
EVAL_SCRIPT="evaluate-accuracy.py"
fi

${CMD} python3 ${EVAL_SCRIPT} \
--checkpoint-path meta-llama/Llama-2-70b-chat-hf \
--mlperf-accuracy-file ${OUTPUT_ACCURACY_JSON_PATH} \
--dataset-file ${DATASET_PATH} 2>&1 | tee ${OUTPUT_LOG_DIR}/evaluate_offline_accuracy_log.log
Expand Down

0 comments on commit ee01072

Please sign in to comment.