Skip to content

Commit

Permalink
Test utils
Browse files Browse the repository at this point in the history
  • Loading branch information
20001LastOrder authored and oshoma committed Dec 19, 2023
1 parent dd8211b commit dc35650
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/sherpa_ai/test_utils/llms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import json
import re

import pytest
from langchain.base_language import BaseLanguageModel
from langchain.chat_models import ChatOpenAI
from langchain.llms import FakeListLLM

from sherpa_ai.models.chat_model_with_logging import ChatModelWithLogging
from sherpa_ai.test_utils.loggers import get_new_logger


def get_fake_llm(filename) -> FakeListLLM:
responses = []
with open(filename) as f:
for line in f:
responses.append(json.loads(line))
# restore new line characters
responses = [response["output"].replace("\\n", "\n") for response in responses]
return FakeListLLM(responses=responses)


def get_real_llm(
filename: str, model_name: str = "gpt-3.5-turbo", temperature: int = 0
) -> ChatModelWithLogging:
logger = get_new_logger(filename)
llm = ChatModelWithLogging(
llm=ChatOpenAI(model_name=model_name, temperature=temperature), logger=logger
)
return llm


def format_cache_name(filename: str, method_name: str) -> str:
filename = re.split("\\\\|/", filename)[-1].removesuffix(".py")
return f"{filename}_{method_name}.jsonl"


@pytest.fixture
def get_llm(external_api: bool):
if external_api:
import sherpa_ai.config

def get(
filename: str, method_name: str, folder: str = "data", **kwargs
) -> BaseLanguageModel:
path = re.split("(\\\\|/)tests", filename)[0]
folder = f"{path}/tests/{folder}"
print(folder)
filename = format_cache_name(filename, method_name)
filename = f"{folder}/{filename}"
if external_api:
return get_real_llm(filename, **kwargs)
else:
return get_fake_llm(filename)

return get
31 changes: 31 additions & 0 deletions src/sherpa_ai/test_utils/loggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import copy
import sys
from typing import Callable

import pytest
from loguru import logger


def get_new_logger(filename, level="INFO") -> Callable[[str], type(logger)]:
"""
Get a new logger that logs to a file while keeping the global logger unchanged.
Reference: https://github.com/Delgan/loguru/issues/487
"""
logger.remove()
child_logger = copy.deepcopy(logger)

# add handler back
logger.add(sys.stderr)

child_logger.add(filename, level=level, format="{message}", mode="w")

return child_logger


@pytest.fixture
def config_logger_level() -> Callable[[str], None]:
def get(level="DEBUG"):
logger.remove()
logger.add(sys.stderr, level=level)

return get

0 comments on commit dc35650

Please sign in to comment.