From 96c49a121b12b7e82f74f63c58fe31c430f9f358 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 6 Dec 2024 13:43:20 +0000 Subject: [PATCH] Moves llm functions into subclasses (#111) * Create classes for LLM endpoints * Test the openai + gemini classes --- src/adtl/autoparser/create_mapping.py | 32 ++-- src/adtl/autoparser/dict_writer.py | 51 ++----- src/adtl/autoparser/gemini_calls.py | 108 ------------- .../autoparser/language_models/base_llm.py | 30 ++++ .../language_models/data_structures.py | 41 +++++ src/adtl/autoparser/language_models/gemini.py | 109 +++++++++++++ src/adtl/autoparser/language_models/openai.py | 118 +++++++++++++++ src/adtl/autoparser/openai_calls.py | 120 --------------- src/adtl/autoparser/util.py | 58 +++---- tests/test_autoparser/test_dict_writer.py | 26 ++-- tests/test_autoparser/test_gemini.py | 123 +++++++++++++++ tests/test_autoparser/test_mapper.py | 23 ++- tests/test_autoparser/test_openai.py | 143 ++++++++++++++++++ tests/test_autoparser/test_utils.py | 17 ++- tests/test_autoparser/testing_data_animals.py | 30 +++- 15 files changed, 686 insertions(+), 343 deletions(-) delete mode 100644 src/adtl/autoparser/gemini_calls.py create mode 100644 src/adtl/autoparser/language_models/base_llm.py create mode 100644 src/adtl/autoparser/language_models/data_structures.py create mode 100644 src/adtl/autoparser/language_models/gemini.py create mode 100644 src/adtl/autoparser/language_models/openai.py delete mode 100644 src/adtl/autoparser/openai_calls.py create mode 100644 tests/test_autoparser/test_gemini.py create mode 100644 tests/test_autoparser/test_openai.py diff --git a/src/adtl/autoparser/create_mapping.py b/src/adtl/autoparser/create_mapping.py index e469bb2..977e132 100644 --- a/src/adtl/autoparser/create_mapping.py +++ b/src/adtl/autoparser/create_mapping.py @@ -9,16 +9,16 @@ from pathlib import Path from typing import Literal -import google.generativeai as gemini import numpy as np import pandas as pd -from openai import OpenAI -from .gemini_calls import _map_fields as _map_fields_gemini -from .gemini_calls import _map_values as _map_values_gemini -from .openai_calls import _map_fields as _map_fields_openai -from .openai_calls import _map_values as _map_values_openai -from .util import DEFAULT_CONFIG, load_data_dict, read_config_schema, read_json +from .util import ( + DEFAULT_CONFIG, + load_data_dict, + read_config_schema, + read_json, + setup_llm, +) class Mapper: @@ -57,20 +57,10 @@ def __init__( self.schema = read_json(schema) self.schema_properties = self.schema["properties"] self.language = language - self.api_key = api_key if llm is None: - self.client = None - elif llm == "openai": # pragma: no cover - self.client = OpenAI(api_key=self.api_key) - self.map_fields = _map_fields_openai - self.map_values = _map_values_openai - elif llm == "gemini": # pragma: no cover - gemini.configure(api_key=self.api_key) - self.client = gemini.GenerativeModel("gemini-1.5-flash") - self.map_fields = _map_fields_gemini - self.map_values = _map_values_gemini + self.model = None else: - raise ValueError(f"Unsupported LLM: {llm}") + self.model = setup_llm(llm, api_key) self.config = read_config_schema( config or Path(Path(__file__).parent, DEFAULT_CONFIG) @@ -192,7 +182,7 @@ def match_fields_to_schema(self) -> pd.DataFrame: # english translated descriptions rather than names. source_fields = list(self.data_dictionary.source_description) - mappings = self.map_fields(source_fields, self.target_fields, self.client) + mappings = self.model.map_fields(source_fields, self.target_fields) mapping_dict = pd.DataFrame( { @@ -229,7 +219,7 @@ def match_values_to_schema(self) -> pd.DataFrame: values_tuples.append((f, s, t)) # to LLM - value_pairs = self.map_values(values_tuples, self.language, self.client) + value_pairs = self.model.map_values(values_tuples, self.language) value_mapping = {} diff --git a/src/adtl/autoparser/dict_writer.py b/src/adtl/autoparser/dict_writer.py index 9eab539..c6ae712 100644 --- a/src/adtl/autoparser/dict_writer.py +++ b/src/adtl/autoparser/dict_writer.py @@ -7,14 +7,16 @@ import argparse from pathlib import Path -import google.generativeai as gemini import numpy as np import pandas as pd -from openai import OpenAI -from .gemini_calls import _get_definitions as _get_definitions_gemini -from .openai_calls import _get_definitions as _get_definitions_openai -from .util import DEFAULT_CONFIG, load_data_dict, read_config_schema, read_data +from .util import ( + DEFAULT_CONFIG, + load_data_dict, + read_config_schema, + read_data, + setup_llm, +) class DictWriter: @@ -37,6 +39,8 @@ class DictWriter: def __init__( self, config: Path | str | None = None, + llm: str | None = None, + api_key: str | None = None, ): if isinstance(config, str): config = Path(config) @@ -44,36 +48,10 @@ def __init__( config or Path(Path(__file__).parent, DEFAULT_CONFIG) ) - def _setup_llm(self, key: str, name: str): - """ - Setup the LLM to use to generate descriptions. - - Separate from the __init__ method to allow for extra barrier between raw data & - LLM. - - Parameters - ---------- - key - API key - name - Name of the LLM to use (currently only OpenAI and Gemini are supported) - """ - if key is None: - raise ValueError("API key required for generating descriptions") - else: - self.key = key - - if name == "openai": # pragma: no cover - self.client = OpenAI(api_key=key) - self._get_descriptions = _get_definitions_openai - - elif name == "gemini": # pragma: no cover - gemini.configure(api_key=key) - self.client = gemini.GenerativeModel("gemini-1.5-flash") - self._get_descriptions = _get_definitions_gemini - + if llm and api_key: + self.model = setup_llm(llm, api_key) else: - raise ValueError(f"Unsupported LLM: {name}") + self.model = None def create_dict(self, data: pd.DataFrame | str) -> pd.DataFrame: """ @@ -183,11 +161,12 @@ def generate_descriptions( df = load_data_dict(self.config, data_dict) - self._setup_llm(key, llm) + if not self.model: + self.model = setup_llm(llm, key) headers = df.source_field - descriptions = self._get_descriptions(list(headers), language, self.client) + descriptions = self.model.get_definitions(list(headers), language) descriptions = {d.field_name: d.translation for d in descriptions} df_descriptions = pd.DataFrame( diff --git a/src/adtl/autoparser/gemini_calls.py b/src/adtl/autoparser/gemini_calls.py deleted file mode 100644 index 748eac9..0000000 --- a/src/adtl/autoparser/gemini_calls.py +++ /dev/null @@ -1,108 +0,0 @@ -"Contains all functions that call Google's Gemini API." - -from __future__ import annotations - -import json - -import google.generativeai as gemini - -from .util import ColumnDescriptionRequest, MappingRequest, ValuesRequest - - -def _get_definitions( - headers: list[str], language: str, model: gemini.GenerativeModel -) -> dict[str, str]: - """ - Get the definitions of the columns in the dataset. - """ - result = model.generate_content( - [ - ( - "You are an expert at structured data extraction. " - "The following is a list of headers from a data file in " - f"{language}, some containing shortened words or abbreviations. " - "Translate them to english. " - "Return a list of (original header, translation) pairs, using the given structure." # noqa - "Preserve special characters such as accented letters and hyphens." - ), - f"{headers}", - ], - generation_config=gemini.GenerationConfig( - response_mime_type="application/json", - response_schema=ColumnDescriptionRequest, - ), - ) - descriptions = ColumnDescriptionRequest.model_validate( - json.loads(result.text) - ).field_descriptions - return descriptions - - -def _map_fields( - source_fields: list[str], target_fields: list[str], model: gemini.GenerativeModel -) -> MappingRequest: - """ - Calls the Gemini API to generate a draft mapping between two datasets. - """ - result = model.generate_content( - [ - ( - "You are an expert at structured data extraction. " - "You will be given two lists of phrases, one is the headers for a " - "target data file, and the other a set of descriptions for columns " - "of source data. " - "Match each target header to the best matching source description, " - "but match a header to None if a good match does not exist. " - "Preserve special characters such as accented letters and hyphens." - "Return the matched target headers and source descriptions using the provided structure." # noqa - ), - ( - f"These are the target headers: {target_fields}\n" - f"These are the source descriptions: {source_fields}" - ), - ], - generation_config=gemini.GenerationConfig( - response_mime_type="application/json", - response_schema=MappingRequest, - ), - ) - return MappingRequest.model_validate(json.loads(result.text)) - - -def _map_values( - values: list[tuple[set[str], set[str], list[str]]], - language: str, - model: gemini.GenerativeModel, -) -> ValuesRequest: - """ - Calls the Gemini API to generate a set of value mappings for the fields. - """ - result = model.generate_content( - [ - ( - "You are an expert at structured data extraction. " - "You will be given a list of tuples, where each tuple contains " - "three sets of string values. " - "The first set contains field names for a dataset." - "The second set contains values from a source dataset in " - f"{language}, and the third set contains target values for an " - "english-language transformed dataset. " - "Match all the values in the second set to the appropriate values " - "in the third set. " - "Return a list of dictionaries, where each dictionary contains the " - "field name as a key, and a dictionary containing " - "source values as keys, and the target text as values, " - "as the values. For example, the result should look like this: " - "[{'field_name_1': {'source_value_a': 'target_value_a', " - "'source_value_b': 'target_value_b'}, 'field_name_2':{...}]" - "using the provided structure." - "Preserve special characters such as accented letters and hyphens." - ), - f"These are the field, source, target value sets: {values}", - ], - generation_config=gemini.GenerationConfig( - response_mime_type="application/json", - response_schema=ValuesRequest, - ), - ) - return ValuesRequest.model_validate(json.loads(result.text)) diff --git a/src/adtl/autoparser/language_models/base_llm.py b/src/adtl/autoparser/language_models/base_llm.py new file mode 100644 index 0000000..19b4e51 --- /dev/null +++ b/src/adtl/autoparser/language_models/base_llm.py @@ -0,0 +1,30 @@ +"Contains all functions that call OpenAI's API." + +from __future__ import annotations + + +class LLMBase: + def __init__(self, api_key, model=None): # pragma: no cover + self.client = None + self.model = model + + def get_definitions(self, headers, language): # pragma: no cover + """ + Get the definitions of the columns in the dataset. + """ + # subclasses should implement this method + raise NotImplementedError + + def map_fields(self, source_fields, target_fields): # pragma: no cover + """ + Calls the OpenAI API to generate a draft mapping between two datasets. + """ + # subclasses should implement this method + raise NotImplementedError + + def map_values(self, values, language): # pragma: no cover + """ + Calls the OpenAI API to generate a set of value mappings for the fields. + """ + # subclasses should implement this method + raise NotImplementedError diff --git a/src/adtl/autoparser/language_models/data_structures.py b/src/adtl/autoparser/language_models/data_structures.py new file mode 100644 index 0000000..8a7987f --- /dev/null +++ b/src/adtl/autoparser/language_models/data_structures.py @@ -0,0 +1,41 @@ +"""Stores the data structures for using with LLM API's""" + +from __future__ import annotations + +from pydantic import BaseModel + +# target classes for generating descriptions + + +class SingleField(BaseModel): + field_name: str + translation: str | None + + +class ColumnDescriptionRequest(BaseModel): + field_descriptions: list[SingleField] + + +# target classes for matching fields +class SingleMapping(BaseModel): + target_field: str + source_description: str | None + + +class MappingRequest(BaseModel): + targets_descriptions: list[SingleMapping] + + +# target classes for matching values to enum/boolean options +class ValueMapping(BaseModel): + source_value: str + target_value: str | None + + +class FieldMapping(BaseModel): + field_name: str + mapped_values: list[ValueMapping] + + +class ValuesRequest(BaseModel): + values: list[FieldMapping] diff --git a/src/adtl/autoparser/language_models/gemini.py b/src/adtl/autoparser/language_models/gemini.py new file mode 100644 index 0000000..81b10d3 --- /dev/null +++ b/src/adtl/autoparser/language_models/gemini.py @@ -0,0 +1,109 @@ +"Contains all functions that call Google's Gemini API." + +from __future__ import annotations + +import json + +import google.generativeai as gemini + +from .base_llm import LLMBase +from .data_structures import ColumnDescriptionRequest, MappingRequest, ValuesRequest + + +class GeminiLanguageModel(LLMBase): + def __init__(self, api_key, model: str = "gemini-1.5-flash"): + gemini.configure(api_key=api_key) + self.client = gemini.GenerativeModel(model) + self.model = model + + def get_definitions(self, headers: list[str], language: str) -> dict[str, str]: + """ + Get the definitions of the columns in the dataset using the Gemini API. + """ + result = self.client.generate_content( + [ + ( + "You are an expert at structured data extraction. " + "The following is a list of headers from a data file in " + f"{language}, some containing shortened words or abbreviations. " + "Translate them to english. " + "Return a list of (original header, translation) pairs, using the given structure." # noqa + "Preserve special characters such as accented letters and hyphens." + ), + f"{headers}", + ], + generation_config=gemini.GenerationConfig( + response_mime_type="application/json", + response_schema=ColumnDescriptionRequest, + ), + ) + descriptions = ColumnDescriptionRequest.model_validate( + json.loads(result.text) + ).field_descriptions + return descriptions + + def map_fields( + self, source_fields: list[str], target_fields: list[str] + ) -> MappingRequest: + """ + Calls the Gemini API to generate a draft mapping between two datasets. + """ + result = self.client.generate_content( + [ + ( + "You are an expert at structured data extraction. " + "You will be given two lists of phrases, one is the headers for a " + "target data file, and the other a set of descriptions for columns " + "of source data. " + "Match each target header to the best matching source description, " + "but match a header to None if a good match does not exist. " + "Preserve special characters such as accented letters and hyphens." + "Return the matched target headers and source descriptions using the provided structure." # noqa + ), + ( + f"These are the target headers: {target_fields}\n" + f"These are the source descriptions: {source_fields}" + ), + ], + generation_config=gemini.GenerationConfig( + response_mime_type="application/json", + response_schema=MappingRequest, + ), + ) + return MappingRequest.model_validate(json.loads(result.text)) + + def map_values( + self, values: list[tuple[str, set[str], list[str | None] | None]], language: str + ) -> ValuesRequest: + """ + Calls the Gemini API to generate a set of value mappings for the fields. + """ + result = self.client.generate_content( + [ + ( + "You are an expert at structured data extraction. " + "You will be given a list of tuples, where each tuple contains " + "three sets of string values. " + "The first set contains field names for a dataset." + "The second set contains values from a source dataset in " + f"{language}, and the third set contains target values for an " + "english-language transformed dataset. " + "Match all the values in the second set to the appropriate values " + "in the third set. " + "Return a list of dictionaries, where each dictionary contains the " + "field name as a key, and a dictionary containing " + "source values as keys, and the target text as values, " + "as the values. For example, the result should look like this: " + "[{'field_name_1': {'source_value_a': 'target_value_a', " + "'source_value_b': 'target_value_b'}, 'field_name_2':{...}]" + "using the provided structure." + "Preserve special characters such as accented letters and hyphens." + ), + f"These are the field, source, target value sets: {values}", + ], + generation_config=gemini.GenerationConfig( + response_mime_type="application/json", + response_schema=ValuesRequest, + ), + ) + return ValuesRequest.model_validate(json.loads(result.text)) diff --git a/src/adtl/autoparser/language_models/openai.py b/src/adtl/autoparser/language_models/openai.py new file mode 100644 index 0000000..2fb2525 --- /dev/null +++ b/src/adtl/autoparser/language_models/openai.py @@ -0,0 +1,118 @@ +"Contains all functions that call OpenAI's API." + +from __future__ import annotations + +from openai import OpenAI + +from .base_llm import LLMBase +from .data_structures import ColumnDescriptionRequest, MappingRequest, ValuesRequest + + +class OpenAILanguageModel(LLMBase): + def __init__(self, api_key, model: str = "gpt-4o-mini"): + self.client = OpenAI(api_key=api_key) + self.model = model + + def get_definitions(self, headers: list[str], language: str) -> dict[str, str]: + """ + Get the definitions of the columns in the dataset. + """ + completion = self.client.beta.chat.completions.parse( + model=self.model, + messages=[ + { + "role": "system", + "content": ( + "You are an expert at structured data extraction. " + "The following is a list of headers from a data file in " + f"{language}, some containing shortened words or abbreviations. " # noqa + "Translate them to english. " + "Return a list of (original header, translation) pairs, using the given structure." # noqa + ), + }, + {"role": "user", "content": f"{headers}"}, + ], + response_format=ColumnDescriptionRequest, + ) + descriptions = completion.choices[0].message.parsed.field_descriptions + + return descriptions + + def map_fields( + self, source_fields: list[str], target_fields: list[str] + ) -> MappingRequest: + """ + Calls the OpenAI API to generate a draft mapping between two datasets. + """ + field_mapping = self.client.beta.chat.completions.parse( + model=self.model, + messages=[ + { + "role": "system", + "content": ( + "You are an expert at structured data extraction. " + "You will be given two lists of phrases, one is the headers " + "for a target data file, and the other a set of descriptions " + "for columns of source data. " + "Match each target header to the best matching source " + "description, but match a header to None if a good match does " + "not exist. " + "Return the matched target headers and source descriptions using the provided structure." # noqa + ), + }, + { + "role": "user", + "content": ( + f"These are the target headers: {target_fields}\n" + f"These are the source descriptions: {source_fields}" + ), + }, + ], + response_format=MappingRequest, + ) + mappings = field_mapping.choices[0].message.parsed + + return mappings + + def map_values( + self, values: list[tuple[str, set[str], list[str | None] | None]], language: str + ) -> ValuesRequest: + """ + Calls the OpenAI API to generate a set of value mappings for the fields. + """ + value_mapping = self.client.beta.chat.completions.parse( + model=self.model, + messages=[ + { + "role": "system", + "content": ( + "You are an expert at structured data extraction. " + "You will be given a list of tuples, where each tuple contains " + "three sets of string values. " + "The first set contains field names for a dataset." + "The second set contains values from a source dataset in " + f"{language}, and the third set contains target values for an " + "english-language transformed dataset. " + "Match all the values in the second set to the appropriate " + "values in the third set. " + "Return a list of dictionaries, where each dictionary contains " + "the field name as a key, and a dictionary containing " + "source values as keys, and the target text as values, " + "as the values. For example, the result should look like this: " + "[{'field_name_1': {'source_value_a': 'target_value_a', " + "'source_value_b': 'target_value_b'}, 'field_name_2':{...}]" + "using the provided structure." + ), + }, + { + "role": "user", + "content": ( + f"These are the field, source, target value sets: {values}" + ), + }, + ], + response_format=ValuesRequest, + ) + mappings = value_mapping.choices[0].message.parsed + + return mappings diff --git a/src/adtl/autoparser/openai_calls.py b/src/adtl/autoparser/openai_calls.py deleted file mode 100644 index 90a25d3..0000000 --- a/src/adtl/autoparser/openai_calls.py +++ /dev/null @@ -1,120 +0,0 @@ -"Contains all functions that call OpenAI's API." - -from __future__ import annotations - -from openai import OpenAI - -from .util import ColumnDescriptionRequest, MappingRequest, ValuesRequest - - -def _get_definitions( - headers: list[str], language: str, client: OpenAI -) -> dict[str, str]: - """ - Get the definitions of the columns in the dataset. - """ - completion = client.beta.chat.completions.parse( - model="gpt-4o-mini", - messages=[ - { - "role": "system", - "content": ( - "You are an expert at structured data extraction. " - "The following is a list of headers from a data file in " - f"{language}, some containing shortened words or abbreviations. " - "Translate them to english. " - # "Return a dictionary where the keys are the original headers, " - # "and the values the translations, using the given structure." - "Return a list of (original header, translation) pairs, using the given structure." # noqa - ), - }, - {"role": "user", "content": f"{headers}"}, - ], - response_format=ColumnDescriptionRequest, - ) - descriptions = completion.choices[0].message.parsed.field_descriptions - - return descriptions - - -def _map_fields( - source_fields: list[str], target_fields: list[str], client: OpenAI -) -> MappingRequest: - """ - Calls the OpenAI API to generate a draft mapping between two datasets. - """ - field_mapping = client.beta.chat.completions.parse( - model="gpt-4o-mini", - messages=[ - { - "role": "system", - "content": ( - "You are an expert at structured data extraction. " - "You will be given two lists of phrases, one is the headers for a " - "target data file, and the other a set of descriptions for columns " - "of source data. " - "Match each target header to the best matching source description, " - "but match a header to None if a good match does not exist. " - # "Return the target headers and descriptions as a dictionary of " - # "key-value pairs, where the header is the key and the description, " # noqa - # "or None, is the value, using the provided structure." - "Return the matched target headers and source descriptions using the provided structure." # noqa - ), - }, - { - "role": "user", - "content": ( - f"These are the target headers: {target_fields}\n" - f"These are the source descriptions: {source_fields}" - ), - }, - ], - response_format=MappingRequest, - ) - mappings = field_mapping.choices[0].message.parsed - - return mappings - - -def _map_values( - values: list[tuple[set[str], set[str], list[str]]], language: str, client: OpenAI -) -> ValuesRequest: - """ - Calls the OpenAI API to generate a set of value mappings for the fields. - """ - value_mapping = client.beta.chat.completions.parse( - model="gpt-4o-mini", - messages=[ - { - "role": "system", - "content": ( - "You are an expert at structured data extraction. " - "You will be given a list of tuples, where each tuple contains " - "three sets of string values. " - "The first set contains field names for a dataset." - "The second set contains values from a source dataset in " - f"{language}, and the third set contains target values for an " - "english-language transformed dataset. " - "Match all the values in the second set to the appropriate values " - "in the third set. " - "Return a list of dictionaries, where each dictionary contains the " - "field name as a key, and a dictionary containing " - "source values as keys, and the target text as values, " - "as the values. For example, the result should look like this: " - "[{'field_name_1': {'source_value_a': 'target_value_a', " - "'source_value_b': 'target_value_b'}, 'field_name_2':{...}]" - "using the provided structure." - ), - }, - { - "role": "user", - "content": ( - f"These are the field, source, target value sets: {values}" - ), - }, - ], - response_format=ValuesRequest, - ) - mappings = value_mapping.choices[0].message.parsed - - return mappings diff --git a/src/adtl/autoparser/util.py b/src/adtl/autoparser/util.py index ed53ad4..a106698 100644 --- a/src/adtl/autoparser/util.py +++ b/src/adtl/autoparser/util.py @@ -10,7 +10,9 @@ import pandas as pd import tomli -from pydantic import BaseModel + +from adtl.autoparser.language_models.gemini import GeminiLanguageModel +from adtl.autoparser.language_models.openai import OpenAILanguageModel DEFAULT_CONFIG = "config/autoparser.toml" @@ -106,40 +108,26 @@ def load_data_dict( return data_dict -# Data structures for llm calls -------------------------- - -# target classes for generating descriptions - - -class SingleField(BaseModel): - field_name: str - translation: str | None - - -class ColumnDescriptionRequest(BaseModel): - field_descriptions: list[SingleField] - +def setup_llm(provider, api_key): + """ + Setup the LLM to use to generate descriptions. -# target classes for matching fields -class SingleMapping(BaseModel): - target_field: str - source_description: str | None + Separate from the __init__ method to allow for extra barrier between raw data & + LLM. + Parameters + ---------- + key + API key + name + Name of the LLM to use (currently only OpenAI and Gemini are supported) + """ + if api_key is None: + raise ValueError("API key required to set up an LLM") -class MappingRequest(BaseModel): - targets_descriptions: list[SingleMapping] - - -# target classes for matching values to enum/boolean options -class ValueMapping(BaseModel): - source_value: str - target_value: str | None - - -class FieldMapping(BaseModel): - field_name: str - mapped_values: list[ValueMapping] - - -class ValuesRequest(BaseModel): - values: list[FieldMapping] + if provider == "openai": # pragma: no cover + return OpenAILanguageModel(api_key=api_key) + elif provider == "gemini": # pragma: no cover + return GeminiLanguageModel(api_key=api_key) + else: + raise ValueError(f"Unsupported LLM provider: {provider}") diff --git a/tests/test_autoparser/test_dict_writer.py b/tests/test_autoparser/test_dict_writer.py index 8a423eb..838f436 100644 --- a/tests/test_autoparser/test_dict_writer.py +++ b/tests/test_autoparser/test_dict_writer.py @@ -5,28 +5,17 @@ import pandas as pd import pytest -from testing_data_animals import get_definitions +from testing_data_animals import TestLLM import adtl.autoparser as autoparser from adtl.autoparser.dict_writer import DictWriter +from adtl.autoparser.language_models.openai import OpenAILanguageModel CONFIG_PATH = "tests/test_autoparser/test_config.toml" SOURCES = "tests/test_autoparser/sources/" SCHEMAS = "tests/test_autoparser/schemas/" -class DictWriterTest(DictWriter): - def __init__( - self, - config: Path | None = None, - ): - super().__init__(config) - - def _setup_llm(self, key, name): - self.client = None - self._get_descriptions = get_definitions - - def test_unsupported_data_format_txt(): writer = DictWriter(config=CONFIG_PATH) @@ -71,7 +60,8 @@ def test_dictionary_creation_no_descrip_excel_dataframe(): def test_dictionary_description(): - writer = DictWriterTest(config=Path(CONFIG_PATH)) + writer = DictWriter(config=Path(CONFIG_PATH)) + writer.model = TestLLM() # check descriptions aren't generated without a dictionary with pytest.raises(ValueError, match="No data dictionary found"): @@ -92,7 +82,13 @@ def test_missing_key_error(): def test_wrong_llm_error(): - with pytest.raises(ValueError, match="Unsupported LLM: fish"): + with pytest.raises(ValueError, match="Unsupported LLM provider: fish"): DictWriter(config=Path(CONFIG_PATH)).generate_descriptions( "fr", SOURCES + "animals_dd.csv", key="a12b3c", llm="fish" ) + + +def test_init_with_llm(): + # test no errors occur + writer = DictWriter(config=Path(CONFIG_PATH), api_key="1234", llm="openai") + assert isinstance(writer.model, OpenAILanguageModel) diff --git a/tests/test_autoparser/test_gemini.py b/tests/test_autoparser/test_gemini.py new file mode 100644 index 0000000..fab354a --- /dev/null +++ b/tests/test_autoparser/test_gemini.py @@ -0,0 +1,123 @@ +"Tests the OpenAILanguageModel class." + +from google.generativeai import protos +from google.generativeai.types import GenerateContentResponse +from testing_data_animals import get_definitions, map_fields, map_values + +from adtl.autoparser.language_models.gemini import GeminiLanguageModel + + +def test_init(): + model = GeminiLanguageModel("1234") + + assert model.client is not None + assert model.model == "gemini-1.5-flash" + + +def test_get_definitions(monkeypatch): + model = GeminiLanguageModel("1234") + + # Define test inputs + headers = ["foo", "bar", "baz"] + language = "fr" + + # Define the mocked response + def mock_generate_content(*args, **kwargs): + json_str = '{"field_descriptions": [{"field_name": "Identité", "translation": "Identity"}, {"field_name": "Province", "translation": "Province"}, {"field_name": "DateNotification", "translation": "Notification Date"}, {"field_name": "Classicfication ", "translation": "Classification"}, {"field_name": "Nom complet ", "translation": "Full Name"}, {"field_name": "Date de naissance", "translation": "Date of Birth"}, {"field_name": "AgeAns", "translation": "Age in Years"}, {"field_name": "AgeMois ", "translation": "Age in Months"}, {"field_name": "Sexe", "translation": "Gender"}, {"field_name": "StatusCas", "translation": "Case Status"}, {"field_name": "DateDec", "translation": "Date of Death"}, {"field_name": "ContSoins ", "translation": "Care Contact"}, {"field_name": "ContHumain Autre", "translation": "Other Human Contact"}, {"field_name": "AutreContHumain", "translation": "Other Human Contact"}, {"field_name": "ContactAnimal", "translation": "Animal Contact"}, {"field_name": "Micropucé", "translation": "Microchipped"}, {"field_name": "AnimalDeCompagnie", "translation": "Pet Animal"}]}' # noqa + res = protos.GenerateContentResponse( + candidates=[ + protos.Candidate( + content=protos.Content( + parts=[protos.Part(text=json_str)], role="model" + ), + finish_reason="STOP", + ) + ] + ) + + return GenerateContentResponse(done=True, iterator=None, result=res, chunks=[]) + + # Mock the parse method using monkeypatch + monkeypatch.setattr(model.client, "generate_content", mock_generate_content) + + # Call the function + result = model.get_definitions(headers, language) + + # Assert the expected output + assert result == get_definitions() + + +def test_map_fields(monkeypatch): + model = GeminiLanguageModel("1234") + + # Define test inputs + source_fields = ["nom", "âge", "localisation"] + target_fields = ["name", "age", "location"] + + # Define the mocked response + def mock_generate_content(*args, **kwargs): + json_str = '{"targets_descriptions": [{"source_description": "Identity", "target_field": "identity"}, {"source_description": "Full Name", "target_field": "name"}, {"source_description": "Province", "target_field": "loc_admin_1"}, {"source_description": null, "target_field": "country_iso3"}, {"source_description": "Notification Date", "target_field": "notification_date"}, {"source_description": "Classification", "target_field": "classification"}, {"source_description": "Case Status", "target_field": "case_status"}, {"source_description": "Death Date", "target_field": "date_of_death"}, {"source_description": "Age in Years", "target_field": "age_years"}, {"source_description": "Age in Months", "target_field": "age_months"}, {"source_description": "Gender", "target_field": "sex"}, {"source_description": "Pet Animal", "target_field": "pet"}, {"source_description": "Microchipped", "target_field": "chipped"}, {"source_description": null, "target_field": "owner"}]}' # noqa + res = protos.GenerateContentResponse( + candidates=[ + protos.Candidate( + content=protos.Content( + parts=[protos.Part(text=json_str)], role="model" + ), + finish_reason="STOP", + ) + ] + ) + + return GenerateContentResponse(done=True, iterator=None, result=res, chunks=[]) + + # Mock the parse method using monkeypatch + monkeypatch.setattr(model.client, "generate_content", mock_generate_content) + + # Call the function + result = model.map_fields(source_fields, target_fields) + + # Assert the expected output + assert result == map_fields() + + +def test_map_values(monkeypatch): + model = GeminiLanguageModel("1234") + + # Define test inputs + fields = ["loc", "status", "pet"] + source_values = [ + {"orientale", "katanga", "kinshasa", "equateur"}, + {"vivant", "décédé"}, + {"oui", "non"}, + ] + target_values = [ + None, + ["alive", "dead", "unknown", None], + ["True", "False", "None"], + ] + values = list(zip(fields, source_values, target_values)) + + # Define the mocked response + def mock_generate_content(*args, **kwargs): + json_str = '{"values": [{"field_name": "classification", "mapped_values": [{"source_value": "mammifère", "target_value": "mammal"}, {"source_value": "fish", "target_value": "fish"}, {"source_value": "poisson", "target_value": "fish"}, {"source_value": "amphibie", "target_value": "amphibian"}, {"source_value": "oiseau", "target_value": "bird"}, {"source_value": "autre", "target_value": null}, {"source_value": "rept", "target_value": "reptile"}]}, {"field_name": "case_status", "mapped_values": [{"source_value": "vivant", "target_value": "alive"}, {"source_value": "décédé", "target_value": "dead"}]}, {"field_name": "sex", "mapped_values": [{"source_value": "m", "target_value": "male"}, {"source_value": "f", "target_value": "female"}, {"source_value": "inconnu", "target_value": null}]}, {"field_name": "pet", "mapped_values": [{"source_value": "oui", "target_value": "True"}, {"source_value": "non", "target_value": "False"}]}, {"field_name": "chipped", "mapped_values": [{"source_value": "oui", "target_value": "True"}, {"source_value": "non", "target_value": "False"}]}]}' # noqa + res = protos.GenerateContentResponse( + candidates=[ + protos.Candidate( + content=protos.Content( + parts=[protos.Part(text=json_str)], role="model" + ), + finish_reason="STOP", + ) + ] + ) + + return GenerateContentResponse(done=True, iterator=None, result=res, chunks=[]) + + # Mock the parse method using monkeypatch + monkeypatch.setattr(model.client, "generate_content", mock_generate_content) + + # Call the function + result = model.map_values(values, "fr") + + # Assert the expected output + assert result == map_values() diff --git a/tests/test_autoparser/test_mapper.py b/tests/test_autoparser/test_mapper.py index a42bb58..46dc894 100644 --- a/tests/test_autoparser/test_mapper.py +++ b/tests/test_autoparser/test_mapper.py @@ -7,9 +7,10 @@ import numpy.testing as npt import pandas as pd import pytest -from testing_data_animals import map_fields, map_values +from testing_data_animals import TestLLM from adtl.autoparser.create_mapping import Mapper +from adtl.autoparser.language_models.openai import OpenAILanguageModel class MapperTest(Mapper): @@ -29,9 +30,7 @@ def __init__( None, ) - # overwrite the LLM API's with dummy functions containing base data - self.map_fields = map_fields - self.map_values = map_values + self.model = TestLLM() ANIMAL_MAPPER = MapperTest( @@ -217,11 +216,12 @@ def test_common_values_mapped_fields_error(): def test_mapper_class_init_raises(): - with pytest.raises(ValueError, match="Unsupported LLM: fish"): + with pytest.raises(ValueError, match="Unsupported LLM provider: fish"): Mapper( Path("tests/test_autoparser/schemas/animals.schema.json"), "tests/test_autoparser/sources/animals_dd_described.csv", "fr", + api_key="1234", llm="fish", ) @@ -235,13 +235,24 @@ def test_mapper_class_init(): ) assert mapper.language == "fr" - assert mapper.client is None + assert mapper.model is None npt.assert_array_equal( mapper.data_dictionary.columns, ["source_field", "source_description", "source_type", "common_values"], ) +def test_mapper_class_init_with_llm(): + mapper = Mapper( + Path("tests/test_autoparser/schemas/animals.schema.json"), + "tests/test_autoparser/sources/animals_dd_described.csv", + "fr", + api_key="abcd", + ) + + assert isinstance(mapper.model, OpenAILanguageModel) + + def test_match_fields_to_schema_dummy_data(): mapper = ANIMAL_MAPPER diff --git a/tests/test_autoparser/test_openai.py b/tests/test_autoparser/test_openai.py new file mode 100644 index 0000000..b3b1a73 --- /dev/null +++ b/tests/test_autoparser/test_openai.py @@ -0,0 +1,143 @@ +"Tests the OpenAILanguageModel class." + +import datetime + +from openai.types.chat.parsed_chat_completion import ( + ParsedChatCompletion, + ParsedChatCompletionMessage, + ParsedChoice, +) +from testing_data_animals import get_definitions, map_fields, map_values + +from adtl.autoparser.language_models.data_structures import ColumnDescriptionRequest +from adtl.autoparser.language_models.openai import OpenAILanguageModel + + +def test_init(): + model = OpenAILanguageModel("1234") + + assert model.client is not None + assert model.model == "gpt-4o-mini" + + +def test_get_definitions(monkeypatch): + model = OpenAILanguageModel("1234") + + # Define test inputs + headers = ["foo", "bar", "baz"] + language = "fr" + + # Define the mocked response + def mock_parse(*args, **kwargs): + return ParsedChatCompletion( + id="foo", + model="gpt-4o-mini", + object="chat.completion", + choices=[ + ParsedChoice( + message=ParsedChatCompletionMessage( + content='{"field_descriptions":[{"field_name":"Identité","translation":"Identity"},{"field_name":"Province","translation":"Province"},{"field_name":"DateNotification","translation":"Notification Date"},{"field_name":"Classicfication ","translation":"Classification"},{"field_name":"Nom complet ","translation":"Full Name"},{"field_name":"Date de naissance","translation":"Date of Birth"},{"field_name":"AgeAns","translation":"Age Years"},{"field_name":"AgeMois ","translation":"Age Months"},{"field_name":"Sexe","translation":"Sex"},{"field_name":"StatusCas","translation":"Case Status"},{"field_name":"DateDec","translation":"Date of Death"},{"field_name":"ContSoins ","translation":"Care Contact"},{"field_name":"ContHumain Autre","translation":"Other Human Contact"},{"field_name":"AutreContHumain","translation":"Other Human Contact"},{"field_name":"ContactAnimal","translation":"Animal Contact"},{"field_name":"Micropucé","translation":"Microchipped"},{"field_name":"AnimalDeCompagnie","translation":"Pet"}]}', # noqa + role="assistant", + parsed=ColumnDescriptionRequest( + field_descriptions=get_definitions() + ), + ), + finish_reason="stop", + index=0, + ) + ], + created=int(datetime.datetime.now().timestamp()), + ) + + # Mock the parse method using monkeypatch + monkeypatch.setattr(model.client.beta.chat.completions, "parse", mock_parse) + + # Call the function + result = model.get_definitions(headers, language) + + # Assert the expected output + assert result == get_definitions() + + +def test_map_fields(monkeypatch): + model = OpenAILanguageModel("1234") + + # Define test inputs + source_fields = ["nom", "âge", "localisation"] + target_fields = ["name", "age", "location"] + + # Define the mocked response + def mock_parse(*args, **kwargs): + return ParsedChatCompletion( + id="foo", + model="gpt-4o-mini", + object="chat.completion", + choices=[ + ParsedChoice( + message=ParsedChatCompletionMessage( + content="", # noqa + role="assistant", + parsed=map_fields(), + ), + finish_reason="stop", + index=0, + ) + ], + created=int(datetime.datetime.now().timestamp()), + ) + + # Mock the parse method using monkeypatch + monkeypatch.setattr(model.client.beta.chat.completions, "parse", mock_parse) + + # Call the function + result = model.map_fields(source_fields, target_fields) + + # Assert the expected output + assert result == map_fields() + + +def test_map_values(monkeypatch): + model = OpenAILanguageModel("1234") + + # Define test inputs + fields = ["loc", "status", "pet"] + source_values = [ + {"orientale", "katanga", "kinshasa", "equateur"}, + {"vivant", "décédé"}, + {"oui", "non"}, + ] + target_values = [ + None, + ["alive", "dead", "unknown", None], + ["True", "False", "None"], + ] + values = list(zip(fields, source_values, target_values)) + + # Define the mocked response + def mock_parse(*args, **kwargs): + return ParsedChatCompletion( + id="foo", + model="gpt-4o-mini", + object="chat.completion", + choices=[ + ParsedChoice( + message=ParsedChatCompletionMessage( + content="", # noqa + role="assistant", + parsed=map_values(), + ), + finish_reason="stop", + index=0, + ) + ], + created=int(datetime.datetime.now().timestamp()), + ) + + # Mock the parse method using monkeypatch + monkeypatch.setattr(model.client.beta.chat.completions, "parse", mock_parse) + + # Call the function + result = model.map_values(values, "fr") + + # Assert the expected output + assert result == map_values() diff --git a/tests/test_autoparser/test_utils.py b/tests/test_autoparser/test_utils.py index 955f63b..3cbb3d1 100644 --- a/tests/test_autoparser/test_utils.py +++ b/tests/test_autoparser/test_utils.py @@ -6,7 +6,12 @@ import pandas as pd import pytest -from adtl.autoparser.util import load_data_dict, parse_choices, read_config_schema +from adtl.autoparser.util import ( + load_data_dict, + parse_choices, + read_config_schema, + setup_llm, +) CONFIG = read_config_schema(Path("tests/test_autoparser/test_config.toml")) @@ -86,3 +91,13 @@ def test_load_data_dict(): with pytest.raises(ValueError, match="Unsupported format"): load_data_dict(CONFIG, "tests/test_autoparser/sources/animals.txt") + + +def test_setup_llm_no_key(): + with pytest.raises(ValueError, match="API key required to set up an LLM"): + setup_llm("openai", None) + + +def test_setup_llm_bad_provider(): + with pytest.raises(ValueError, match="Unsupported LLM provider: fish"): + setup_llm("fish", "abcd") diff --git a/tests/test_autoparser/testing_data_animals.py b/tests/test_autoparser/testing_data_animals.py index f918f5a..74251ae 100644 --- a/tests/test_autoparser/testing_data_animals.py +++ b/tests/test_autoparser/testing_data_animals.py @@ -2,7 +2,8 @@ from __future__ import annotations -from adtl.autoparser.util import ( +from adtl.autoparser.language_models.base_llm import LLMBase +from adtl.autoparser.language_models.data_structures import ( ColumnDescriptionRequest, FieldMapping, MappingRequest, @@ -116,3 +117,30 @@ def map_values(*args): mapping = ValuesRequest(values=vm) return mapping + + +class TestLLM(LLMBase): + def __init__(self): + self.client = None + self.model = None + + def get_definitions(self, headers, language): + """ + Get the definitions of the columns in the dataset. + """ + translated_fields = get_definitions(headers, language) + return translated_fields + + def map_fields(self, source_fields, target_fields): + """ + Calls the OpenAI API to generate a draft mapping between two datasets. + """ + mapping = map_fields(source_fields, target_fields) + return mapping + + def map_values(self, values, language): + """ + Calls the OpenAI API to generate a set of value mappings for the fields. + """ + value_mapping = map_values(values, language) + return value_mapping