Skip to content

Commit

Permalink
Move retries to int test, update to smoke test
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Jul 23, 2023
1 parent f71a178 commit 1940d0a
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 31 deletions.
83 changes: 80 additions & 3 deletions libs/langchain/tests/integration_tests/llms/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test OpenAI API wrapper."""

from pathlib import Path
from typing import Generator
from typing import Any, Generator
from unittest.mock import MagicMock, patch

import pytest

Expand All @@ -10,7 +10,10 @@
from langchain.llms.loading import load_llm
from langchain.llms.openai import OpenAI, OpenAIChat
from langchain.schema import LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
from tests.unit_tests.callbacks.fake_callback_handler import (
FakeAsyncCallbackHandler,
FakeCallbackHandler,
)


def test_openai_call() -> None:
Expand Down Expand Up @@ -284,3 +287,77 @@ def test_chat_openai_get_num_tokens(model: str) -> None:
"""Test get_tokens."""
llm = ChatOpenAI(model=model)
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]


@pytest.fixture
def mock_completion() -> dict:
return {
"id": "cmpl-3evkmQda5Hu7fcZavknQda3SQ",
"object": "text_completion",
"created": 1689989000,
"model": "text-davinci-003",
"choices": [
{"text": "Bar Baz", "index": 0, "logprobs": None, "finish_reason": "length"}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
}


@pytest.mark.requires("openai")
def test_openai_retries(mock_completion: dict) -> None:
llm = OpenAI()
mock_client = MagicMock()
completed = False
raised = False
import openai

def raise_once(*args: Any, **kwargs: Any) -> Any:
nonlocal completed, raised
if not raised:
raised = True
raise openai.error.APIError
completed = True
return mock_completion

mock_client.create = raise_once
callback_handler = FakeCallbackHandler()
with patch.object(
llm,
"client",
mock_client,
):
res = llm.predict("bar", callbacks=[callback_handler])
assert res == "Bar Baz"
assert completed
assert raised
assert callback_handler.retries == 1


@pytest.mark.requires("openai")
async def test_openai_async_retries(mock_completion: dict) -> None:
llm = OpenAI()
mock_client = MagicMock()
completed = False
raised = False
import openai

def raise_once(*args: Any, **kwargs: Any) -> Any:
nonlocal completed, raised
if not raised:
raised = True
raise openai.error.APIError
completed = True
return mock_completion

mock_client.create = raise_once
callback_handler = FakeAsyncCallbackHandler()
with patch.object(
llm,
"client",
mock_client,
):
res = llm.apredict("bar", callbacks=[callback_handler])
assert res == "Bar Baz"
assert completed
assert raised
assert callback_handler.retries == 1
70 changes: 69 additions & 1 deletion libs/langchain/tests/unit_tests/chat_models/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Test OpenAI Chat API wrapper."""

import json
from typing import Any
from unittest.mock import MagicMock, patch

import pytest

from langchain.chat_models.openai import (
ChatOpenAI,
_convert_dict_to_message,
)
from langchain.schema.messages import FunctionMessage
Expand All @@ -21,3 +25,67 @@ def test_function_message_dict_to_function_message() -> None:
assert isinstance(result, FunctionMessage)
assert result.name == name
assert result.content == content


@pytest.fixture
def mock_completion() -> dict:
return {
"id": "chatcmpl-7fcZavknQda3SQ",
"object": "chat.completion",
"created": 1689989000,
"model": "gpt-3.5-turbo-0613",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Bar Baz",
},
"finish_reason": "stop",
}
],
}


@pytest.mark.requires("openai")
def test_openai_predict(mock_completion: dict) -> None:
llm = ChatOpenAI()
mock_client = MagicMock()
completed = False

def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion

mock_client.create = mock_create
with patch.object(
llm,
"client",
mock_client,
):
res = llm.predict("bar")
assert res == "Bar Baz"
assert completed


@pytest.mark.requires("openai")
async def test_openai_apredict(mock_completion: dict) -> None:
llm = ChatOpenAI()
mock_client = MagicMock()
completed = False

def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion

mock_client.create = mock_create
with patch.object(
llm,
"client",
mock_client,
):
res = llm.predict("bar")
assert res == "Bar Baz"
assert completed
38 changes: 11 additions & 27 deletions libs/langchain/tests/unit_tests/llms/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from typing import Any
from unittest.mock import MagicMock, patch

import pytest

from langchain.llms.openai import OpenAI
from tests.unit_tests.callbacks.fake_callback_handler import FakeAsyncCallbackHandler, FakeCallbackHandler

os.environ["OPENAI_API_KEY"] = "foo"

Expand Down Expand Up @@ -33,7 +33,7 @@ def test_openai_incorrect_field() -> None:
@pytest.fixture
def mock_completion() -> dict:
return {
"id": "cmpl-7evkmQda5HuNqR1i4QiDOcrmVvTpd",
"id": "cmpl-3evkmQda5Hu7fcZavknQda3SQ",
"object": "text_completion",
"created": 1689989000,
"model": "text-davinci-003",
Expand All @@ -45,60 +45,44 @@ def mock_completion() -> dict:


@pytest.mark.requires("openai")
def test_openai_retries(mock_completion: dict) -> None:
def test_openai_calls(mock_completion: dict) -> None:
llm = OpenAI()
mock_client = MagicMock()
completed = False
raised = False
import openai

def raise_once(*args, **kwargs):
nonlocal completed, raised
if not raised:
raised = True
raise openai.error.APIError

def raise_once(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion

mock_client.create = raise_once
callback_handler = FakeCallbackHandler()
with patch.object(
llm,
"client",
mock_client,
):
res = llm.predict("bar", callbacks=[callback_handler])
res = llm.predict("bar")
assert res == "Bar Baz"
assert completed
assert raised
assert callback_handler.retries == 1


@pytest.mark.requires("openai")
async def test_openai_async_retries(mock_completion: dict) -> None:
llm = OpenAI()
mock_client = MagicMock()
completed = False
raised = False
import openai

def raise_once(*args, **kwargs):
nonlocal completed, raised
if not raised:
raised = True
raise openai.error.APIError

def raise_once(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion

mock_client.create = raise_once
callback_handler = FakeAsyncCallbackHandler()
with patch.object(
llm,
"client",
mock_client,
):
res = llm.apredict("bar", callbacks=[callback_handler])
res = llm.apredict("bar")
assert res == "Bar Baz"
assert completed
assert raised
assert callback_handler.retries == 1

0 comments on commit 1940d0a

Please sign in to comment.