From dc35650233659a68db989dfdff47814f9b46e29f Mon Sep 17 00:00:00 2001 From: Boqi Chen Date: Tue, 19 Dec 2023 11:25:52 -0500 Subject: [PATCH] Test utils --- src/sherpa_ai/test_utils/llms.py | 56 +++++++++++++++++++++++++++++ src/sherpa_ai/test_utils/loggers.py | 31 ++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 src/sherpa_ai/test_utils/llms.py create mode 100644 src/sherpa_ai/test_utils/loggers.py diff --git a/src/sherpa_ai/test_utils/llms.py b/src/sherpa_ai/test_utils/llms.py new file mode 100644 index 00000000..5318b9d5 --- /dev/null +++ b/src/sherpa_ai/test_utils/llms.py @@ -0,0 +1,56 @@ +import json +import re + +import pytest +from langchain.base_language import BaseLanguageModel +from langchain.chat_models import ChatOpenAI +from langchain.llms import FakeListLLM + +from sherpa_ai.models.chat_model_with_logging import ChatModelWithLogging +from sherpa_ai.test_utils.loggers import get_new_logger + + +def get_fake_llm(filename) -> FakeListLLM: + responses = [] + with open(filename) as f: + for line in f: + responses.append(json.loads(line)) + # restore new line characters + responses = [response["output"].replace("\\n", "\n") for response in responses] + return FakeListLLM(responses=responses) + + +def get_real_llm( + filename: str, model_name: str = "gpt-3.5-turbo", temperature: int = 0 +) -> ChatModelWithLogging: + logger = get_new_logger(filename) + llm = ChatModelWithLogging( + llm=ChatOpenAI(model_name=model_name, temperature=temperature), logger=logger + ) + return llm + + +def format_cache_name(filename: str, method_name: str) -> str: + filename = re.split("\\\\|/", filename)[-1].removesuffix(".py") + return f"{filename}_{method_name}.jsonl" + + +@pytest.fixture +def get_llm(external_api: bool): + if external_api: + import sherpa_ai.config + + def get( + filename: str, method_name: str, folder: str = "data", **kwargs + ) -> BaseLanguageModel: + path = re.split("(\\\\|/)tests", filename)[0] + folder = f"{path}/tests/{folder}" + print(folder) + filename = format_cache_name(filename, method_name) + filename = f"{folder}/{filename}" + if external_api: + return get_real_llm(filename, **kwargs) + else: + return get_fake_llm(filename) + + return get diff --git a/src/sherpa_ai/test_utils/loggers.py b/src/sherpa_ai/test_utils/loggers.py new file mode 100644 index 00000000..140679c0 --- /dev/null +++ b/src/sherpa_ai/test_utils/loggers.py @@ -0,0 +1,31 @@ +import copy +import sys +from typing import Callable + +import pytest +from loguru import logger + + +def get_new_logger(filename, level="INFO") -> Callable[[str], type(logger)]: + """ + Get a new logger that logs to a file while keeping the global logger unchanged. + Reference: https://github.com/Delgan/loguru/issues/487 + """ + logger.remove() + child_logger = copy.deepcopy(logger) + + # add handler back + logger.add(sys.stderr) + + child_logger.add(filename, level=level, format="{message}", mode="w") + + return child_logger + + +@pytest.fixture +def config_logger_level() -> Callable[[str], None]: + def get(level="DEBUG"): + logger.remove() + logger.add(sys.stderr, level=level) + + return get