Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task refactor #7

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions docprompt/contrib/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Any, Dict, List

try:
import litellm
except ImportError:
print(
"litellm is required for this function. Install with `pip install docprompt[litellm]`"
)
raise


def get_sync_litellm_callable(model: str, /, **kwargs):
if "messages" in kwargs:
raise ValueError("messages should only be passed at runtime")

def wrapper(messages: List[Dict[str, Any]]):
response = litellm.completion(model=model, messages=messages, **kwargs)

return response.to_dict()

return wrapper


def get_async_litellm_callable(model: str, /, **kwargs):
if "messages" in kwargs:
raise ValueError("messages should only be passed at runtime")

async def wrapper(messages: List[Dict[str, Any]]):
response = await litellm.acompletion(model=model, messages=messages, **kwargs)

return response.to_dict()

return wrapper
Empty file.
190 changes: 187 additions & 3 deletions docprompt/tasks/base.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Coroutine,
Dict,
Generic,
Iterable,
List,
Optional,
Type,
TypedDict,
TypeVar,
Union,
)

from pydantic import BaseModel, PrivateAttr, ValidationInfo, model_validator
from typing_extensions import Self
from pydantic import (
BaseModel,
Field,
GetCoreSchemaHandler,
PrivateAttr,
ValidationInfo,
model_validator,
)
from pydantic_core import core_schema
from typing_extensions import Annotated, Self

from docprompt._decorators import flexible_methods

Expand All @@ -21,7 +33,7 @@
from .util import _init_context_var, init_context

if TYPE_CHECKING:
from docprompt.schema.pipeline import DocumentNode
from docprompt.schema.pipeline import DocumentNode, PageNode


TTaskInput = TypeVar("TTaskInput") # What invoke requires
Expand All @@ -35,6 +47,16 @@
)


class NullSchema:
def __get_pydantic_core_schema__(
self, source: Type[Any], handler: GetCoreSchemaHandler
):
def noop_validate(value: Any) -> Any:
return value

return core_schema.no_info_plain_validator_function(noop_validate)


@flexible_methods(
("process_document_node", "aprocess_document_node"),
("_invoke", "_ainvoke"),
Expand Down Expand Up @@ -168,6 +190,168 @@ async def aprocess_document_node(
raise NotImplementedError


class SupportsOpenAIMessages(BaseModel, Generic[TTaskInput]):
"""
Mixin for task providers that support OpenAI.
"""

def get_openai_messages(self, input: TTaskInput, **kwargs) -> List[Dict[str, Any]]:
raise NotImplementedError

async def aget_openai_messages(
self, input: TTaskInput, **kwargs
) -> Coroutine[None, None, Dict[str, Any]]:
raise NotImplementedError


class SupportsParsing(BaseModel, Generic[TTaskResult]):
"""
Mixin for task providers that support parsing.
"""

def parse(self, response: str, **kwargs) -> TTaskResult:
raise NotImplementedError

async def aparse(self, response: str, **kwargs) -> TTaskResult:
raise NotImplementedError


class SupportsPageNode(BaseModel, Generic[TTaskConfig, TPageResult]):
"""
Mixin for task providers that support page processing.
"""

def process_page_node(
self,
page_node: "PageNode",
task_config: Optional[TTaskConfig] = None,
**kwargs,
) -> TPageResult:
raise NotImplementedError

async def aprocess_page_node(
self,
page_node: "PageNode",
task_config: Optional[TTaskConfig] = None,
**kwargs,
) -> TPageResult:
raise NotImplementedError


class SupportsDirectInvocation(
BaseModel, Generic[TTaskInput, TTaskConfig, TTaskResult]
):
"""
Mixin for task providers that support direct invocation on
non-node based items.s
"""

def invoke(
self, input: TTaskInput, config: Optional[TTaskConfig] = None, **kwargs
) -> TTaskResult:
raise NotImplementedError

async def ainvoke(
self, input: TTaskInput, config: Optional[TTaskConfig] = None, **kwargs
) -> TTaskResult:
raise NotImplementedError


class SupportsDocumentNode(BaseModel, Generic[TTaskInput, TDocumentResult]):
"""
Mixin for task providers that support document processing.
"""

def process_document_node(
self,
document_node: "DocumentNode",
task_config: Optional[TTaskConfig] = None,
**kwargs,
) -> TDocumentResult:
raise NotImplementedError

async def aprocess_document_node(
self,
document_node: "DocumentNode",
task_config: Optional[TTaskConfig] = None,
**kwargs,
) -> TDocumentResult:
raise NotImplementedError


@flexible_methods(
("process_image", "aprocess_image"),
)
class SupportsImage(BaseModel, Generic[TTaskInput, TTaskResult]):
"""
Mixin for task providers that support image processing.
"""

def process_image(self, input: TTaskInput, **kwargs) -> TTaskResult:
raise NotImplementedError

async def aprocess_image(self, input: TTaskInput, **kwargs) -> TTaskResult:
raise NotImplementedError


class OpenAIMessageItem(TypedDict):
content: str


class OpenAiChoiceItem(TypedDict):
finish_reason: str
index: int
message: OpenAIMessageItem


class OpenAICompletionResponse(TypedDict):
choices: List[OpenAiChoiceItem]


SyncOAICallable = Callable[[List[Dict[str, Any]]], OpenAICompletionResponse]
AsyncOAICallable = Coroutine[List[Dict[str, Any]], None, OpenAICompletionResponse]


class ProviderAgnosticOAI(BaseModel):
sync_callable: Annotated[SyncOAICallable, NullSchema()] = Field(
default=None, exclude=True
)
async_callable: Annotated[AsyncOAICallable, NullSchema()] = Field(
default=None, exclude=True
)

@model_validator(mode="after")
def validate_callable(self):
if not self.sync_callable and not self.async_callable:
raise ValueError(
f"{self.__class__.__name__} must be initialized with either `sync_callable` and/or `async_callable`"
)

return self


@flexible_methods(
("process_webpage", "aprocess_webpage"),
)
class SupportsWebPage(BaseModel, Generic[TTaskInput, TPageResult]):
"""
Mixin for task providers that support webpage processing.
"""

def process_webpage(self, input: TTaskInput, **kwargs) -> TPageResult:
raise NotImplementedError

async def aprocess_webpage(self, input: TTaskInput, **kwargs) -> TPageResult:
raise NotImplementedError


class SupportsTaskConfig(BaseModel, Generic[TTaskConfig]):
task_config: TTaskConfig = None

def get_config(self) -> TTaskConfig:
return self.task_config


class AbstractPageTaskProvider(AbstractTaskProvider):
"""
A page task provider performs a specific, repeatable task on a page.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from typing import Iterable, List

from pydantic import Field

Check failure on line 6 in docprompt/tasks/classification/image.py

View workflow job for this annotation

GitHub Actions / test (3.11)

Ruff (F401)

docprompt/tasks/classification/image.py:6:22: F401 `pydantic.Field` imported but unused

Check failure on line 6 in docprompt/tasks/classification/image.py

View workflow job for this annotation

GitHub Actions / test (3.12)

Ruff (F401)

docprompt/tasks/classification/image.py:6:22: F401 `pydantic.Field` imported but unused

from docprompt.tasks.message import OpenAIComplexContent, OpenAIImageURL, OpenAIMessage
from docprompt.utils import inference
Expand Down Expand Up @@ -73,7 +73,7 @@
return "".join(prompt_parts).strip()


class AnthropicPageClassificationOutputParser(BasePageClassificationOutputParser):
class ImagePageClassificationOutputParser(BasePageClassificationOutputParser):
"""The output parser for the page classification system."""

def parse(self, text: str) -> ClassificationOutput:
Expand Down Expand Up @@ -131,20 +131,18 @@
class AnthropicClassificationProvider(BaseClassificationProvider):
"""The Anthropic implementation of unscored page classification."""

name = "anthropic"

anthropic_model_name: str = Field("claude-3-haiku-20240307")
name = "image"

async def _ainvoke(
self, input: Iterable[bytes], config: ClassificationConfig = None, **kwargs
) -> List[ClassificationOutput]:
messages = _prepare_messages(input, config)

parser = AnthropicPageClassificationOutputParser.from_task_input(
parser = ImagePageClassificationOutputParser.from_task_input(
config, provider_name=self.name
)

model_name = kwargs.pop("model_name", self.anthropic_model_name)
model_name = kwargs.pop("model_name")
completions = await inference.run_batch_inference_anthropic(
model_name, messages, **kwargs
)
Expand Down
6 changes: 3 additions & 3 deletions docprompt/tasks/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,15 @@ def _validate_provider(self, info: ValidationInfo) -> Self:

def get_page_classification_provider(self, **kwargs) -> TTaskProvider:
"""Get the page classification provider."""
from docprompt.tasks.classification.anthropic import (
from docprompt.tasks.classification.image import (
AnthropicClassificationProvider,
)

return AnthropicClassificationProvider(invoke_kwargs=self._credentials.kwargs)

def get_page_table_extraction_provider(self, **kwargs) -> TTaskProvider:
"""Get the page table extraction provider."""
from docprompt.tasks.table_extraction.anthropic import (
from docprompt.tasks.table_extraction.image_xml import (
AnthropicTableExtractionProvider,
)

Expand All @@ -186,7 +186,7 @@ def get_page_table_extraction_provider(self, **kwargs) -> TTaskProvider:

def get_page_markerization_provider(self, **kwargs) -> TTaskProvider:
"""Get the page markerization provider."""
from docprompt.tasks.markerize.anthropic import AnthropicMarkerizeProvider
from docprompt.tasks.markerize.image import AnthropicMarkerizeProvider

return AnthropicMarkerizeProvider(
invoke_kwargs=self._credentials.kwargs, **kwargs
Expand Down
Loading
Loading