From 74d1d62c53e5a41f396dfe258f9d495f7fc269a8 Mon Sep 17 00:00:00 2001 From: paperTII <2293564561@qq.com> Date: Mon, 20 Oct 2025 19:28:53 +0800 Subject: [PATCH] Performance testing tool based on the PyTest testing framework. --- test/README.md | 2 +- test/README_zh.md | 2 +- test/common/llmperf/__init__.py | 0 test/common/llmperf/run_inference.py | 149 +++++++ test/common/llmperf/utils/__init__.py | 0 test/common/llmperf/utils/common_metrics.py | 17 + test/common/llmperf/utils/models.py | 23 ++ .../utils/openai_chat_completions_client.py | 127 ++++++ test/common/llmperf/utils/sonnet.txt | 84 ++++ test/common/llmperf/utils/token_benchmark.py | 386 ++++++++++++++++++ test/common/llmperf/utils/utils.py | 171 ++++++++ test/config.yaml | 17 +- test/pytest.ini | 2 +- test/requirements.txt | 2 +- ...o_performance.py => test_demo_function.py} | 0 test/suites/E2E/test_uc_performance.py | 138 +++++++ test/test_ucm_dram.py | 250 ++++++++++++ 17 files changed, 1365 insertions(+), 5 deletions(-) create mode 100644 test/common/llmperf/__init__.py create mode 100644 test/common/llmperf/run_inference.py create mode 100644 test/common/llmperf/utils/__init__.py create mode 100644 test/common/llmperf/utils/common_metrics.py create mode 100644 test/common/llmperf/utils/models.py create mode 100644 test/common/llmperf/utils/openai_chat_completions_client.py create mode 100644 test/common/llmperf/utils/sonnet.txt create mode 100644 test/common/llmperf/utils/token_benchmark.py create mode 100644 test/common/llmperf/utils/utils.py rename test/suites/E2E/{test_demo_performance.py => test_demo_function.py} (100%) create mode 100644 test/suites/E2E/test_uc_performance.py create mode 100644 test/test_ucm_dram.py diff --git a/test/README.md b/test/README.md index 1e11da7e..2c6b3e26 100644 --- a/test/README.md +++ b/test/README.md @@ -176,4 +176,4 @@ api_config = config_utils.get_nested_config("easyPerf.api") 2. Apply appropriate tags 3. Naming: `test_*.py` 4. Use fixtures & marks for data management -5. Keep custom marks concise and aligned with overall goals \ No newline at end of file +5. Keep custom marks concise and aligned with overall goals diff --git a/test/README_zh.md b/test/README_zh.md index 26b0f393..518cc17c 100644 --- a/test/README_zh.md +++ b/test/README_zh.md @@ -179,4 +179,4 @@ api_config = config_utils.get_nested_config("easyPerf.api") 2. 使用适当的测试标记 3. 遵循命名规范:`test_*.py` 4. 使用 fixture 及mark 进行测试数据管理 -5. 自定义 mark 标签不易过细,应当与整体功能目标相符合 \ No newline at end of file +5. 自定义 mark 标签不易过细,应当与整体功能目标相符合 diff --git a/test/common/llmperf/__init__.py b/test/common/llmperf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/common/llmperf/run_inference.py b/test/common/llmperf/run_inference.py new file mode 100644 index 00000000..8af5bc3b --- /dev/null +++ b/test/common/llmperf/run_inference.py @@ -0,0 +1,149 @@ +import json +import os +import random +from pathlib import Path +from typing import Any, Dict, List + +import yaml +from common.llmperf.utils.token_benchmark import run_token_benchmark +from common.llmperf.utils.utils import reset_prefill_cache + + +def run_test_cases(test_cases, timestamp_dir, model, server_url, tokenizer_path): + """ + Execute all test cases and return the list of failed case indices and hit_rate mapping for each case. + Parameters: + test_cases — List of test cases read from the configuration file + timestamp_dir — Directory Path to save results + model — Model name + server_url — Base URL of the service + tokenizer_path— Path to the tokenizer + Returns: + failed_cases — List of failed case indices + """ + print(f"[INFO] Total {len(test_cases)} test cases to be executed") + all_summaries = [] + failed_case = [] + + # Clear proxy environment variables + env = os.environ.copy() + env.pop("http_proxy", None) + env.pop("https_proxy", None) + + for i, case in enumerate(test_cases): + print(f"\n>>> Executing test case {i + 1} <<<") + reset_prefill_cache(env, server_url) + # Use a fixed random_seed for each test to control PC hit_rate + random_seed = random.randint(1, 100000) + summary = {} + + # Read parameters from configuration file + mean_input = case.get("mean_input_tokens", 5000) + stddev_input = case.get("stddev_input_tokens", 0) + mean_output = case.get("mean_output_tokens", 1000) + stddev_output = case.get("stddev_output_tokens", 0) + max_completed = case.get("max_num_completed_requests", 1) + concurrent = case.get("concurrent_requests", 1) + llm_api = case.get("llm_api", "openai") + additional_sampling_params = case.get("additional_sampling_params", "{}") + timeout = case.get("timeout", 60000) + hit_rate = case.get("hit_rate", 0) + + try: + # Determine if two runs are needed (PC hit_rate test) + if hit_rate == 0: + summary = run_token_benchmark( + llm_api=llm_api, + model=model, + test_timeout_s=timeout, + max_num_completed_requests=max_completed, + concurrent_requests=concurrent, + mean_input_tokens=mean_input, + stddev_input_tokens=stddev_input, + mean_output_tokens=mean_output, + stddev_output_tokens=stddev_output, + additional_sampling_params=additional_sampling_params, + results_dir=str(timestamp_dir), + random_seed=random_seed, + openai_api_base=server_url + "/v1", + tokenizer_path=tokenizer_path, + user_metadata={"case_idx": i}, + ) + else: + print( + f"[INFO] hit_rate > 0 detected, entering prefill mode, PC hit rate: {hit_rate} %" + ) + # hit_rate > 0: first prefill mode + prefill_mean_input = int(mean_input * hit_rate / 100) + print( + f"[INFO] Prefill execution: mean_input_tokens={prefill_mean_input}" + ) + run_token_benchmark( + llm_api=llm_api, + model=model, + test_timeout_s=timeout, + max_num_completed_requests=max_completed, + concurrent_requests=concurrent, + mean_input_tokens=prefill_mean_input, + stddev_input_tokens=stddev_input, + mean_output_tokens=2, + stddev_output_tokens=stddev_output, + additional_sampling_params=additional_sampling_params, + results_dir=str(timestamp_dir), + random_seed=random_seed, + openai_api_base=server_url + "/v1", + tokenizer_path=tokenizer_path, + user_metadata={"case_idx": i, "phase": "prefill"}, + ) + reset_prefill_cache(env, server_url) + # Then run normal mode + print("[INFO] Prefill completed, switching to normal mode execution") + summary = run_token_benchmark( + llm_api=llm_api, + model=model, + test_timeout_s=timeout, + max_num_completed_requests=max_completed, + concurrent_requests=concurrent, + mean_input_tokens=mean_input, + stddev_input_tokens=stddev_input, + mean_output_tokens=mean_output, + stddev_output_tokens=stddev_output, + additional_sampling_params=additional_sampling_params, + results_dir=str(timestamp_dir), + random_seed=random_seed, + openai_api_base=server_url + "/v1", + tokenizer_path=tokenizer_path, + user_metadata={"case_idx": i, "phase": "normal"}, + ) + all_summaries.append(summary) + except Exception as e: + failed_case.append(i) + + return all_summaries, failed_case + + +def inference_results(): + config_file = Path(__file__).parent.parent.parent / "config.yaml" + all_smmaries = {} + print("[INFO] Initialization complete, starting main process") + print(f"[INFO] Reading configuration file: {config_file}") + with open(config_file, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + model = config.get("llm_connection", {}).get("model", "") + server_url = config.get("llm_connection", {}).get("server_url", "") + tokenizer_path = config.get("llm_connection", {}).get("tokenizer_path", "") + test_cases = config.get("llmperf_test_cases", []) + timestamp_dir = Path("results") + timestamp_dir.mkdir(parents=True, exist_ok=True) + print(f"[INFO] Created results directory: {timestamp_dir}") + + all_summaries, failed_cases = run_test_cases( + test_cases, timestamp_dir, model, server_url, tokenizer_path + ) + total = len(test_cases) + print( + f"\n[INFO] All tests completed! Success: {total - len(failed_cases)}/{total}" + ) + if failed_cases: + print(f"[WARN] Failed case indices: {failed_cases}") + return all_summaries diff --git a/test/common/llmperf/utils/__init__.py b/test/common/llmperf/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/common/llmperf/utils/common_metrics.py b/test/common/llmperf/utils/common_metrics.py new file mode 100644 index 00000000..40e21124 --- /dev/null +++ b/test/common/llmperf/utils/common_metrics.py @@ -0,0 +1,17 @@ +# TODO (Avnishn): compute metrics in class +INTER_TOKEN_LAT = "inter_token_latency_s" +TTFT = "ttft_s" +E2E_LAT = "end_to_end_latency_s" +NUM_INPUT_TOKENS = "number_input_tokens" +NUM_OUTPUT_TOKENS = "number_output_tokens" +NUM_TOTAL_TOKENS = "number_total_tokens" +REQ_OUTPUT_THROUGHPUT = "request_output_throughput_token_per_s" +ERROR_MSG = "error_msg" +ERROR_CODE = "error_code" +ERROR_CODE_FREQ = "error_code_frequency" +NUM_ERRORS = "number_errors" +OUTPUT_THROUGHPUT = "mean_output_throughput_token_per_s" +NUM_COMPLETED_REQUESTS = "num_completed_requests" +COMPLETED_REQUESTS_PER_MIN = "num_completed_requests_per_min" +ERROR_RATE = "error_rate" +NUM_REQ_STARTED = "num_requests_started" diff --git a/test/common/llmperf/utils/models.py b/test/common/llmperf/utils/models.py new file mode 100644 index 00000000..1cbab628 --- /dev/null +++ b/test/common/llmperf/utils/models.py @@ -0,0 +1,23 @@ +from typing import Any, Dict, Optional, Tuple + +from pydantic import BaseModel + + +class RequestConfig(BaseModel): + """The configuration for a request to the LLM API. + + Args: + model: The model to use. + prompt: The prompt to provide to the LLM API. + sampling_params: Additional sampling parameters to send with the request. + For more information see the Router app's documentation for the completions + llm_api: The name of the LLM API to send the request to. + metadata: Additional metadata to attach to the request for logging or validation purposes. + """ + + model: str + prompt: Tuple[str, int] + sampling_params: Optional[Dict[str, Any]] = None + llm_api: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + openai_api_base: Optional[str] = "" diff --git a/test/common/llmperf/utils/openai_chat_completions_client.py b/test/common/llmperf/utils/openai_chat_completions_client.py new file mode 100644 index 00000000..d559f9cd --- /dev/null +++ b/test/common/llmperf/utils/openai_chat_completions_client.py @@ -0,0 +1,127 @@ +import json +import os +import time +from typing import Any, Dict, Tuple + +import requests +from common.llmperf.utils import common_metrics +from common.llmperf.utils.models import RequestConfig + + +class OpenAIChatCompletionsClient: + """ + used for sending HTTP requests, receiving token streams, measuring latency, etc. + """ + + def llm_request( + self, request_config: RequestConfig + ) -> Tuple[Dict[str, Any], str, RequestConfig]: + prompt, prompt_len = request_config.prompt + + message = [ + {"role": "system", "content": ""}, + {"role": "user", "content": prompt}, + ] + model = request_config.model + body = { + "model": model, + "messages": message, + "stream": True, + "ignore_eos": True, + } + sampling_params = request_config.sampling_params + body.update(sampling_params or {}) + + time_to_next_token = [] + tokens_received = 0 + ttft = 0.0 + error_response_code = None + generated_text = "" + error_msg = "" + output_throughput = 0.0 + total_request_time = 0.0 + flag = False + + metrics: Dict[str, Any] = {} + + metrics[common_metrics.ERROR_CODE] = None + metrics[common_metrics.ERROR_MSG] = "" + + start_time = time.monotonic() + most_recent_received_token_time = start_time + + address = request_config.openai_api_base + + if not address: + raise ValueError("the environment variable OPENAI_API_BASE must be set.") + key = os.environ.get("OPENAI_API_KEY", "secret_abcdefg") + if not key: + raise ValueError("the environment variable OPENAI_API_KEY must be set.") + headers = {"Authorization": f"Bearer {key}"} + if not address.endswith("/"): + address = address + "/" + address += "chat/completions" + try: + with requests.post( + address, + json=body, + stream=True, + timeout=180, + headers=headers, + ) as response: + if response.status_code != 200: + error_msg = response.text + error_response_code = response.status_code + response.raise_for_status() + + for chunk in response.iter_lines(chunk_size=None): + if not chunk: + continue + stem = b"data: " + if chunk.startswith(stem): + chunk = chunk[len(stem) :] + # Data might already be bytes or str + if isinstance(chunk, bytes): + chunk = chunk.decode("utf-8", errors="ignore") + if chunk.strip() == "[DONE]": + continue + tokens_received += 1 + data = json.loads(chunk) + if "error" in data: + error_msg = data["error"]["message"] + error_response_code = data["error"]["code"] + raise RuntimeError(error_msg) + delta = data["choices"][0]["delta"] + content = delta.get("content", None) or delta.get( + "reasoning_content", "" + ) + if content: + if tokens_received != 0 and flag == False: + ttft = time.monotonic() - start_time + flag = True + else: + time_to_next_token.append( + time.monotonic() - most_recent_received_token_time + ) + most_recent_received_token_time = time.monotonic() + generated_text += content + + total_request_time = time.monotonic() - start_time + if total_request_time > 0: + output_throughput = tokens_received / total_request_time + + except Exception as e: + metrics[common_metrics.ERROR_MSG] = error_msg + metrics[common_metrics.ERROR_CODE] = error_response_code + print(f"Warning Or Error: {e}") + print(error_response_code) + + metrics[common_metrics.INTER_TOKEN_LAT] = sum(time_to_next_token) + metrics[common_metrics.TTFT] = ttft + metrics[common_metrics.E2E_LAT] = total_request_time + metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = output_throughput + metrics[common_metrics.NUM_TOTAL_TOKENS] = tokens_received + prompt_len + metrics[common_metrics.NUM_OUTPUT_TOKENS] = tokens_received + metrics[common_metrics.NUM_INPUT_TOKENS] = prompt_len + + return metrics, generated_text, request_config diff --git a/test/common/llmperf/utils/sonnet.txt b/test/common/llmperf/utils/sonnet.txt new file mode 100644 index 00000000..9f13ead4 --- /dev/null +++ b/test/common/llmperf/utils/sonnet.txt @@ -0,0 +1,84 @@ +Shall I compare thee to a summer's day? +Thou art more lovely and more temperate: +Rough winds do shake the darling buds of May, +And summer's lease hath all too short a date: +Sometime too hot the eye of heaven shines, +And often is his gold complexion dimm'd; +And every fair from fair sometime declines, +By chance or nature's changing course untrimm'd; +But thy eternal summer shall not fade +Nor lose possession of that fair thou owest; +Nor shall Death brag thou wander'st in his shade, +When in eternal lines to time thou growest: +So long as men can breathe or eyes can see, +So long lives this and this gives life to thee. +Then let not winter's ragged hand deface +In thee thy summer, ere thou be distill'd: +Make sweet some vial; treasure thou some place +With beauty's treasure, ere it be self-kill'd. +That use is not forbidden usury, +Which happies those that pay the willing loan; +That's for thyself to breed another thee, +Or ten times happier, be it ten for one; +Ten times thyself were happier than thou art, +If ten of thine ten times refigured thee: +Then what could death do, if thou shouldst depart, +Leaving thee living in posterity? +Be not self-will'd, for thou art much too fair +To be death's conquest and make worms thine heir. +Where art thou, Muse, that thou forget'st so long +To speak of that which gives thee all thy might? +Spend'st thou thy fury on some worthless song, +Darkening thy power to lend base subjects light? +Return, forgetful Muse, and straight redeem +In gentle numbers time so idly spent; +Sing to the ear that doth thy lays esteem +And gives thy pen both skill and argument. +Rise, resty Muse, my love's sweet face survey, +If Time have any wrinkle graven there; +If any, be a satire to decay, +And make Time's spoils despised every where. +Give my love fame faster than Time wastes life; +So thou prevent'st his scythe and crooked knife. +My glass shall not persuade me I am old, +So long as youth and thou are of one date; +But when in thee time's furrows I behold, +Then look I death my days should expiate. +For all that beauty that doth cover thee +Is but the seemly raiment of my heart, +Which in thy breast doth live, as thine in me: +How can I then be elder than thou art? +O, therefore, love, be of thyself so wary +As I, not for myself, but for thee will; +Bearing thy heart, which I will keep so chary +As tender nurse her babe from faring ill. +Presume not on thy heart when mine is slain; +Thou gavest me thine, not to give back again. +So am I as the rich, whose blessed key +Can bring him to his sweet up-locked treasure, +The which he will not every hour survey, +For blunting the fine point of seldom pleasure. +Therefore are feasts so solemn and so rare, +Since, seldom coming, in the long year set, +Like stones of worth they thinly placed are, +Or captain jewels in the carcanet. +So is the time that keeps you as my chest, +Or as the wardrobe which the robe doth hide, +To make some special instant special blest, +By new unfolding his imprison'd pride. +Blessed are you, whose worthiness gives scope, +Being had, to triumph, being lack'd, to hope. +If there be nothing new, but that which is +Hath been before, how are our brains beguiled, +Which, labouring for invention, bear amiss +The second burden of a former child! +O, that record could with a backward look, +Even of five hundred courses of the sun, +Show me your image in some antique book, +Since mind at first in character was done! +That I might see what the old world could say +To this composed wonder of your frame; +Whether we are mended, or whether better they, +Or whether revolution be the same. +O, sure I am, the wits of former days +To subjects worse have given admiring praise. \ No newline at end of file diff --git a/test/common/llmperf/utils/token_benchmark.py b/test/common/llmperf/utils/token_benchmark.py new file mode 100644 index 00000000..67553cf1 --- /dev/null +++ b/test/common/llmperf/utils/token_benchmark.py @@ -0,0 +1,386 @@ +import json +import logging +import random +import re +import time +from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import pandas as pd +from common.llmperf.utils import common_metrics +from common.llmperf.utils.models import RequestConfig +from common.llmperf.utils.openai_chat_completions_client import ( + OpenAIChatCompletionsClient, +) +from common.llmperf.utils.utils import ( + LLMPerfResults, + randomly_sample_sonnet_lines_prompt, + sample_random_positive_int, +) +from transformers import AutoTokenizer + + +def get_token_throughput_latencies( + model: str, + mean_input_tokens: int, + stddev_input_tokens: int, + mean_output_tokens: int, + stddev_output_tokens: int, + additional_sampling_params: Optional[Dict[str, Any]] = None, + concurrent_requests: int = 1, + max_num_completed_requests: int = 500, + test_timeout_s=90, + llm_api="openai", + random_seed: int = None, + openai_api_base: str = "", + tokenizer_path: str = None, +) -> Tuple[Dict[str, Any], List[Dict[str, Any]], float, float]: + """Get the token throughput and latencies for the given model. + + Args: + model: The name of the model to query. + mean_input_tokens: The mean number of tokens to send in the prompt for the request. + stddev_input_tokens: The standard deviation of the number of tokens to send in the prompt for the request. + mean_output_tokens: The mean number of tokens to generate per request. + stddev_output_tokens: The standard deviation of the number of tokens to generate per request. + additional_sampling_params: Additional sampling parameters to send with the request. + For more information see the LLM APIs documentation for the completions + concurrent_requests: The number of concurrent requests to make. Increase + this to increase the amount of load and vice versa. + test_timeout_s: The amount of time to run the test for before reporting results. + llm_api: The name of the llm api to use. Either "openai" or "litellm". + + Returns: + A summary of the performance metrics collected across all completed requests + (e.g. throughput, latencies, etc.) + The individual metrics for each request. + """ + random.seed(random_seed) + + print(f"Using tokenizer:{tokenizer_path}") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + get_token_length = lambda text: len(tokenizer.encode(text)) + + if not additional_sampling_params: + additional_sampling_params = {} + + # 1. create prompts + prompts: List[Tuple[str, int]] = [] + num_output_tokens_list: List[int] = [] + for i in range(max_num_completed_requests): + num_output = sample_random_positive_int( + mean_output_tokens, stddev_output_tokens + ) + num_output_tokens_list.append(num_output) + prompts.append( + randomly_sample_sonnet_lines_prompt( + prompt_tokens_mean=mean_input_tokens, + prompt_tokens_stddev=stddev_input_tokens, + tokenizer=tokenizer, + ) + ) + start_time = time.monotonic() + completed_requests: List[Dict[str, Any]] = [] + incremental_time_delay = 0.0 + client = OpenAIChatCompletionsClient() + futures = [] + + # 2. Submitting tasks using a thread pool + with ThreadPoolExecutor(max_workers=concurrent_requests) as executor: + for idx in range(max_num_completed_requests): + sampling = {"max_tokens": num_output_tokens_list[idx]} + sampling.update(additional_sampling_params) + cfg = RequestConfig( + model=model, + prompt=prompts[idx], + sampling_params=sampling, + llm_api=llm_api, + openai_api_base=openai_api_base, + ) + futures.append(executor.submit(client.llm_request, cfg)) + # 3. Waiting for completion or timeout + for future in as_completed(futures, timeout=test_timeout_s): + try: + metrics, gen_text, req_cfg = future.result() + except Exception as e: + logging.warning(f"[WARN] Future raised exception: {e}") + continue + num_output_tokens = get_token_length(gen_text) + if num_output_tokens: + metrics[common_metrics.INTER_TOKEN_LAT] /= ( + (metrics[common_metrics.NUM_OUTPUT_TOKENS] - 1) + if (metrics[common_metrics.NUM_OUTPUT_TOKENS] - 1) + else 1 + ) + metrics[common_metrics.NUM_OUTPUT_TOKENS] = num_output_tokens + metrics[common_metrics.NUM_TOTAL_TOKENS] = ( + metrics[common_metrics.NUM_INPUT_TOKENS] + num_output_tokens + ) + try: + metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = ( + num_output_tokens / metrics[common_metrics.E2E_LAT] + ) + except ZeroDivisionError: + logging.error("Division by zero in throughput calculation.") + + completed_requests.append(metrics) + + incremental_time_delay += metrics.get( + common_metrics.INTER_TOKEN_LAT, 0.0 + ) + + end_time = time.monotonic() + + print(f"Results for token benchmark for {model} queried with the {llm_api} api.\n") + if mean_output_tokens == 2: + print(f"[INFO] First token sending pre-embedding completed\n") + return {}, [], 0.0, 0.0 + + ret = metrics_summary(completed_requests, start_time, end_time) + + metadata = { + "model": model, + "mean_input_tokens": mean_input_tokens, + "stddev_input_tokens": stddev_input_tokens, + "mean_output_tokens": mean_output_tokens, + "stddev_output_tokens": stddev_output_tokens, + "concurrent_requests": concurrent_requests, + "additional_sampling_params": additional_sampling_params, + } + + metadata["results"] = ret + elapsed_time = end_time - start_time + return metadata, completed_requests, elapsed_time, incremental_time_delay + + +def compute_throughput( + summary: Dict[str, Any], + completed_requests: List[Dict[str, Any]], + elapsed_time: float, + incremental_time_delay: float, +) -> Tuple[float, float]: + """ + Compute total_throughput (token/s) based on the metrics in summary. + + Formula: (mean_output_tokens * num_completed_requests) / total_e2e_latency_s + + Args: + summary (Dict[str, Any]): A dictionary containing performance metrics. + + Returns: + float: The computed total throughput in tokens per second. Returns 0.0 if latency is zero. + """ + mean_output_tokens = summary.get("mean_output_tokens", 0) + + total_throughput = ( + (mean_output_tokens * len(completed_requests)) / elapsed_time + if elapsed_time > 0 + else 0.0 + ) + incremental_throughput = ( + (mean_output_tokens * len(completed_requests)) / incremental_time_delay + if incremental_time_delay > 0 + else 0.0 + ) + return round(total_throughput, 4), round(incremental_throughput, 4) + + +def metrics_summary( + metrics: List[Dict[str, Any]], start_time: int, end_time: int +) -> Dict[str, Any]: + """Generate a summary over metrics generated from potentially multiple instances of this client. + + Args: + metrics: The metrics to summarize. + start_time: The time the test started. + end_time: The time the test ended. + + Returns: + A summary with the following information: + - Overall throughput (generated tokens / total test time) + - Number of completed requests + - Error rate + - Error code frequency + - Quantiles (p25-p99) for the following metrics: + - Inter token latency + - Time to first token + - User total request time + - Number of tokens processed per request + - Number of tokens generated per request + - User throughput (tokens / s) + """ + ret = {} + + def flatten(item): + for sub_item in item: + if isinstance(sub_item, Iterable) and not isinstance(sub_item, str): + yield from flatten(sub_item) + else: + yield sub_item + + df = pd.DataFrame(metrics) + df_without_errored_req = df[df[common_metrics.ERROR_CODE].isna()] + + for key in [ + common_metrics.INTER_TOKEN_LAT, + common_metrics.TTFT, + common_metrics.E2E_LAT, + common_metrics.REQ_OUTPUT_THROUGHPUT, + common_metrics.NUM_INPUT_TOKENS, + common_metrics.NUM_OUTPUT_TOKENS, + ]: + print(key) + ret[key] = {} + series = pd.Series(list(flatten(df_without_errored_req[key]))).dropna() + series = series[series > 0] # Calculate non-zero values + quantiles = series.quantile([0.25, 0.5, 0.75, 0.9, 0.95, 0.99]).to_dict() + quantiles_reformatted_keys = {} + for quantile, value in quantiles.items(): + reformatted_key = f"p{int(quantile * 100)}" + print(f" {reformatted_key} = {value}") + quantiles_reformatted_keys[reformatted_key] = value + ret[key]["quantiles"] = quantiles_reformatted_keys + mean = series.mean() + print(f" mean = {mean}") + ret[key]["mean"] = mean + print(f" min = {series.min()}") + ret[key]["min"] = series.min() + print(f" max = {series.max()}") + ret[key]["max"] = series.max() + print(f" stddev = {series.std()}") + ret[key]["stddev"] = series.std() + + ret[common_metrics.NUM_REQ_STARTED] = len(metrics) + + error_codes = df[common_metrics.ERROR_CODE].dropna() + num_errors = len(error_codes) + ret[common_metrics.ERROR_RATE] = num_errors / len(metrics) if len(metrics) else 0 + ret[common_metrics.NUM_ERRORS] = num_errors + print(f"Number Of Errored Requests: {num_errors}") + error_code_frequency = dict(error_codes.value_counts()) + if num_errors: + error_code_frequency = dict(error_codes.value_counts()) + print("Error Code Frequency") + print(error_code_frequency) + ret[common_metrics.ERROR_CODE_FREQ] = str(error_code_frequency) + + overall_output_throughput = df_without_errored_req[ + common_metrics.NUM_OUTPUT_TOKENS + ].sum() / (end_time - start_time) + + print(f"Overall Output Throughput: {overall_output_throughput}") + ret[common_metrics.OUTPUT_THROUGHPUT] = overall_output_throughput + + num_completed_requests = len(df_without_errored_req) + num_completed_requests_per_min = ( + num_completed_requests / (end_time - start_time) * 60 + ) + print(f"Number Of Completed Requests: {num_completed_requests}") + print(f"Completed Requests Per Minute: {num_completed_requests_per_min}") + + ret[common_metrics.NUM_COMPLETED_REQUESTS] = num_completed_requests + ret[common_metrics.COMPLETED_REQUESTS_PER_MIN] = num_completed_requests_per_min + + return ret + + +def run_token_benchmark( + llm_api: str, + model: str, + test_timeout_s: int, + max_num_completed_requests: int, + concurrent_requests: int, + mean_input_tokens: int, + stddev_input_tokens: int, + mean_output_tokens: int, + stddev_output_tokens: int, + additional_sampling_params: str, + results_dir: str, + random_seed: int, + openai_api_base: str, + tokenizer_path: str, + user_metadata: Dict[str, Any], +): + """ + Args: + llm_api: The name of the llm api to use. + model: The name of the model to query. + max_num_completed_requests: The number of requests to complete before finishing the test. + test_timeout_s: The amount of time to run the test for before reporting results. + concurrent_requests: The number of concurrent requests to make. Increase + this to increase the amount of load and vice versa. + mean_input_tokens: The mean number of tokens to send in the prompt for the request. + stddev_input_tokens: The standard deviation of the number of tokens to send in the prompt for the request. + mean_output_tokens: The mean number of tokens to generate per request. + stddev_output_tokens: The standard deviation of the number of tokens to generate per request. + additional_sampling_params: Additional sampling parameters to send with the request. + For more information see the LLM APIs documentation for the completions. + results_dir: The directory to save the results to. + user_metadata: Additional metadata to include in the results. + """ + if mean_input_tokens < 40: + print( + "the minimum number of input tokens that will be sent is 41" + " because of the prompting logic right now" + ) + + summary, completed_requests, elapsed_time, incremental_time_delay = ( + get_token_throughput_latencies( + model=model, + llm_api=llm_api, + test_timeout_s=test_timeout_s, + max_num_completed_requests=max_num_completed_requests, + mean_input_tokens=mean_input_tokens, + stddev_input_tokens=stddev_input_tokens, + mean_output_tokens=mean_output_tokens, + stddev_output_tokens=stddev_output_tokens, + concurrent_requests=concurrent_requests, + additional_sampling_params=json.loads(additional_sampling_params), + random_seed=random_seed, + openai_api_base=openai_api_base, + tokenizer_path=tokenizer_path, + ) + ) + if mean_output_tokens == 2: + return summary, completed_requests, elapsed_time, incremental_time_delay + + timestamp = int(time.time() * 1000) + if results_dir: + filename = f"{model}_{mean_input_tokens}_{mean_output_tokens}_{timestamp}" + filename = re.sub(r"[^\w\d-]+", "-", filename) + filename = re.sub(r"-{2,}", "-", filename) + summary_filename = f"{filename}_summary" + + # Update to metadata. + summary.update(user_metadata) + total_tp, req_tp = compute_throughput( + summary, completed_requests, elapsed_time, incremental_time_delay + ) + summary["num_completed_requests"] = len(completed_requests) + summary["elapsed_time"] = elapsed_time + summary["incremental_time_delay"] = incremental_time_delay + summary["total_throughput"] = total_tp + summary["incremental_throughput"] = req_tp + + results = LLMPerfResults(name=summary_filename, metadata=summary) + results_dir = Path(results_dir) + if not results_dir.exists(): + results_dir.mkdir(parents=True) + elif not results_dir.is_dir(): + raise ValueError(f"{results_dir} is not a directory") + + llmperf_dir = results_dir / "llmperf" + if not llmperf_dir.exists(): + llmperf_dir.mkdir(parents=True) + elif not llmperf_dir.is_dir(): + raise ValueError(f"{llmperf_dir} is not a directory") + + try: + with open(llmperf_dir / f"{summary_filename}.json", "w") as f: + json.dump(results.to_dict(), f, indent=4, default=str) + except Exception as e: + print(results.to_dict()) + raise e + return summary diff --git a/test/common/llmperf/utils/utils.py b/test/common/llmperf/utils/utils.py new file mode 100644 index 00000000..e2c27087 --- /dev/null +++ b/test/common/llmperf/utils/utils.py @@ -0,0 +1,171 @@ +import hashlib +import json +import math +import os +import pathlib +import random +import subprocess +import time +from typing import Any, Dict, Tuple + +from transformers import LlamaTokenizerFast + +RESULTS_VERSION = "2025-10-30" + + +class LLMPerfResults: + def __init__( + self, + name: str, + metadata: Dict[str, Any] = None, + ): + self.name = name + self.metadata = metadata or {} + self.timestamp = int(time.time()) + self.metadata["timestamp"] = self.timestamp + self.version = RESULTS_VERSION + + def to_dict(self): + data = { + "version": self.version, + "name": self.name, + } + data.update(self.metadata) + data = flatten_dict(data) + return data + + def json(self): + data = self.to_dict() + return json.dumps(data) + + +def upload_to_s3(results_path: str, s3_path: str) -> None: + """Upload the results to s3. + + Args: + results_path: The path to the results file. + s3_path: The s3 path to upload the results to. + + """ + + command = ["aws", "s3", "sync", results_path, f"{s3_path}/"] + result = subprocess.run(command) + if result.returncode == 0: + print("Files uploaded successfully!") + else: + print("An error occurred:") + print(result.stderr) + + +def randomly_sample_sonnet_lines_prompt( + prompt_tokens_mean: int = 550, + prompt_tokens_stddev: int = 250, + tokenizer: LlamaTokenizerFast = None, +) -> Tuple[str, int]: + """Generate a prompt that randomly samples lines from a the shakespeare sonnet at sonnet.txt. + + Args: + prompt_length_mean: The mean length of the prompt to generate. + prompt_len_stddev: The standard deviation of the length of the prompt to generate. + expect_output_tokens: The number of tokens to expect in the output. This is used to + determine the length of the prompt. The prompt will be generated such that the output + will be approximately this many tokens. + + Note: + tokens will be counted from the sonnet using the Llama tokenizer. Using one tokenizer + ensures a fairer comparison across different LLMs. For example, if gpt 3.5 tokenizes + a prompt in less tokens than Llama2, then this will be reflected in the results since + they will be fed identical prompts. + + Returns: + A tuple of the prompt and the length of the prompt. + """ + get_token_length = lambda text: len(tokenizer.encode(text)) + + prompt = ( + "Randomly stream lines from the following text " + "Don't generate eos tokens:\n\n" + ) + # get a prompt length that is at least as long as the base + num_prompt_tokens = sample_random_positive_int( + prompt_tokens_mean, prompt_tokens_stddev + ) + while num_prompt_tokens < get_token_length(prompt): + num_prompt_tokens = sample_random_positive_int( + prompt_tokens_mean, prompt_tokens_stddev + ) + remaining_prompt_tokens = num_prompt_tokens - get_token_length(prompt) + sonnet_path = pathlib.Path(__file__).parent.resolve() / "sonnet.txt" + with open(sonnet_path, "r") as f: + sonnet_lines = f.readlines() + random.shuffle(sonnet_lines) + sampling_lines = True + while sampling_lines: + for line in sonnet_lines: + line_to_add = line + if remaining_prompt_tokens - get_token_length(line_to_add) < 0: + # This will cut off a line in the middle of a word, but that's ok since an + # llm should be able to handle that. + line_to_add = line_to_add[: int(math.ceil(remaining_prompt_tokens))] + sampling_lines = False + prompt += line_to_add + break + prompt += line_to_add + remaining_prompt_tokens -= get_token_length(line_to_add) + print(hashlib.sha256(prompt.encode("utf-8")).hexdigest()) + return (prompt, num_prompt_tokens) + + +def sample_random_positive_int(mean: int, stddev: int) -> int: + """Sample random numbers from a gaussian distribution until a positive number is sampled. + + Args: + mean: The mean of the gaussian distribution to sample from. + stddev: The standard deviation of the gaussian distribution to sample from. + + Returns: + A random positive integer sampled from the gaussian distribution. + """ + ret = -1 + while ret <= 0: + ret = int(random.gauss(mean, stddev)) + return ret + + +def flatten_dict(d, parent_key="", sep="_"): + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def reset_prefill_cache(env, server_url): + """ + prefix cache / HBM + Param: + env + server_url + """ + reset_url = f"{server_url}/reset_prefix_cache" + print(f"[INFO] Resetting prefix cache: {reset_url}") + try: + result = subprocess.run( + ["curl", "-X", "POST", reset_url, "-s", "-f"], + env=env, + check=False, + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode == 0: + print("[INFO] Prefix cache successfully reset") + else: + print( + f"[ERROR] Unsuccessfully reset prefix cache,error code: {result.returncode}" + ) + except Exception as e: + print(f"[ERROR] Exception in resetting prefix cache: {e}") diff --git a/test/config.yaml b/test/config.yaml index 88d00a61..766cfeb6 100644 --- a/test/config.yaml +++ b/test/config.yaml @@ -15,4 +15,19 @@ database: name: "ucm_pytest" user: "root" password: "123456" - charset: "utf8mb4" \ No newline at end of file + charset: "utf8mb4" + +# LLM Connection Configuration +llm_connection: + model: "qwen3" + server_url: "http://141.111.32.70:9382" + tokenizer_path: "/home/models/QwQ-32B" + +# Performance Test Configuration +llmperf_test_cases: + - mean_input_tokens: 6000 + mean_output_tokens: 200 + max_num_completed_requests: 16 + concurrent_requests: 8 + additional_sampling_params: "{}" + hit_rate: 0 \ No newline at end of file diff --git a/test/pytest.ini b/test/pytest.ini index 4be3cf47..76b0297a 100644 --- a/test/pytest.ini +++ b/test/pytest.ini @@ -22,4 +22,4 @@ markers = # -------- Features (Recommended) -------- feature: Feature tag platform(name): Platform tag(gpu/npu) -# end of markers \ No newline at end of file +# end of markers diff --git a/test/requirements.txt b/test/requirements.txt index 07635b24..41bcc9b3 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -3,4 +3,4 @@ pytest-html>=3.1.1 PyYAML>=6.0 # MySQL peewee>=3.14.5 -pymysql>=1.0.2 \ No newline at end of file +pymysql>=1.0.2 diff --git a/test/suites/E2E/test_demo_performance.py b/test/suites/E2E/test_demo_function.py similarity index 100% rename from test/suites/E2E/test_demo_performance.py rename to test/suites/E2E/test_demo_function.py diff --git a/test/suites/E2E/test_uc_performance.py b/test/suites/E2E/test_uc_performance.py new file mode 100644 index 00000000..60524e51 --- /dev/null +++ b/test/suites/E2E/test_uc_performance.py @@ -0,0 +1,138 @@ +import pytest +from common.capture_utils import export_vars +from common.llmperf.run_inference import inference_results + + +@pytest.mark.feature("uc_performance_test") +@export_vars +def test_performance(): + all_summaries = inference_results() + failed_cases = [] + + value_lists = { + "mean_input_tokens": [], + "mean_output_tokens": [], + "results_inter_token_latency_s_quantiles_p50": [], + "results_inter_token_latency_s_quantiles_p90": [], + "results_inter_token_latency_s_quantiles_p99": [], + "results_inter_token_latency_s_mean": [], + "results_ttft_s_quantiles_p50": [], + "results_ttft_s_quantiles_p90": [], + "results_ttft_s_quantiles_p99": [], + "results_ttft_s_mean": [], + "results_end_to_end_latency_s_quantiles_p50": [], + "results_end_to_end_latency_s_quantiles_p90": [], + "results_end_to_end_latency_s_quantiles_p99": [], + "results_end_to_end_latency_s_mean": [], + "num_completed_requests": [], + "elapsed_time": [], + "incremental_time_delay": [], + "total_throughput": [], + "incremental_throughput": [], + } + + for i, summary in enumerate(all_summaries): + mean_input_tokens = summary["mean_input_tokens"] + mean_output_tokens = summary["mean_output_tokens"] + + results_inter_token_latency_s_quantiles_p50 = summary["results"][ + "inter_token_latency_s" + ]["quantiles"]["p50"] + results_inter_token_latency_s_quantiles_p90 = summary["results"][ + "inter_token_latency_s" + ]["quantiles"]["p90"] + results_inter_token_latency_s_quantiles_p99 = summary["results"][ + "inter_token_latency_s" + ]["quantiles"]["p99"] + results_inter_token_latency_s_mean = summary["results"][ + "inter_token_latency_s" + ]["mean"] + + results_ttft_s_quantiles_p50 = summary["results"]["ttft_s"]["quantiles"]["p50"] + results_ttft_s_quantiles_p90 = summary["results"]["ttft_s"]["quantiles"]["p90"] + results_ttft_s_quantiles_p99 = summary["results"]["ttft_s"]["quantiles"]["p99"] + results_ttft_s_mean = summary["results"]["ttft_s"]["mean"] + + results_end_to_end_latency_s_quantiles_p50 = summary["results"][ + "end_to_end_latency_s" + ]["quantiles"]["p50"] + results_end_to_end_latency_s_quantiles_p90 = summary["results"][ + "end_to_end_latency_s" + ]["quantiles"]["p90"] + results_end_to_end_latency_s_quantiles_p99 = summary["results"][ + "end_to_end_latency_s" + ]["quantiles"]["p99"] + results_end_to_end_latency_s_mean = summary["results"]["end_to_end_latency_s"][ + "mean" + ] + + num_completed_requests = summary["num_completed_requests"] + elapsed_time = summary["elapsed_time"] + incremental_time_delay = summary["incremental_time_delay"] + total_throughput = summary["total_throughput"] + incremental_throughput = summary["incremental_throughput"] + + values = [ + mean_input_tokens, + mean_output_tokens, + results_inter_token_latency_s_quantiles_p50, + results_inter_token_latency_s_quantiles_p90, + results_inter_token_latency_s_quantiles_p99, + results_inter_token_latency_s_mean, + results_ttft_s_quantiles_p50, + results_ttft_s_quantiles_p90, + results_ttft_s_quantiles_p99, + results_ttft_s_mean, + results_end_to_end_latency_s_quantiles_p50, + results_end_to_end_latency_s_quantiles_p90, + results_end_to_end_latency_s_quantiles_p99, + results_end_to_end_latency_s_mean, + num_completed_requests, + elapsed_time, + incremental_time_delay, + total_throughput, + incremental_throughput, + ] + + for var_name, val in zip( + [ + "mean_input_tokens", + "mean_output_tokens", + "results_inter_token_latency_s_quantiles_p50", + "results_inter_token_latency_s_quantiles_p90", + "results_inter_token_latency_s_quantiles_p99", + "results_inter_token_latency_s_mean", + "results_ttft_s_quantiles_p50", + "results_ttft_s_quantiles_p90", + "results_ttft_s_quantiles_p99", + "results_ttft_s_mean", + "results_end_to_end_latency_s_quantiles_p50", + "results_end_to_end_latency_s_quantiles_p90", + "results_end_to_end_latency_s_quantiles_p99", + "results_end_to_end_latency_s_mean", + "num_completed_requests", + "elapsed_time", + "incremental_time_delay", + "total_throughput", + "incremental_throughput", + ], + values, + ): + value_lists[var_name].append(val) + if val is None: + failed_cases.append((i, var_name, "missing")) + + try: + assert val > 0, f"value <= 0" + except AssertionError as e: + failed_cases.append((i, var_name, str(e))) + + # Output final result + if failed_cases: + print(f"\n[WARNING] Assertion failed: {len(failed_cases)} abnormal cases found") + for i, key, reason in failed_cases: + print(f" Iteration={i + 1}, key='{key}' -> {reason}") + else: + print("\n[INFO] All values are greater than 0. Assertion passed!") + + return value_lists diff --git a/test/test_ucm_dram.py b/test/test_ucm_dram.py new file mode 100644 index 00000000..020405d1 --- /dev/null +++ b/test/test_ucm_dram.py @@ -0,0 +1,250 @@ +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +import random +import unittest +import unittest.mock as mock +from contextlib import contextmanager +from typing import List +from unittest.mock import MagicMock + +import torch +from vllm.multimodal.inputs import MultiModalKwargs +from vllm.sampling_params import SamplingParams +from vllm.utils import sha256 +from vllm.v1.core.kv_cache_utils import hash_request_tokens +from vllm.v1.request import Request + + +@contextmanager +def mock_stream_context(stream=None): + yield + + +class MockStream: + def __init__(self, device=None): + self.device = device or torch.device("cpu") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def synchronize(self): + pass + + def record_event(self, event=None): + return event or MockEvent() + + def wait_stream(self, stream): + pass + + +class MockEvent: + def __init__(self, enable_timing=False): + self.enable_timing = enable_timing + + def record(self, stream=None): + pass + + def wait(self, stream=None): + pass + + def synchronize(self): + pass + + +def patch_cuda_for_cpu(): + mock.patch("torch.cuda.Stream", MockStream).start() + mock.patch("torch.cuda.Event", MockEvent).start() + mock.patch("torch.cuda.current_stream", return_value=MockStream()).start() + mock.patch("torch.cuda.synchronize", side_effect=lambda *a, **k: None).start() + mock.patch("torch.cuda.is_available", return_value=True).start() + mock.patch("torch.cuda.stream", mock_stream_context).start() + + +patch_cuda_for_cpu() +from ucm.store.dramstore.dramstore_connector import ( # isort: skip + DramTask, + UcmDramStore, +) + + +def make_request( + request_id, prompt_token_ids, mm_positions=None, mm_hashes=None, cache_salt=None +): + if mm_positions is None: + multi_modal_inputs = None + else: + multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions) + + return Request( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + multi_modal_inputs=multi_modal_inputs, + multi_modal_hashes=mm_hashes, + multi_modal_placeholders=mm_positions, + sampling_params=SamplingParams(max_tokens=17), + pooling_params=None, + eos_token_id=100, + arrival_time=0, + lora_request=None, + cache_salt=cache_salt, + ) + + +class TestUcmDram(unittest.TestCase): + + @classmethod + def setUpClass(cls): + print("===> Before all tests (setUpClass)") + + @classmethod + def tearDownClass(cls): + print("===> After all tests (setUpClass)") + + def setUp(self): + self.config = {"block_size": 4} + self.scheduler_config = { + "role": "scheduler", + "max_cache_size": 1073741824, + "kv_block_size": 262144, + } + self.worker_config = { + "role": "worker", + "max_cache_size": 1073741824, + "kv_block_size": 262144, + } + + self.block_number = 4 + self.block_size = int(self.config["block_size"]) + self.scheduler_dram = UcmDramStore(self.scheduler_config) + self.worker_dram = UcmDramStore(self.worker_config) + random.seed(20250728) + self.request = make_request( + request_id=1, + prompt_token_ids=random.sample( + range(0, 10000), self.block_number * self.block_size + ), + mm_positions=None, + mm_hashes=None, + ) + block_hash_types = hash_request_tokens(sha256, self.block_size, self.request) + self.block_hashes: List[str] = [str(x.hash_value) for x in block_hash_types] + + def test_look_up_all_hit(self): + """ + Test for all blocks hitten in cache + """ + expected = [True] * len(self.block_hashes) + self.scheduler_dram.cached_blocks.update(self.block_hashes) + actual = self.scheduler_dram.lookup(self.block_hashes) + + self.assertEqual(actual, expected) + + def test_lookup_partial_hit(self): + """ + Test for part of the blocks hitten in cache + """ + partial_index = random.randint(0, 4) + partial_hashes = self.block_hashes[:partial_index] + self.scheduler_dram.cached_blocks.update(partial_hashes) + actual = self.scheduler_dram.lookup(self.block_hashes) + expected = [True] * partial_index + [False] * (self.block_size - partial_index) + self.assertEqual(actual, expected) + + def test_lookup_none_hit(self): + """ + Test for none of the blocks hitten in cache + """ + actual = self.scheduler_dram.lookup(self.block_hashes) + expected = [False] * len(self.block_hashes) + self.assertEqual(actual, expected) + + def test_load_success(self): + """ + Test for load from cache successfully + """ + src_tensors = [ + torch.randint(0, 100, (self.block_size,), dtype=torch.int8) + for _ in range(len(self.block_hashes)) + ] + offsets = [i for i in range(len(self.block_hashes))] + dump_task = self.worker_dram.dump(self.block_hashes, offsets, src_tensors) + self.worker_dram.wait(dump_task) + dst_tensors = [ + torch.zeros(self.block_size, dtype=torch.int8) + for _ in range(len(self.block_hashes)) + ] + load_task = self.worker_dram.load(self.block_hashes, offsets, dst_tensors) + + self.assertIsInstance(load_task, DramTask) + self.assertIsNotNone(load_task.event) + for i, (src_tensor, dst_tensor) in enumerate(zip(src_tensors, dst_tensors)): + self.assertEqual(dst_tensor.shape[0], self.block_size) + self.assertTrue( + torch.equal(src_tensor, dst_tensor), + f"Block {i} loaded data is different", + ) + + def test_dump_success(self): + """ + Test data dump successfully + """ + src_tensors = [ + torch.randint(0, 100, (self.block_size,), dtype=torch.int8) + for _ in range(len(self.block_hashes)) + ] + offsets = [i for i in range(len(self.block_hashes))] + original_data = [tensor.clone() for tensor in src_tensors] + dump_task = self.worker_dram.dump(self.block_hashes, offsets, src_tensors) + self.assertIsInstance(dump_task, DramTask) + self.assertIsNotNone(dump_task.event) + self.worker_dram.wait(dump_task) + for i, block_id in enumerate(self.block_hashes): + key = block_id + "_" + str(offsets[i]) + cached_data = self.worker_dram.dram_cache[key] + self.assertEqual(cached_data.shape[0], self.block_size) + self.assertTrue(torch.equal(cached_data, original_data[i])) + + def test_wait_success(self): + """ + Test wait for task successfully + """ + task = DramTask() + task.event = MagicMock() + result = self.worker_dram.wait(task) + self.assertEqual(result, 0) + task.event.synchronize.assert_called_once() + + def test_wait_failure(self): + task = DramTask() + task.event = None + result = self.worker_dram.wait(task) + self.assertEqual(result, -1) + + +if __name__ == "__main__": + unittest.main()