Skip to content

Commit

Permalink
Merge pull request #56 from fixie-ai/juberti/spread
Browse files Browse the repository at this point in the history
Rework parsing to handle non-str args
  • Loading branch information
juberti authored May 1, 2024
2 parents c496c5d + 18a4e7f commit 052c720
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 47 deletions.
1 change: 1 addition & 0 deletions llm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ async def main(args: argparse.Namespace):
med_index2 = len(results) // 2
median_latency = (results[med_index1].latency + results[med_index2].latency) / 2
if num_tokens > 0:
assert first_token_time
ttft = first_token_time - chosen.start_time
tps = min((num_tokens - 1) / (end_time - first_token_time), 999)
total_time = end_time - chosen.start_time
Expand Down
76 changes: 29 additions & 47 deletions llm_benchmark_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import os
import random
import sys
from typing import Any, Dict, List, Optional, Tuple

import gcloud.aio.storage as gcs
Expand Down Expand Up @@ -83,23 +84,17 @@ class _Llm:
"""

def __init__(self, model: str, display_name: Optional[str] = None, **kwargs):
self.pass_argv = []
self.args = {
"model": model,
"display_name": display_name,
"format": "none",
**kwargs,
}

def apply(self, pass_argv: List[str], **kwargs):
self.pass_argv = pass_argv
self.args.update(kwargs)
return self

async def run(self, spread: float) -> asyncio.Task:
async def run(self, pass_argv: List[str], spread: float) -> asyncio.Task:
if spread:
await asyncio.sleep(spread)
full_argv = _dict_to_argv(self.args) + self.pass_argv
full_argv = _dict_to_argv(self.args) + pass_argv
return await llm_benchmark.run(full_argv)


Expand Down Expand Up @@ -361,7 +356,9 @@ class _Response:
results: List[Dict[str, Any]]


def _format_response(response: _Response, format: str, dlen: int) -> Tuple[str, str]:
def _format_response(
response: _Response, format: str, dlen: int = 0
) -> Tuple[str, str]:
if format == "json":
return json.dumps(vars(response), indent=2), "application/json"
else:
Expand Down Expand Up @@ -391,58 +388,43 @@ async def _store_response(gcp_bucket: str, key: str, text: str, content_type: st
await storage.close()


async def run(
mode: str = "text",
format: str = "text",
display_length: Optional[int] = DEFAULT_DISPLAY_LENGTH,
filter: Optional[str] = None,
spread: Optional[float] = None,
store: bool = False,
pass_argv: Optional[List[str]] = None,
**kwargs,
):
async def _run(argv: List[str]) -> Tuple[str, str]:
"""
This function is invoked either from the webapp or the main function below.
When invoked from the webapp, the arguments are passed as kwargs.
When invoked from the main function, the arguments are passed as a list of flags.
We'll give both to the _Llm.run function, which will turn them back into a
This function is invoked either from the webapp (via run) or the main function below.
The args we know about are stored in args, and any unknown args are stored in pass_argv,
which we'll pass to the _Llm.run function, who will turn them back into a
single list of flags for consumption by the llm_benchmark.run function.
"""
time_start = datetime.datetime.now()
time_str = time_start.isoformat()
region = os.getenv("FLY_REGION", "local")
argv = _dict_to_argv(kwargs) + (pass_argv or [])
models = _get_models(mode, filter)
cmd = " ".join(argv)
args, pass_argv = parser.parse_known_args(argv)
models = _get_models(args.mode, args.filter)
tasks = []
for m in models:
m.apply(pass_argv or [], **kwargs)
delay = random.uniform(0, spread)
tasks.append(asyncio.create_task(m.run(delay)))
delay = random.uniform(0, args.spread)
tasks.append(asyncio.create_task(m.run(pass_argv, delay)))
await asyncio.gather(*tasks)
results = [t.result() for t in tasks if t.result() is not None]
elapsed = datetime.datetime.now() - time_start
elapsed_str = f"{elapsed.total_seconds():.2f}s"
response = _Response(time_str, elapsed_str, region, " ".join(argv), results)
if store:
path = f"{region}/{mode}/{time_str.split('T')[0]}.json"
json, content_type = _format_response(response, "json", display_length)
response = _Response(time_str, elapsed_str, region, cmd, results)
if args.store:
path = f"{region}/{args.mode}/{time_str.split('T')[0]}.json"
json, content_type = _format_response(response, "json")
await _store_response(DEFAULT_GCS_BUCKET, path, json, content_type)
return _format_response(response, format, display_length)


async def main(args: argparse.Namespace, pass_argv: List[str]):
text, _ = await run(
args.mode,
args.format,
args.display_length,
args.filter,
args.spread,
args.store,
pass_argv,
)
return _format_response(response, args.format, args.display_length)


async def run(params: Dict[str, Any]) -> Tuple[str, str]:
return await _run(_dict_to_argv(params))


async def main():
text, _ = await _run(sys.argv[1:])
print(text)


if __name__ == "__main__":
args, unk_args = parser.parse_known_args()
asyncio.run(main(args, unk_args))
asyncio.run(main())

0 comments on commit 052c720

Please sign in to comment.