Skip to content

Commit 62b7638

Browse files
committed
add request timeout
1 parent c85cdf6 commit 62b7638

File tree

2 files changed

+40
-33
lines changed

2 files changed

+40
-33
lines changed

multimodal/vl2l/src/mlperf_inference_multimodal_vl2l/schema.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ class BaseModelWithAttributeDescriptionsFromDocstrings(BaseModel):
120120
"""
121121

122122

123+
_DEFAULT_DATASET_SIZE = 48289
124+
_DEFAULT_MIN_DURATION = timedelta(minutes=10)
125+
_DEFAULT_OFFLINE_EXPECTED_QPS = (
126+
_DEFAULT_DATASET_SIZE / _DEFAULT_MIN_DURATION.total_seconds()
127+
)
128+
129+
123130
class TestSettings(BaseModelWithAttributeDescriptionsFromDocstrings):
124131
"""The test settings for the MLPerf inference LoadGen."""
125132

@@ -131,18 +138,18 @@ class TestSettings(BaseModelWithAttributeDescriptionsFromDocstrings):
131138
evaluation.
132139
"""
133140

134-
offline_expected_qps: float = 100
141+
offline_expected_qps: float = _DEFAULT_OFFLINE_EXPECTED_QPS
135142
"""The expected QPS for the offline scenario."""
136143

137-
sample_concatenate_permutation: bool = True
138-
"""Affects the order in which the samples of the dataset are chosen.
139-
If `False`, it concatenates a single permutation of the dataset (or part
140-
of it depending on `performance_sample_count_override`) several times up to the
141-
number of samples requested.
142-
If `True`, it concatenates a multiple permutation of the dataset (or a
143-
part of it depending on `performance_sample_count_override`) several times
144-
up to the number of samples requested.
145-
"""
144+
# sample_concatenate_permutation: bool = True # noqa: ERA001
145+
# """Affects the order in which the samples of the dataset are chosen.
146+
# If `False`, it concatenates a single permutation of the dataset (or part
147+
# of it depending on `performance_sample_count_override`) several times up to the
148+
# number of samples requested.
149+
# If `True`, it concatenates a multiple permutation of the dataset (or a
150+
# part of it depending on `performance_sample_count_override`) several times
151+
# up to the number of samples requested.
152+
# """
146153

147154
server_expected_qps: float = 10
148155
"""The expected QPS for the server scenario. Loadgen will try to send as many
@@ -166,12 +173,12 @@ class TestSettings(BaseModelWithAttributeDescriptionsFromDocstrings):
166173
use_token_latencies is enabled).
167174
"""
168175

169-
min_duration: timedelta = timedelta(minutes=10)
176+
min_duration: timedelta = _DEFAULT_MIN_DURATION
170177
"""The minimum testing duration (in seconds or ISO 8601 format like `PT5S`). The
171178
benchmark runs until this value has been met.
172179
"""
173180

174-
min_query_count: int = 48289
181+
min_query_count: int = _DEFAULT_DATASET_SIZE
175182
"""The minimum testing query count. The benchmark runs until this value has been
176183
met. If min_query_count is less than the total number of samples in the dataset,
177184
only the first min_query_count samples will be used during testing.
@@ -191,7 +198,7 @@ class TestSettings(BaseModelWithAttributeDescriptionsFromDocstrings):
191198
f"{MAX_NUM_ESTIMATION_PERFORMANCE_SAMPLES} samples (at most), and then"
192199
" use this estimation as the value P.",
193200
),
194-
] = 48289
201+
] = _DEFAULT_DATASET_SIZE
195202

196203
use_token_latencies: bool = False
197204
"""By default, the Server scenario will use `server_target_latency` as the
@@ -207,8 +214,7 @@ class TestSettings(BaseModelWithAttributeDescriptionsFromDocstrings):
207214
mode="before",
208215
)
209216
@classmethod
210-
def parse_timedelta(cls, value: timedelta | float |
211-
str) -> timedelta | str:
217+
def parse_timedelta(cls, value: timedelta | float | str) -> timedelta | str:
212218
"""Parse timedelta from seconds (int/float/str) or ISO 8601 format."""
213219
if isinstance(value, timedelta):
214220
return value
@@ -234,12 +240,9 @@ def to_lgtype(self) -> lg.TestSettings:
234240
settings.server_target_latency_ns = round(
235241
self.server_target_latency.total_seconds() * 1e9,
236242
)
237-
settings.ttft_latency = round(
238-
self.server_ttft_latency.total_seconds() * 1e9)
239-
settings.tpot_latency = round(
240-
self.server_tpot_latency.total_seconds() * 1e9)
241-
settings.min_duration_ms = round(
242-
self.min_duration.total_seconds() * 1000)
243+
settings.ttft_latency = round(self.server_ttft_latency.total_seconds() * 1e9)
244+
settings.tpot_latency = round(self.server_tpot_latency.total_seconds() * 1e9)
245+
settings.min_duration_ms = round(self.min_duration.total_seconds() * 1000)
243246
settings.min_query_count = self.min_query_count
244247
settings.performance_sample_count_override = (
245248
self.performance_sample_count_override
@@ -343,7 +346,9 @@ class Dataset(BaseModelWithAttributeDescriptionsFromDocstrings):
343346
"""The token to access the HuggingFace repository of the dataset."""
344347

345348
revision: str | None = None
346-
"""The revision of the dataset."""
349+
"""The revision of the dataset. If not provided, the default revision (i.e., usually
350+
`main`) will be used.
351+
"""
347352

348353
split: list[str] = ["train", "test"]
349354
"""Dataset splits to use for the benchmark, e.g., "train" and "test". You can add
@@ -388,6 +393,13 @@ class Endpoint(BaseModelWithAttributeDescriptionsFromDocstrings):
388393
```json ... ``` code block).
389394
"""
390395

396+
request_timeout: timedelta = timedelta(hours=2)
397+
"""The timeout for the inference request to the endpoint. The default value for
398+
OpenAI API client is 10 minutes
399+
(https://github.com/openai/openai-python?tab=readme-ov-file#timeouts) which might
400+
not be sufficient for the offline scenario.
401+
"""
402+
391403

392404
class EndpointToDeploy(Endpoint):
393405
"""Specifies the endpoint to deploy for the VL2L benchmark."""
@@ -438,8 +450,7 @@ def __init__(self, flag: str) -> None:
438450
class BlacklistedVllmCliFlagError(ValueError):
439451
"""The exception raised when a blacklisted vllm CLI flag is encountered."""
440452

441-
BLACKLIST: ClassVar[list[str]] = [
442-
"--model", "--host", "--port", "--api-key"]
453+
BLACKLIST: ClassVar[list[str]] = ["--model", "--host", "--port", "--api-key"]
443454

444455
def __init__(self, flag: str) -> None:
445456
"""Initialize the exception."""
@@ -492,6 +503,5 @@ def ensure_content_is_list(
492503
== "pydantic_core._pydantic_core"
493504
and message["content"].__class__.__name__ == "ValidatorIterator"
494505
):
495-
message["content"] = list(
496-
message["content"]) # type: ignore[arg-type]
506+
message["content"] = list(message["content"]) # type: ignore[arg-type]
497507
return messages

multimodal/vl2l/src/mlperf_inference_multimodal_vl2l/task.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
base_url=endpoint.url,
6767
http_client=DefaultAioHttpClient(),
6868
api_key=endpoint.api_key,
69+
timeout=endpoint.request_timeout.total_seconds(),
6970
)
7071
self.event_loop, self.event_loop_thread = (
7172
self._create_event_loop_in_separate_thread()
@@ -152,9 +153,7 @@ def estimated_num_performance_samples(self) -> int:
152153
"""
153154
estimation_indices = random.sample(
154155
range(self.total_num_samples),
155-
k=min(
156-
MAX_NUM_ESTIMATION_PERFORMANCE_SAMPLES,
157-
self.total_num_samples),
156+
k=min(MAX_NUM_ESTIMATION_PERFORMANCE_SAMPLES, self.total_num_samples),
158157
)
159158
estimation_samples = [
160159
self.formulate_loaded_sample(
@@ -219,8 +218,7 @@ def _unload_samples_from_ram(query_sample_indices: list[int]) -> None:
219218
_unload_samples_from_ram,
220219
)
221220

222-
async def _query_endpoint_async_batch(
223-
self, query_sample: lg.QuerySample) -> None:
221+
async def _query_endpoint_async_batch(self, query_sample: lg.QuerySample) -> None:
224222
"""Query the endpoint through the async OpenAI API client."""
225223
try:
226224
sample = self.loaded_samples[query_sample.index]
@@ -297,8 +295,7 @@ async def _query_endpoint_async_batch(
297295
],
298296
)
299297

300-
async def _query_endpoint_async_stream(
301-
self, query_sample: lg.QuerySample) -> None:
298+
async def _query_endpoint_async_stream(self, query_sample: lg.QuerySample) -> None:
302299
"""Query the endpoint through the async OpenAI API client."""
303300
ttft_set = False
304301
try:

0 commit comments

Comments
 (0)