Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moves llm functions into subclasses #111

Merged
merged 8 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 11 additions & 21 deletions src/adtl/autoparser/create_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
from pathlib import Path
from typing import Literal

import google.generativeai as gemini
import numpy as np
import pandas as pd
from openai import OpenAI

from .gemini_calls import _map_fields as _map_fields_gemini
from .gemini_calls import _map_values as _map_values_gemini
from .openai_calls import _map_fields as _map_fields_openai
from .openai_calls import _map_values as _map_values_openai
from .util import DEFAULT_CONFIG, load_data_dict, read_config_schema, read_json
from .util import (
DEFAULT_CONFIG,
load_data_dict,
read_config_schema,
read_json,
setup_llm,
)


class Mapper:
Expand Down Expand Up @@ -57,20 +57,10 @@ def __init__(
self.schema = read_json(schema)
self.schema_properties = self.schema["properties"]
self.language = language
self.api_key = api_key
if llm is None:
self.client = None
elif llm == "openai": # pragma: no cover
self.client = OpenAI(api_key=self.api_key)
self.map_fields = _map_fields_openai
self.map_values = _map_values_openai
elif llm == "gemini": # pragma: no cover
gemini.configure(api_key=self.api_key)
self.client = gemini.GenerativeModel("gemini-1.5-flash")
self.map_fields = _map_fields_gemini
self.map_values = _map_values_gemini
self.model = None
else:
raise ValueError(f"Unsupported LLM: {llm}")
self.model = setup_llm(llm, api_key)

self.config = read_config_schema(
config or Path(Path(__file__).parent, DEFAULT_CONFIG)
Expand Down Expand Up @@ -192,7 +182,7 @@ def match_fields_to_schema(self) -> pd.DataFrame:
# english translated descriptions rather than names.
source_fields = list(self.data_dictionary.source_description)

mappings = self.map_fields(source_fields, self.target_fields, self.client)
mappings = self.model.map_fields(source_fields, self.target_fields)

mapping_dict = pd.DataFrame(
{
Expand Down Expand Up @@ -229,7 +219,7 @@ def match_values_to_schema(self) -> pd.DataFrame:
values_tuples.append((f, s, t))

# to LLM
value_pairs = self.map_values(values_tuples, self.language, self.client)
value_pairs = self.model.map_values(values_tuples, self.language)

value_mapping = {}

Expand Down
51 changes: 15 additions & 36 deletions src/adtl/autoparser/dict_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
import argparse
from pathlib import Path

import google.generativeai as gemini
import numpy as np
import pandas as pd
from openai import OpenAI

from .gemini_calls import _get_definitions as _get_definitions_gemini
from .openai_calls import _get_definitions as _get_definitions_openai
from .util import DEFAULT_CONFIG, load_data_dict, read_config_schema, read_data
from .util import (
DEFAULT_CONFIG,
load_data_dict,
read_config_schema,
read_data,
setup_llm,
)


class DictWriter:
Expand All @@ -37,43 +39,19 @@ class DictWriter:
def __init__(
self,
config: Path | str | None = None,
llm: str | None = None,
api_key: str | None = None,
):
if isinstance(config, str):
config = Path(config)
self.config = read_config_schema(
config or Path(Path(__file__).parent, DEFAULT_CONFIG)
)

def _setup_llm(self, key: str, name: str):
"""
Setup the LLM to use to generate descriptions.

Separate from the __init__ method to allow for extra barrier between raw data &
LLM.

Parameters
----------
key
API key
name
Name of the LLM to use (currently only OpenAI and Gemini are supported)
"""
if key is None:
raise ValueError("API key required for generating descriptions")
else:
self.key = key

if name == "openai": # pragma: no cover
self.client = OpenAI(api_key=key)
self._get_descriptions = _get_definitions_openai

elif name == "gemini": # pragma: no cover
gemini.configure(api_key=key)
self.client = gemini.GenerativeModel("gemini-1.5-flash")
self._get_descriptions = _get_definitions_gemini

if llm and api_key:
self.model = setup_llm(llm, api_key)
else:
raise ValueError(f"Unsupported LLM: {name}")
self.model = None

def create_dict(self, data: pd.DataFrame | str) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -183,11 +161,12 @@ def generate_descriptions(

df = load_data_dict(self.config, data_dict)

self._setup_llm(key, llm)
if not self.model:
self.model = setup_llm(llm, key)

headers = df.source_field

descriptions = self._get_descriptions(list(headers), language, self.client)
descriptions = self.model.get_definitions(list(headers), language)

descriptions = {d.field_name: d.translation for d in descriptions}
df_descriptions = pd.DataFrame(
Expand Down
108 changes: 0 additions & 108 deletions src/adtl/autoparser/gemini_calls.py

This file was deleted.

30 changes: 30 additions & 0 deletions src/adtl/autoparser/language_models/base_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"Contains all functions that call OpenAI's API."

from __future__ import annotations


class LLMBase:
def __init__(self, api_key, model=None): # pragma: no cover
self.client = None
self.model = model

def get_definitions(self, headers, language): # pragma: no cover
"""
Get the definitions of the columns in the dataset.
"""
# subclasses should implement this method
raise NotImplementedError

def map_fields(self, source_fields, target_fields): # pragma: no cover
"""
Calls the OpenAI API to generate a draft mapping between two datasets.
"""
# subclasses should implement this method
raise NotImplementedError

def map_values(self, values, language): # pragma: no cover
"""
Calls the OpenAI API to generate a set of value mappings for the fields.
"""
# subclasses should implement this method
raise NotImplementedError
41 changes: 41 additions & 0 deletions src/adtl/autoparser/language_models/data_structures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Stores the data structures for using with LLM API's"""

from __future__ import annotations

from pydantic import BaseModel

# target classes for generating descriptions


class SingleField(BaseModel):
field_name: str
translation: str | None


class ColumnDescriptionRequest(BaseModel):
field_descriptions: list[SingleField]


# target classes for matching fields
class SingleMapping(BaseModel):
target_field: str
source_description: str | None


class MappingRequest(BaseModel):
targets_descriptions: list[SingleMapping]


# target classes for matching values to enum/boolean options
class ValueMapping(BaseModel):
source_value: str
target_value: str | None


class FieldMapping(BaseModel):
field_name: str
mapped_values: list[ValueMapping]


class ValuesRequest(BaseModel):
values: list[FieldMapping]
Loading
Loading