Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OSS support added #265

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion docetl/operations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .base import BaseOperation
import importlib.metadata

from docetl.operations.cluster import ClusterOperation
from docetl.operations.code_operations import CodeFilterOperation, CodeMapOperation, CodeReduceOperation
from docetl.operations.equijoin import EquijoinOperation
Expand All @@ -11,6 +13,22 @@
from docetl.operations.sample import SampleOperation
from docetl.operations.unnest import UnnestOperation

__all__ = [
'BaseOperation',
'ClusterOperation',
'CodeFilterOperation',
'CodeMapOperation',
'CodeReduceOperation',
'EquijoinOperation',
'FilterOperation',
'GatherOperation',
'MapOperation',
'ReduceOperation',
'ResolveOperation',
'SplitOperation',
'SampleOperation',
'UnnestOperation',
]

mapping = {
"cluster": ClusterOperation,
Expand Down Expand Up @@ -47,4 +65,4 @@ def get_operations():
op.name: op.load()
for op in importlib.metadata.entry_points(group="docetl.operation")
})
return operations
return operations
57 changes: 7 additions & 50 deletions docetl/operations/base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
"""
The BaseOperation class is an abstract base class for all operations in the docetl framework. It provides a common structure and interface for various data processing operations.
The BaseOperation class is an abstract base class for all operations in the docetl framework.
"""

from abc import ABC, ABCMeta, abstractmethod
from typing import Dict, List, Optional, Tuple

from docetl.operations.utils import APIWrapper
from docetl.console import DOCETL_CONSOLE
from rich.console import Console
from rich.status import Status
import jsonschema
from pydantic import BaseModel


# FIXME: This should probably live in some utils module?
class classproperty(object):
def __init__(self, f):
self.f = f
Expand All @@ -32,7 +29,7 @@ def __new__(cls, *arg, **kw):
class BaseOperation(ABC, metaclass=BaseOperationMeta):
def __init__(
self,
runner: "ConfigWrapper",
runner: "ConfigWrapper", # Type hint as string to avoid circular import
config: Dict,
default_model: str,
max_threads: int,
Expand All @@ -41,23 +38,13 @@ def __init__(
is_build: bool = False,
**kwargs,
):
"""
Initialize the BaseOperation.

Args:
config (Dict): Configuration dictionary for the operation.
default_model (str): Default language model to use.
max_threads (int): Maximum number of threads for parallel processing.
console (Optional[Console]): Rich console for outputting logs. Defaults to None.
status (Optional[Status]): Rich status for displaying progress. Defaults to None.
"""
assert "name" in config, "Operation must have a name"
assert "type" in config, "Operation must have a type"
self.runner = runner
self.config = config
self.default_model = default_model
self.max_threads = max_threads
self.console = console or DOCETL_CONSOLE
self.console = console or Console()
self.manually_fix_errors = self.config.get("manually_fix_errors", False)
self.status = status
self.num_retries_on_validate_failure = self.config.get(
Expand All @@ -70,7 +57,6 @@ def __init__(
class schema(BaseModel, extra="allow"):
name: str
type: str
skip_on_error: bool = False

@classproperty
def json_schema(cls):
Expand All @@ -84,45 +70,16 @@ def json_schema(cls):

@abstractmethod
def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
"""
Execute the operation on the input data.

This method should be implemented by subclasses to perform the
actual operation on the input data.

Args:
input_data (List[Dict]): List of input data items.

Returns:
Tuple[List[Dict], float]: A tuple containing the processed data
and the total cost of the operation.
"""
"""Execute the operation on the input data."""
pass

@abstractmethod
def syntax_check(self) -> None:
"""
Perform syntax checks on the operation configuration.

This method should be implemented by subclasses to validate the
configuration specific to each operation type.

Raises:
ValueError: If the configuration is invalid.
"""
"""Perform syntax checks on the operation configuration."""
jsonschema.validate(instance=self.config, schema=self.json_schema)

def gleaning_check(self) -> None:
"""
Perform checks on the gleaning configuration.

This method validates the gleaning configuration if it's present
in the operation config.

Raises:
ValueError: If the gleaning configuration is invalid.
TypeError: If the gleaning configuration has incorrect types.
"""
"""Perform checks on the gleaning configuration."""
if "gleaning" not in self.config:
return
if "num_rounds" not in self.config["gleaning"]:
Expand All @@ -145,4 +102,4 @@ def gleaning_check(self) -> None:
if not self.config["gleaning"]["validation_prompt"].strip():
raise ValueError(
"'validation_prompt' in 'gleaning' configuration cannot be empty"
)
)
4 changes: 3 additions & 1 deletion docetl/operations/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .llm import LLMResult, InvalidOutputError, truncate_messages
from .progress import RichLoopBar, rich_as_completed
from .validation import safe_eval, convert_val, convert_dict_schema_to_list_schema, get_user_input_for_schema, strict_render
from .oss_llm import OSSLLMOperation

__all__ = [
'APIWrapper',
Expand All @@ -32,5 +33,6 @@
'convert_dict_schema_to_list_schema',
'get_user_input_for_schema',
'truncate_messages',
"strict_render"
'strict_render',
'OSSLLMOperation'
]
61 changes: 59 additions & 2 deletions docetl/operations/utils/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,50 @@
from docetl.utils import completion_cost

from rich import print as rprint
from .oss_llm import OSSLLMOperation

class APIWrapper(object):
class APIWrapper:
def __init__(self, runner):
self.runner = runner
self._oss_operations = {}

def _is_oss_model(self, model: str) -> bool:
"""Check if model is an OSS model"""
return any(prefix in model for prefix in ["local/", "huggingface/", "oss/"])

def _get_oss_operation(self, model: str, output_schema: Dict[str, Any]) -> OSSLLMOperation:
"""Get or create OSSLLMOperation instance"""
model_path = model.split("/", 1)[1]
if model_path not in self._oss_operations:
config = {
"name": f"oss_llm_{model_path}",
"type": "oss_llm",
"model_path": model_path,
"output_schema": output_schema,
"max_tokens": 4096 # Default value, can be overridden
}
self._oss_operations[model_path] = OSSLLMOperation(
config=config,
runner=self.runner
)
return self._oss_operations[model_path]

def _call_llm_with_cache(
self,
model: str,
op_type: str,
messages: List[Dict[str, str]],
output_schema: Dict[str, str],
tools: Optional[str] = None,
scratchpad: Optional[str] = None,
litellm_completion_kwargs: Dict[str, Any] = {},
) -> ModelResponse:
"""Existing function, just add OSS model handling at the start"""

# Add OSS model handling
if self._is_oss_model(model):
operation = self._get_oss_operation(model, output_schema)
return operation.process_messages(messages)

@freezeargs
def gen_embedding(self, model: str, input: List[str]) -> List[float]:
Expand Down Expand Up @@ -389,6 +429,23 @@ def _call_llm_with_cache(
Returns:
str: The response from the LLM.
"""
if any(prefix in model for prefix in ["local/", "huggingface/", "oss/"]):
try:
config = {
"name": f"oss_{op_type}",
"type": "oss_llm",
"model_path": model.split("/", 1)[1],
"output_schema": output_schema,
"max_tokens": litellm_completion_kwargs.get("max_tokens", 4096)
}

oss_op = self._get_oss_operation(config)
return oss_op.process_messages(messages).response

except Exception as e:
self.runner.console.print(f"[red]Error with OSS model: {str(e)}[/red]")
raise e

props = {key: convert_val(value) for key, value in output_schema.items()}
use_tools = True

Expand Down Expand Up @@ -612,7 +669,7 @@ def _parse_llm_response_helper(
for tool_call in tool_calls:
if response.choices[0].finish_reason == "content_filter":
raise InvalidOutputError(
"Content filter triggered by LLM provider.",
"Content filter triggered in LLM response",
"",
schema,
response.choices,
Expand Down
Loading
Loading