Skip to content

Commit

Permalink
task_list to tasks (#343)
Browse files Browse the repository at this point in the history
  • Loading branch information
HDCharles authored Jun 11, 2024
1 parent 61fef69 commit 950a893
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions scripts/hf_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.fx_graph_cache = True

def run_evaluation(repo_id, task_list, limit, device, precision, quantization, compile, batch_size, max_length):
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length):

tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision)
Expand All @@ -41,7 +41,7 @@ def run_evaluation(repo_id, task_list, limit, device, precision, quantization, c
tokenizer=tokenizer,
batch_size=batch_size,
max_length=max_length),
get_task_dict(task_list),
get_task_dict(tasks),
limit = limit,
)
for task, res in result["results"].items():
Expand All @@ -52,7 +52,7 @@ def run_evaluation(repo_id, task_list, limit, device, precision, quantization, c
import argparse
parser = argparse.ArgumentParser(description='Run HF Model Evaluation')
parser.add_argument('--repo_id', type=str, default="meta-llama/Meta-Llama-3-8B", help='Repository ID to download from HF.')
parser.add_argument('--task_list', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2')
parser.add_argument('--tasks', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2')
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
Expand All @@ -62,4 +62,4 @@ def run_evaluation(repo_id, task_list, limit, device, precision, quantization, c
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')

args = parser.parse_args()
run_evaluation(args.repo_id, args.task_list, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length)
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length)

0 comments on commit 950a893

Please sign in to comment.