From 9c4ad22e6866dcc546d6067a0fa89d0089f529a0 Mon Sep 17 00:00:00 2001 From: ramkrishna2910 Date: Fri, 13 Dec 2024 13:52:37 -0800 Subject: [PATCH] link fix --- src/turnkeyml/llm/tools/huggingface_bench.py | 9 +++++++-- src/turnkeyml/llm/tools/ort_genai/oga_bench.py | 8 ++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/turnkeyml/llm/tools/huggingface_bench.py b/src/turnkeyml/llm/tools/huggingface_bench.py index 5bc94f7..4cee854 100644 --- a/src/turnkeyml/llm/tools/huggingface_bench.py +++ b/src/turnkeyml/llm/tools/huggingface_bench.py @@ -110,7 +110,10 @@ class HuggingfaceBench(Tool): def __init__(self): super().__init__(monitor_message="Benchmarking Huggingface LLM") - self.status_stats = [Keys.SECONDS_TO_FIRST_TOKEN, Keys.TOKEN_GENERATION_TOKENS_PER_SECOND] + self.status_stats = [ + Keys.SECONDS_TO_FIRST_TOKEN, + Keys.TOKEN_GENERATION_TOKENS_PER_SECOND, + ] @staticmethod def parser(parser: argparse.ArgumentParser = None, add_help: bool = True): @@ -287,7 +290,9 @@ def run( # Save performance data to stats state.save_stat(Keys.SECONDS_TO_FIRST_TOKEN, mean_time_to_first_token) - state.save_stat(Keys.TOKEN_GENERATION_TOKENS_PER_SECOND, token_generation_tokens_per_second) + state.save_stat( + Keys.TOKEN_GENERATION_TOKENS_PER_SECOND, token_generation_tokens_per_second + ) state.save_stat(Keys.PROMPT_TOKENS, input_ids.shape[1]) return state diff --git a/src/turnkeyml/llm/tools/ort_genai/oga_bench.py b/src/turnkeyml/llm/tools/ort_genai/oga_bench.py index 329a8b7..fed7f8c 100644 --- a/src/turnkeyml/llm/tools/ort_genai/oga_bench.py +++ b/src/turnkeyml/llm/tools/ort_genai/oga_bench.py @@ -146,11 +146,15 @@ def run( mean_time_to_first_token = statistics.mean(per_iteration_time_to_first_token) prefill_tokens_per_second = input_ids_len / mean_time_to_first_token - token_generation_tokens_per_second = statistics.mean(per_iteration_tokens_per_second) + token_generation_tokens_per_second = statistics.mean( + per_iteration_tokens_per_second + ) state.save_stat(Keys.SECONDS_TO_FIRST_TOKEN, mean_time_to_first_token) state.save_stat(Keys.PREFILL_TOKENS_PER_SECOND, prefill_tokens_per_second) - state.save_stat(Keys.TOKEN_GENERATION_TOKENS_PER_SECOND, token_generation_tokens_per_second) + state.save_stat( + Keys.TOKEN_GENERATION_TOKENS_PER_SECOND, token_generation_tokens_per_second + ) state.save_stat(Keys.PROMPT_TOKENS, input_ids_len) return state