Skip to content

Commit

Permalink
Moves llm functions into subclasses (#111)
Browse files Browse the repository at this point in the history
* Create classes for LLM endpoints

* Test the openai + gemini classes
  • Loading branch information
pipliggins authored Dec 6, 2024
1 parent ba2fddc commit 96c49a1
Show file tree
Hide file tree
Showing 15 changed files with 686 additions and 343 deletions.
32 changes: 11 additions & 21 deletions src/adtl/autoparser/create_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
from pathlib import Path
from typing import Literal

import google.generativeai as gemini
import numpy as np
import pandas as pd
from openai import OpenAI

from .gemini_calls import _map_fields as _map_fields_gemini
from .gemini_calls import _map_values as _map_values_gemini
from .openai_calls import _map_fields as _map_fields_openai
from .openai_calls import _map_values as _map_values_openai
from .util import DEFAULT_CONFIG, load_data_dict, read_config_schema, read_json
from .util import (
DEFAULT_CONFIG,
load_data_dict,
read_config_schema,
read_json,
setup_llm,
)


class Mapper:
Expand Down Expand Up @@ -57,20 +57,10 @@ def __init__(
self.schema = read_json(schema)
self.schema_properties = self.schema["properties"]
self.language = language
self.api_key = api_key
if llm is None:
self.client = None
elif llm == "openai": # pragma: no cover
self.client = OpenAI(api_key=self.api_key)
self.map_fields = _map_fields_openai
self.map_values = _map_values_openai
elif llm == "gemini": # pragma: no cover
gemini.configure(api_key=self.api_key)
self.client = gemini.GenerativeModel("gemini-1.5-flash")
self.map_fields = _map_fields_gemini
self.map_values = _map_values_gemini
self.model = None
else:
raise ValueError(f"Unsupported LLM: {llm}")
self.model = setup_llm(llm, api_key)

self.config = read_config_schema(
config or Path(Path(__file__).parent, DEFAULT_CONFIG)
Expand Down Expand Up @@ -192,7 +182,7 @@ def match_fields_to_schema(self) -> pd.DataFrame:
# english translated descriptions rather than names.
source_fields = list(self.data_dictionary.source_description)

mappings = self.map_fields(source_fields, self.target_fields, self.client)
mappings = self.model.map_fields(source_fields, self.target_fields)

mapping_dict = pd.DataFrame(
{
Expand Down Expand Up @@ -229,7 +219,7 @@ def match_values_to_schema(self) -> pd.DataFrame:
values_tuples.append((f, s, t))

# to LLM
value_pairs = self.map_values(values_tuples, self.language, self.client)
value_pairs = self.model.map_values(values_tuples, self.language)

value_mapping = {}

Expand Down
51 changes: 15 additions & 36 deletions src/adtl/autoparser/dict_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
import argparse
from pathlib import Path

import google.generativeai as gemini
import numpy as np
import pandas as pd
from openai import OpenAI

from .gemini_calls import _get_definitions as _get_definitions_gemini
from .openai_calls import _get_definitions as _get_definitions_openai
from .util import DEFAULT_CONFIG, load_data_dict, read_config_schema, read_data
from .util import (
DEFAULT_CONFIG,
load_data_dict,
read_config_schema,
read_data,
setup_llm,
)


class DictWriter:
Expand All @@ -37,43 +39,19 @@ class DictWriter:
def __init__(
self,
config: Path | str | None = None,
llm: str | None = None,
api_key: str | None = None,
):
if isinstance(config, str):
config = Path(config)
self.config = read_config_schema(
config or Path(Path(__file__).parent, DEFAULT_CONFIG)
)

def _setup_llm(self, key: str, name: str):
"""
Setup the LLM to use to generate descriptions.
Separate from the __init__ method to allow for extra barrier between raw data &
LLM.
Parameters
----------
key
API key
name
Name of the LLM to use (currently only OpenAI and Gemini are supported)
"""
if key is None:
raise ValueError("API key required for generating descriptions")
else:
self.key = key

if name == "openai": # pragma: no cover
self.client = OpenAI(api_key=key)
self._get_descriptions = _get_definitions_openai

elif name == "gemini": # pragma: no cover
gemini.configure(api_key=key)
self.client = gemini.GenerativeModel("gemini-1.5-flash")
self._get_descriptions = _get_definitions_gemini

if llm and api_key:
self.model = setup_llm(llm, api_key)
else:
raise ValueError(f"Unsupported LLM: {name}")
self.model = None

def create_dict(self, data: pd.DataFrame | str) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -183,11 +161,12 @@ def generate_descriptions(

df = load_data_dict(self.config, data_dict)

self._setup_llm(key, llm)
if not self.model:
self.model = setup_llm(llm, key)

headers = df.source_field

descriptions = self._get_descriptions(list(headers), language, self.client)
descriptions = self.model.get_definitions(list(headers), language)

descriptions = {d.field_name: d.translation for d in descriptions}
df_descriptions = pd.DataFrame(
Expand Down
108 changes: 0 additions & 108 deletions src/adtl/autoparser/gemini_calls.py

This file was deleted.

30 changes: 30 additions & 0 deletions src/adtl/autoparser/language_models/base_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"Contains all functions that call OpenAI's API."

from __future__ import annotations


class LLMBase:
def __init__(self, api_key, model=None): # pragma: no cover
self.client = None
self.model = model

def get_definitions(self, headers, language): # pragma: no cover
"""
Get the definitions of the columns in the dataset.
"""
# subclasses should implement this method
raise NotImplementedError

def map_fields(self, source_fields, target_fields): # pragma: no cover
"""
Calls the OpenAI API to generate a draft mapping between two datasets.
"""
# subclasses should implement this method
raise NotImplementedError

def map_values(self, values, language): # pragma: no cover
"""
Calls the OpenAI API to generate a set of value mappings for the fields.
"""
# subclasses should implement this method
raise NotImplementedError
41 changes: 41 additions & 0 deletions src/adtl/autoparser/language_models/data_structures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Stores the data structures for using with LLM API's"""

from __future__ import annotations

from pydantic import BaseModel

# target classes for generating descriptions


class SingleField(BaseModel):
field_name: str
translation: str | None


class ColumnDescriptionRequest(BaseModel):
field_descriptions: list[SingleField]


# target classes for matching fields
class SingleMapping(BaseModel):
target_field: str
source_description: str | None


class MappingRequest(BaseModel):
targets_descriptions: list[SingleMapping]


# target classes for matching values to enum/boolean options
class ValueMapping(BaseModel):
source_value: str
target_value: str | None


class FieldMapping(BaseModel):
field_name: str
mapped_values: list[ValueMapping]


class ValuesRequest(BaseModel):
values: list[FieldMapping]
Loading

0 comments on commit 96c49a1

Please sign in to comment.