-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dd8211b
commit dc35650
Showing
2 changed files
with
87 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import json | ||
import re | ||
|
||
import pytest | ||
from langchain.base_language import BaseLanguageModel | ||
from langchain.chat_models import ChatOpenAI | ||
from langchain.llms import FakeListLLM | ||
|
||
from sherpa_ai.models.chat_model_with_logging import ChatModelWithLogging | ||
from sherpa_ai.test_utils.loggers import get_new_logger | ||
|
||
|
||
def get_fake_llm(filename) -> FakeListLLM: | ||
responses = [] | ||
with open(filename) as f: | ||
for line in f: | ||
responses.append(json.loads(line)) | ||
# restore new line characters | ||
responses = [response["output"].replace("\\n", "\n") for response in responses] | ||
return FakeListLLM(responses=responses) | ||
|
||
|
||
def get_real_llm( | ||
filename: str, model_name: str = "gpt-3.5-turbo", temperature: int = 0 | ||
) -> ChatModelWithLogging: | ||
logger = get_new_logger(filename) | ||
llm = ChatModelWithLogging( | ||
llm=ChatOpenAI(model_name=model_name, temperature=temperature), logger=logger | ||
) | ||
return llm | ||
|
||
|
||
def format_cache_name(filename: str, method_name: str) -> str: | ||
filename = re.split("\\\\|/", filename)[-1].removesuffix(".py") | ||
return f"{filename}_{method_name}.jsonl" | ||
|
||
|
||
@pytest.fixture | ||
def get_llm(external_api: bool): | ||
if external_api: | ||
import sherpa_ai.config | ||
|
||
def get( | ||
filename: str, method_name: str, folder: str = "data", **kwargs | ||
) -> BaseLanguageModel: | ||
path = re.split("(\\\\|/)tests", filename)[0] | ||
folder = f"{path}/tests/{folder}" | ||
print(folder) | ||
filename = format_cache_name(filename, method_name) | ||
filename = f"{folder}/{filename}" | ||
if external_api: | ||
return get_real_llm(filename, **kwargs) | ||
else: | ||
return get_fake_llm(filename) | ||
|
||
return get |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import copy | ||
import sys | ||
from typing import Callable | ||
|
||
import pytest | ||
from loguru import logger | ||
|
||
|
||
def get_new_logger(filename, level="INFO") -> Callable[[str], type(logger)]: | ||
""" | ||
Get a new logger that logs to a file while keeping the global logger unchanged. | ||
Reference: https://github.com/Delgan/loguru/issues/487 | ||
""" | ||
logger.remove() | ||
child_logger = copy.deepcopy(logger) | ||
|
||
# add handler back | ||
logger.add(sys.stderr) | ||
|
||
child_logger.add(filename, level=level, format="{message}", mode="w") | ||
|
||
return child_logger | ||
|
||
|
||
@pytest.fixture | ||
def config_logger_level() -> Callable[[str], None]: | ||
def get(level="DEBUG"): | ||
logger.remove() | ||
logger.add(sys.stderr, level=level) | ||
|
||
return get |