diff --git a/src/turnkeyml/llm/tools/huggingface_bench.py b/src/turnkeyml/llm/tools/huggingface_bench.py index 5bc94f7..4cee854 100644 --- a/src/turnkeyml/llm/tools/huggingface_bench.py +++ b/src/turnkeyml/llm/tools/huggingface_bench.py @@ -110,7 +110,10 @@ class HuggingfaceBench(Tool): def __init__(self): super().__init__(monitor_message="Benchmarking Huggingface LLM") - self.status_stats = [Keys.SECONDS_TO_FIRST_TOKEN, Keys.TOKEN_GENERATION_TOKENS_PER_SECOND] + self.status_stats = [ + Keys.SECONDS_TO_FIRST_TOKEN, + Keys.TOKEN_GENERATION_TOKENS_PER_SECOND, + ] @staticmethod def parser(parser: argparse.ArgumentParser = None, add_help: bool = True): @@ -287,7 +290,9 @@ def run( # Save performance data to stats state.save_stat(Keys.SECONDS_TO_FIRST_TOKEN, mean_time_to_first_token) - state.save_stat(Keys.TOKEN_GENERATION_TOKENS_PER_SECOND, token_generation_tokens_per_second) + state.save_stat( + Keys.TOKEN_GENERATION_TOKENS_PER_SECOND, token_generation_tokens_per_second + ) state.save_stat(Keys.PROMPT_TOKENS, input_ids.shape[1]) return state diff --git a/src/turnkeyml/llm/tools/ort_genai/oga_bench.py b/src/turnkeyml/llm/tools/ort_genai/oga_bench.py index 329a8b7..fed7f8c 100644 --- a/src/turnkeyml/llm/tools/ort_genai/oga_bench.py +++ b/src/turnkeyml/llm/tools/ort_genai/oga_bench.py @@ -146,11 +146,15 @@ def run( mean_time_to_first_token = statistics.mean(per_iteration_time_to_first_token) prefill_tokens_per_second = input_ids_len / mean_time_to_first_token - token_generation_tokens_per_second = statistics.mean(per_iteration_tokens_per_second) + token_generation_tokens_per_second = statistics.mean( + per_iteration_tokens_per_second + ) state.save_stat(Keys.SECONDS_TO_FIRST_TOKEN, mean_time_to_first_token) state.save_stat(Keys.PREFILL_TOKENS_PER_SECOND, prefill_tokens_per_second) - state.save_stat(Keys.TOKEN_GENERATION_TOKENS_PER_SECOND, token_generation_tokens_per_second) + state.save_stat( + Keys.TOKEN_GENERATION_TOKENS_PER_SECOND, token_generation_tokens_per_second + ) state.save_stat(Keys.PROMPT_TOKENS, input_ids_len) return state