Skip to content

Commit

Permalink
Generalize perplexity logic for streaming end detection across OpenAI…
Browse files Browse the repository at this point in the history
… compatible models (#286)

* Use perplexity logic to check cerabras streaming end indication

* Check openai data contains [DONE] first

* Check response choices before acessing choices objects

* Ensure reaching end stream in the last object

* Specific conditions for openai and mistral

* Add parameterized for openai compatible tests

* Include the newline characters without using a backslash in the f-string

* Check for usage to determine reaching the end of cerebras stream

* Update using gpt4-o-mini for openai tests

* Add cerebras secret to ci test

* Remove usage condition
  • Loading branch information
kxtran authored Aug 29, 2024
1 parent fde3e7d commit 98dbb2c
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 45 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ jobs:
LAMINI_API_KEY: ${{ secrets.LAMINI_API_KEY }}
GOOGLE_API_KEY : ${{ secrets.GOOGLE_API_KEY }}
PERPLEXITYAI_API_KEY: ${{ secrets.PERPLEXITYAI_API_KEY }}
CEREBRAS_API_KEY: ${{ secrets.CEREBRAS_API_KEY }}
steps:
- uses: actions/checkout@v4
- name: Install poetry
Expand Down
49 changes: 36 additions & 13 deletions log10/_httpx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@ def patch_streaming_log(self, duration: int, full_response: str):
"\r\n\r\n" if self.llm_client == LLM_CLIENTS.OPENAI and "perplexity" in self.host_header else "\n\n"
)
responses = full_response.split(separator)
response_json = self.parse_response_data(responses)
filter_responses = [r for r in responses if r]
response_json = self.parse_response_data(filter_responses)

self.log_row["response"] = json.dumps(response_json)
self.log_row["status"] = "finished"
Expand Down Expand Up @@ -507,30 +508,52 @@ def is_response_end_reached(self, text: str) -> bool:
if self.llm_client == LLM_CLIENTS.ANTHROPIC:
return self.is_anthropic_response_end_reached(text)
elif self.llm_client == LLM_CLIENTS.OPENAI:
if "perplexity" in self.host_header:
return self.is_perplexity_response_end_reached(text)
else:
return self.is_openai_response_end_reached(text)
return self.is_openai_response_end_reached(text)
else:
logger.debug("Currently logging is only available for async openai and anthropic.")
return False

def is_anthropic_response_end_reached(self, text: str):
return "event: message_stop" in text

def is_perplexity_response_end_reached(self, text: str):
def has_response_finished_with_stop_reason(self, text: str, parse_single_data_entry: bool = False):
json_strings = text.split("data: ")[1:]
# Parse the last JSON string
last_json_str = json_strings[-1].strip()
last_object = json.loads(last_json_str)
return last_object.get("choices", [{}])[0].get("finish_reason", "") == "stop"
try:
last_object = json.loads(last_json_str)
except json.JSONDecodeError:
logger.debug(f"Full response: {repr(text)}")
logger.debug(f"Failed to parse the last JSON string: {last_json_str}")
return False

if choices := last_object.get("choices", []):
choice = choices[0]
else:
return False

finish_reason = choice.get("finish_reason", "")
content = choice.get("delta", {}).get("content", "")

if finish_reason == "stop":
return not content if parse_single_data_entry else True
return False

def is_openai_response_end_reached(self, text: str):
def is_openai_response_end_reached(self, text: str, parse_single_data_entry: bool = False):
"""
In Perplexity, the last item in the responses is empty.
In OpenAI and Mistral, the last item in the responses is "data: [DONE]".
OpenAI, Mistral response end is reached when the data contains "data: [DONE]\n\n".
Perplexity, Cerebras response end is reached when the last JSON object contains finish_reason == stop.
The parse_single_data_entry argument is used to distinguish between a single data entry and multiple data entries.
The function is called in two contexts: first, to assess whether the entire accumulated response has completed when processing streaming data, and second, to verify if a single response object has finished processing during individual response handling.
"""
return not text or "data: [DONE]" in text
hosts = ["openai", "mistral"]

if any(p in self.host_header for p in hosts):
suffix = "data: [DONE]" + ("" if parse_single_data_entry else "\n\n")
if text.endswith(suffix):
return True

return self.has_response_finished_with_stop_reason(text, parse_single_data_entry)

def parse_anthropic_responses(self, responses: list[str]):
message_id = ""
Expand Down Expand Up @@ -628,7 +651,7 @@ def parse_openai_responses(self, responses: list[str]):
finish_reason = ""

for r in responses:
if self.is_openai_response_end_reached(r):
if self.is_openai_response_end_reached(r, parse_single_data_entry=True):
break

# loading the substring of response text after 'data: '.
Expand Down
4 changes: 2 additions & 2 deletions tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[pytest]
addopts =
--openai_model=gpt-3.5-turbo
--openai_vision_model=gpt-4o
--openai_model=gpt-4o-mini
--openai_vision_model=gpt-4o-mini
--anthropic_model=claude-3-haiku-20240307
--anthropic_legacy_model=claude-2.1
--google_model=gemini-1.5-pro-latest
Expand Down
87 changes: 57 additions & 30 deletions tests/test_openai_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,39 @@

log10(openai)

model_name = "llama-3.1-sonar-small-128k-chat"

if "PERPLEXITYAI_API_KEY" not in os.environ:
raise ValueError("Please set the PERPLEXITYAI_API_KEY environment variable.")

compatibility_config = {
"base_url": "https://api.perplexity.ai",
"api_key": os.environ.get("PERPLEXITYAI_API_KEY"),
}
# Define a fixture that provides parameterized api_key and base_url
@pytest.fixture(
params=[
{
"model_name": "llama-3.1-sonar-small-128k-chat",
"api_key": "PERPLEXITYAI_API_KEY",
"base_url": "https://api.perplexity.ai",
},
{"model_name": "open-mistral-nemo", "api_key": "MISTRAL_API_KEY", "base_url": "https://api.mistral.ai/v1"},
{"model_name": "llama3.1-8b", "api_key": "CEREBRAS_API_KEY", "base_url": "https://api.cerebras.ai/v1"},
]
)
def config(request):
api_environment_variable = request.param["api_key"]
if api_environment_variable not in os.environ:
raise ValueError(f"Please set the {api_environment_variable} environment variable.")

return {
"base_url": request.param["base_url"],
"api_key": request.param["api_key"],
"model_name": request.param["model_name"],
}


@pytest.mark.chat
def test_chat(session):
def test_chat(session, config):
compatibility_config = {
"base_url": config["base_url"],
"api_key": os.environ.get(config["api_key"]),
}
model_name = config["model_name"]

client = openai.OpenAI(**compatibility_config)
completion = client.chat.completions.create(
model=model_name,
Expand All @@ -46,7 +66,13 @@ def test_chat(session):


@pytest.mark.chat
def test_chat_not_given(session):
def test_chat_not_given(session, config):
compatibility_config = {
"base_url": config["base_url"],
"api_key": os.environ.get(config["api_key"]),
}
model_name = config["model_name"]

client = openai.OpenAI(**compatibility_config)
completion = client.chat.completions.create(
model=model_name,
Expand All @@ -69,23 +95,13 @@ def test_chat_not_given(session):
@pytest.mark.chat
@pytest.mark.async_client
@pytest.mark.asyncio(scope="module")
async def test_chat_async(session):
client = AsyncOpenAI(**compatibility_config)
completion = await client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": "Say this is a test"}],
)
async def test_chat_async(session, config):
compatibility_config = {
"base_url": config["base_url"],
"api_key": os.environ.get(config["api_key"]),
}
model_name = config["model_name"]

content = completion.choices[0].message.content
assert isinstance(content, str)
await finalize()
_LogAssertion(completion_id=session.last_completion_id(), message_content=content).assert_chat_response()


@pytest.mark.chat
@pytest.mark.async_client
@pytest.mark.asyncio(scope="module")
async def test_perplexity_chat_async(session):
client = AsyncOpenAI(**compatibility_config)
completion = await client.chat.completions.create(
model=model_name,
Expand All @@ -100,7 +116,13 @@ async def test_perplexity_chat_async(session):

@pytest.mark.chat
@pytest.mark.stream
def test_chat_stream(session):
def test_chat_stream(session, config):
compatibility_config = {
"base_url": config["base_url"],
"api_key": os.environ.get(config["api_key"]),
}
model_name = config["model_name"]

client = openai.OpenAI(**compatibility_config)
response = client.chat.completions.create(
model=model_name,
Expand All @@ -111,17 +133,22 @@ def test_chat_stream(session):

output = ""
for chunk in response:
output += chunk.choices[0].delta.content
output += chunk.choices[0].delta.content or ""

_LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response()


@pytest.mark.async_client
@pytest.mark.stream
@pytest.mark.asyncio(scope="module")
async def test_chat_async_stream(session):
client = AsyncOpenAI(**compatibility_config)
async def test_chat_async_stream(session, config):
compatibility_config = {
"base_url": config["base_url"],
"api_key": os.environ.get(config["api_key"]),
}
model_name = config["model_name"]

client = AsyncOpenAI(**compatibility_config)
output = ""
stream = await client.chat.completions.create(
model=model_name,
Expand Down

0 comments on commit 98dbb2c

Please sign in to comment.