Skip to content

Commit

Permalink
Merge pull request #8 from sfc-gh-mkeralapura/testing
Browse files Browse the repository at this point in the history
Try to setup a pp>1 case
  • Loading branch information
sfc-gh-mkeralapura authored Aug 14, 2024
2 parents a8042bc + ead3a43 commit 42a7229
Showing 1 changed file with 28 additions and 6 deletions.
34 changes: 28 additions & 6 deletions tests/tracing/test_tracing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import os
import threading
from concurrent import futures
Expand All @@ -14,8 +15,11 @@
OTEL_EXPORTER_OTLP_TRACES_INSECURE)

from vllm import LLM, SamplingParams
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
from vllm.tracing import SpanAttributes

from ..utils import fork_new_process_for_each_test

FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"

FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value',
Expand Down Expand Up @@ -67,7 +71,9 @@ def trace_service():
server.stop(None)


def test_traces(trace_service):
@fork_new_process_for_each_test
def test_traces(request):
trace_service = request.getfixturevalue('trace_service')
os.environ[OTEL_EXPORTER_OTLP_TRACES_INSECURE] = "true"

sampling_params = SamplingParams(temperature=0.01,
Expand Down Expand Up @@ -123,20 +129,36 @@ def test_traces(trace_service):
assert metrics.model_execute_time is None


def test_traces_with_detailed_steps(trace_service):
@pytest.mark.parametrize("use_pp", [True, False])
@fork_new_process_for_each_test
def test_traces_with_detailed_steps(request, use_pp):
trace_service = request.getfixturevalue('trace_service')
asyncio.run(_test_traces_with_detailed_steps(trace_service, use_pp))


async def _test_traces_with_detailed_steps(trace_service, use_pp):
os.environ[OTEL_EXPORTER_OTLP_TRACES_INSECURE] = "true"

sampling_params = SamplingParams(temperature=0.01,
top_p=0.1,
max_tokens=256)
model = "facebook/opt-125m"
llm = LLM(
model = "facebook/opt-125m" if not use_pp else "JackFram/llama-160m"
pp_size = 1 if not use_pp else 2
engine_args = AsyncEngineArgs(
model=model,
otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
collect_detailed_traces="all",
pipeline_parallel_size=pp_size,
)
prompts = ["This is a short prompt"]
outputs = llm.generate(prompts, sampling_params=sampling_params)
async_engine = AsyncLLMEngine.from_engine_args(engine_args)
prompt = "This is a short prompt"
request_id = f"test_{use_pp}"
outputs_generator = async_engine.generate(prompt,
request_id=request_id,
sampling_params=sampling_params)
outputs = []
async for request_output in outputs_generator:
outputs.append(request_output)

timeout = 5
if not trace_service.evt.wait(timeout):
Expand Down

0 comments on commit 42a7229

Please sign in to comment.