55import sys
66from datetime import timedelta
77from enum import StrEnum , auto
8+ from pathlib import Path
89from typing import Annotated
910
1011import mlperf_loadgen as lg
1112from loguru import logger
12- from pydantic import BaseModel , Field , field_validator
13+ from pydantic import BaseModel , DirectoryPath , Field , field_validator
1314from pydantic_typer import Typer
1415from typer import Option
1516
@@ -74,6 +75,36 @@ def __init__(self, test_mode: TestMode) -> None:
7475 super ().__init__ (f"Unknown test mode: { test_mode } " )
7576
7677
78+ class LoggingMode (StrEnum ):
79+ """Specifies when logging should be sampled and stringified."""
80+
81+ ASYNC_POLL = auto ()
82+ """ Logs are serialized and output on an IOThread that polls for new logs
83+ at a fixed interval. This is the only mode currently implemented."""
84+
85+ END_OF_TEST_ONLY = auto ()
86+ """ Not implemented """
87+
88+ SYNCHRONOUS = auto ()
89+ """ Not implemented """
90+
91+ def to_lgtype (self ) -> lg .LoggingMode :
92+ """Convert logging mode to its corresponding LoadGen type."""
93+ match self :
94+ case LoggingMode .ASYNC_POLL :
95+ return lg .LoggingMode .AsyncPoll
96+ case _:
97+ raise UnknownLoggingModeValueError
98+
99+
100+ class UnknownLoggingModeValueError (ValueError ):
101+ """The exception raised when an unknown logging mode is encountered."""
102+
103+ def __init__ (self , test_mode : TestMode ) -> None :
104+ """Initialize the exception."""
105+ super ().__init__ (f"Unknown logging mode: { test_mode } " )
106+
107+
77108class TestSettings (BaseModel ):
78109 """The test settings for the MLPerf inference LoadGen."""
79110
@@ -102,27 +133,73 @@ class TestSettings(BaseModel):
102133 ),
103134 ] = 100
104135
136+ server_expected_qps : Annotated [
137+ float ,
138+ Field (
139+ description = "The expected QPS for the server scenario. "
140+ "Loadgen will try to send as many request as necessary "
141+ "to achieve this value." ,
142+ ),
143+ ] = 1
144+
145+ server_target_latency : Annotated [
146+ timedelta ,
147+ Field (description = """Expected latency constraint for Server scenario.
148+ This is a constraint that we expect depending
149+ on the argument server_expected_qps.
150+ When server_expected_qps increases, we expect the latency to also increase.
151+ When server_expected_qps decreases, we expect the latency to also decrease.""" ),
152+ ] = timedelta (seconds = 1 )
153+
154+ server_ttft_latency : Annotated [
155+ timedelta ,
156+ Field (description = """Time to First Token (TTFT)
157+ latency constraint result validation"
158+ (used when use_token_latencies is enabled).""" ),
159+ ] = timedelta (seconds = 1 )
160+
161+ server_tpot_latency : Annotated [
162+ timedelta ,
163+ Field (description = """Time per Output Token (TPOT)
164+ latency constraint result validation"
165+ (used when use_token_latencies is enabled).""" ),
166+ ] = timedelta (seconds = 1 )
167+
105168 min_duration : Annotated [
106169 timedelta ,
107170 Field (
108- description = (
109- "The minimum testing duration (in seconds or ISO 8601 format like"
110- " PT5S)."
111- ),
171+ description = """The minimum testing duration
172+ (in seconds or ISO 8601 format like PT5S).
173+ The benchmark runs until this value has been met.""" ,
112174 ),
113175 ] = timedelta (seconds = 5 )
114176
177+ min_query_count : Annotated [
178+ int ,
179+ Field (
180+ description = """The minimum testing query count.
181+ The benchmark runs until this value has been met.""" ,
182+ ),
183+ ] = 100
184+
115185 use_token_latencies : Annotated [
116186 bool ,
117187 Field (
118- description = "When set to True, LoadGen will track TTFT and TPOT." ,
188+ description = """By default,
189+ the Server scenario will use server_target_latency as the constraint.
190+ When set to True, the Server scenario will use server_ttft_latency
191+ and server_tpot_latency as the constraint.""" ,
119192 ),
120- ] = True
193+ ] = False
121194
122- @field_validator ("min_duration" , mode = "before" )
195+ @field_validator ("server_target_latency" ,
196+ "server_ttft_latency" ,
197+ "server_tpot_latency" ,
198+ "min_duration" ,
199+ mode = "before" )
123200 @classmethod
124- def parse_min_duration (cls , value : timedelta |
125- float | str ) -> timedelta | str :
201+ def parse_timedelta (cls , value : timedelta |
202+ float | str ) -> timedelta | str :
126203 """Parse timedelta from seconds (int/float/str) or ISO 8601 format."""
127204 if isinstance (value , timedelta ):
128205 return value
@@ -144,12 +221,133 @@ def to_lgtype(self) -> lg.TestSettings:
144221 settings .scenario = self .scenario .to_lgtype ()
145222 settings .mode = self .mode .to_lgtype ()
146223 settings .offline_expected_qps = self .offline_expected_qps
224+ settings .server_target_qps = self .server_expected_qps
225+ settings .server_target_latency_ns = round (
226+ self .server_target_latency .total_seconds () * 1e9 )
227+ settings .ttft_latency = round (
228+ self .server_ttft_latency .total_seconds () * 1e9 )
229+ settings .tpot_latency = round (
230+ self .server_tpot_latency .total_seconds () * 1e9 )
147231 settings .min_duration_ms = round (
148232 self .min_duration .total_seconds () * 1000 )
233+ settings .min_query_count = self .min_query_count
149234 settings .use_token_latencies = self .use_token_latencies
150235 return settings
151236
152237
238+ class LogOutputSettings (BaseModel ):
239+ """The test log output settings for the MLPerf inference LoadGen."""
240+ outdir : Annotated [
241+ DirectoryPath ,
242+ Field (
243+ description = "Where to save the output files from the benchmark." ,
244+ ),
245+ ] = DirectoryPath ("output" )
246+ prefix : Annotated [
247+ str ,
248+ Field (
249+ description = "Modify the filenames of the logs with a prefix." ,
250+ ),
251+ ] = "mlperf_log_"
252+ suffix : Annotated [
253+ str ,
254+ Field (
255+ description = "Modify the filenames of the logs with a suffix." ,
256+ ),
257+ ] = ""
258+ prefix_with_datetime : Annotated [
259+ bool ,
260+ Field (
261+ description = "Modify the filenames of the logs with a datetime." ,
262+ ),
263+ ] = False
264+ copy_detail_to_stdout : Annotated [
265+ bool ,
266+ Field (
267+ description = "Print details of performance test to stdout." ,
268+ ),
269+ ] = False
270+ copy_summary_to_stdout : Annotated [
271+ bool ,
272+ Field (
273+ description = "Print results of performance test to terminal." ,
274+ ),
275+ ] = True
276+
277+ @field_validator ("outdir" , mode = "before" )
278+ @classmethod
279+ def parse_directory_field (cls , value : str ) -> None :
280+ """Verify and create the output directory to store log files."""
281+ path = Path (value )
282+ path .mkdir (exist_ok = True )
283+ return path
284+
285+ def to_lgtype (self ) -> lg .LogOutputSettings :
286+ """Convert the log output settings to its corresponding LoadGen type."""
287+ log_output_settings = lg .LogOutputSettings ()
288+ log_output_settings .outdir = self .outdir .as_posix ()
289+ log_output_settings .prefix = self .prefix
290+ log_output_settings .suffix = self .suffix
291+ log_output_settings .prefix_with_datetime = self .prefix_with_datetime
292+ log_output_settings .copy_detail_to_stdout = self .copy_detail_to_stdout
293+ log_output_settings .copy_summary_to_stdout = self .copy_summary_to_stdout
294+ return log_output_settings
295+
296+
297+ class LogSettings (BaseModel ):
298+ """The test log settings for the MLPerf inference LoadGen."""
299+ log_output : Annotated [
300+ LogOutputSettings ,
301+ Field (
302+ description = "Log output settings" ,
303+ ),
304+ ] = LogOutputSettings
305+ log_mode : Annotated [
306+ LoggingMode ,
307+ Field (
308+ description = """How and when logging should be
309+ sampled and stringified at runtime""" ,
310+ ),
311+ ] = LoggingMode .ASYNC_POLL
312+ enable_trace : Annotated [
313+ bool ,
314+ Field (
315+ description = "Enable trace" ,
316+ ),
317+ ] = True
318+
319+ def to_lgtype (self ) -> lg .LogSettings :
320+ """Convert log settings to its corresponding LoadGen type."""
321+ log_settings = lg .LogSettings ()
322+ log_settings .log_output = self .log_output .to_lgtype ()
323+ log_settings .log_mode = self .log_mode .to_lgtype ()
324+ log_settings .enable_trace = self .enable_trace
325+ return log_settings
326+
327+
328+ class Settings (BaseModel ):
329+ """Combine the settings for the test and logging of LoadGen."""
330+ test : Annotated [
331+ TestSettings ,
332+ Field (
333+ description = "Test settings parameters." ,
334+ ),
335+ ] = TestSettings
336+
337+ logging : Annotated [
338+ LogSettings ,
339+ Field (
340+ description = "Test logging parameters" ,
341+ ),
342+ ] = LogSettings
343+
344+ def to_lgtype (self ) -> tuple [lg .TestSettings , lg .LogSettings ]:
345+ """Return test and log settings for LoadGen."""
346+ test_settings = self .test .to_lgtype ()
347+ log_settings = self .logging .to_lgtype ()
348+ return (test_settings , log_settings )
349+
350+
153351class Model (BaseModel ):
154352 """Specifies the model to use for the VL2L benchmark."""
155353
@@ -211,7 +409,7 @@ class Endpoint(BaseModel):
211409@app .command ()
212410def main (
213411 * ,
214- settings : TestSettings ,
412+ settings : Settings ,
215413 model : Model ,
216414 dataset : Dataset ,
217415 endpoint : Endpoint ,
@@ -234,17 +432,18 @@ def main(
234432 "Running VL2L benchmark with OpenAI API endpoint: {}" ,
235433 endpoint )
236434 logger .info ("Running VL2L benchmark with random seed: {}" , random_seed )
237- lg_settings = settings .to_lgtype ()
435+ test_settings , log_settings = settings .to_lgtype ()
238436 task = ShopifyGlobalCatalogue (
239437 dataset_cli = dataset ,
240438 model_cli = model ,
241439 endpoint_cli = endpoint ,
440+ scenario = settings .test .scenario ,
242441 random_seed = random_seed ,
243442 )
244443 sut = task .construct_sut ()
245444 qsl = task .construct_qsl ()
246445 logger .info ("Starting the VL2L benchmark with LoadGen..." )
247- lg .StartTest (sut , qsl , lg_settings )
446+ lg .StartTestWithLogSettings (sut , qsl , test_settings , log_settings )
248447 logger .info ("The VL2L benchmark with LoadGen completed." )
249448 lg .DestroyQSL (qsl )
250449 lg .DestroySUT (sut )
0 commit comments