Skip to content

Commit

Permalink
feat: implement test suite over base task components, still need test…
Browse files Browse the repository at this point in the history
…s on individual tasks
  • Loading branch information
jmoffatt32 committed Jun 30, 2024
1 parent ca5a730 commit d4c674f
Show file tree
Hide file tree
Showing 24 changed files with 873 additions and 273 deletions.
4 changes: 4 additions & 0 deletions docprompt/provenance/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class DocumentProvenanceLocator:

@classmethod
def from_document_node(cls, document_node: "DocumentNode"):
# TODO: See if we can remove the ocr_results attribute from the
# PageNode and just use the metadata.task_result["<provider>_ocr"],
# result of the OCR task instead.

index = create_tantivy_document_wise_block_index()
block_mapping_dict = {}
geo_index_dict: DocumentProvenanceGeoMap = {}
Expand Down
11 changes: 4 additions & 7 deletions docprompt/schema/layout.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Literal, Optional

from pydantic import BaseModel, Field, PlainSerializer
from pydantic import BaseModel, ConfigDict, Field, PlainSerializer
from typing_extensions import Annotated

SegmentLevels = Literal["word", "line", "block"]
Expand Down Expand Up @@ -33,8 +33,7 @@ class NormBBox(BaseModel):
x1: BoundedFloat
bottom: BoundedFloat

class Config:
json_encoders = {float: lambda v: round(v, 5)} # 1/10,000 increments is plenty
model_config: ConfigDict = {"json_encoders": {float: lambda v: round(v, 5)}}

def as_tuple(self):
return (self.x0, self.top, self.x1, self.bottom)
Expand Down Expand Up @@ -194,8 +193,7 @@ class Point(BaseModel):
Represents a normalized bounding box with each value in the range [0, 1]
"""

class Config:
json_encoders = {float: lambda v: round(v, 5)} # 1/10,000 increments is plenty
model_config: ConfigDict = {"json_encoders": {float: lambda v: round(v, 5)}}

x: BoundedFloat
y: BoundedFloat
Expand Down Expand Up @@ -229,8 +227,7 @@ class TextBlock(BaseModel):
is normalized to the page size.
"""

class Config:
json_encoders = {float: lambda v: round(v, 5)} # 1/10,000 increments is plenty
model_config: ConfigDict = {"json_encoders": {float: lambda v: round(v, 5)}}

text: str
type: SegmentLevels
Expand Down
5 changes: 4 additions & 1 deletion docprompt/schema/pipeline/node/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

from docprompt.schema.pipeline.metadata import BaseMetadata
from docprompt.schema.pipeline.rasterizer import PageRasterizer
from docprompt.tasks.base import ResultContainer
from docprompt.tasks.ocr.result import OcrPageResult

# TODO: This dependency should be removed -- schmea should be lowest level
# Can do this by moving the OCR results to the metadata.task_results
from docprompt.tasks.result import ResultContainer

from .base import BaseNode
from .typing import PageNodeMetadata

Expand Down
201 changes: 83 additions & 118 deletions docprompt/tasks/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -10,96 +8,39 @@
List,
Optional,
TypeVar,
Union,
)

from pydantic import BaseModel, Field
from pydantic import BaseModel, PrivateAttr, ValidationInfo, model_validator
from typing_extensions import Self

from docprompt._decorators import flexible_methods
from docprompt.schema.document import Document

from .capabilities import DocumentLevelCapabilities, PageLevelCapabilities
from .result import BaseDocumentResult, BasePageResult
from .util import _init_context_var, init_context

if TYPE_CHECKING:
from docprompt.schema.pipeline import DocumentNode


class BaseResult(BaseModel):
provider_name: str = Field(
description="The name of the provider which produced the result"
)
when: datetime = Field(
default_factory=datetime.now, description="The time the result was produced"
)

task_name: ClassVar[str]

@property
def task_key(self):
return f"{self.provider_name}_{self.task_name}"

@abstractmethod
def contribute_to_document_node(
self, document_node: "DocumentNode", page_number: int = None
) -> None:
"""
Contribute this task result to the document node or a specific page node.
:param document_node: The DocumentNode to contribute to
:param page_number: If provided, contribute to a specific page. If None, contribute to the document.
"""
pass


class BaseDocumentResult(BaseResult):
document_name: str = Field(description="The name of the document")
file_hash: str = Field(description="The hash of the document")

def contribute_to_document_node(
self, document_node: "DocumentNode", page_number: int = None
) -> None:
document_node.metadata.task_results[self.task_key] = self


class BasePageResult(BaseResult):
page_number: int = Field(description="The page number")

def contribute_to_document_node(
self, document_node: "DocumentNode", page_number: int = None
) -> None:
assert page_number is not None, "Page number must be provided for a page result"
assert page_number > 0, "Page number must be greater than 0"

page_node = document_node.page_nodes[page_number - 1]
page_node.metadata.task_results[self.task_key] = self


TTaskInput = TypeVar("TTaskInput") # What invoke requires
TTaskConfig = TypeVar("TTaskConfig") # Task specific config like classification labels
PageTaskResult = TypeVar("PageTaskResult", bound=BasePageResult)
DocumentTaskResult = TypeVar("DocumentTaskResult", bound=BaseDocumentResult)
PageOrDocumentTaskResult = TypeVar("PageOrDocumentTaskResult", bound=BaseResult)
TPageResult = TypeVar("TPageResult", bound=BasePageResult)
TDocumentResult = TypeVar("TDocumentResult", bound=BaseDocumentResult)
TTaskResult = TypeVar("TTaskResult", bound=Union[BasePageResult, BaseDocumentResult])


class ResultContainer(BaseModel, Generic[PageOrDocumentTaskResult]):
"""
Represents a container for results of a task
"""

results: Dict[str, PageOrDocumentTaskResult] = Field(
description="The results of the task, keyed by provider", default_factory=dict
)

@property
def result(self):
return next(iter(self.results.values()), None)
Capabilites = TypeVar(
"Capabilities", bound=Union[DocumentLevelCapabilities, PageLevelCapabilities]
)


@flexible_methods(
("process_document_node", "aprocess_document_node"),
("invoke", "ainvoke"),
("_invoke", "_ainvoke"),
)
class AbstractPageTaskProvider(Generic[TTaskInput, TTaskConfig, PageTaskResult]):
class AbstractTaskProvider(BaseModel, Generic[TTaskInput, TTaskConfig, TTaskResult]):
"""
A task provider performs a specific, repeatable task on a document or its pages.
Expand All @@ -112,33 +53,71 @@ class AbstractPageTaskProvider(Generic[TTaskInput, TTaskConfig, PageTaskResult])
a flexible method pair, the other will automatically be generated and provided for you at runtime.
"""

name: str
capabilities: List[PageLevelCapabilities]
requires_input: bool
name: ClassVar[str]
capabilities: ClassVar[List[Capabilites]]

# TODO: Potentially utilize context here during instantiation from Factory??
_default_invoke_kwargs: Dict[str, str] = PrivateAttr()

class Meta:
"""The meta class is utilized by the flexible methods decorator.
_default_invoke_kwargs: Dict[str, Any]
For all classes that are not concrete implementations, we should set the
abstract attribute to True, which will prevent the check from failing when
the flexible methods decorator is looking for the implementation of the
methods.
"""

abstract = True

def __init__(self, invoke_kwargs: Dict[str, str] = None, **data):
with init_context({"invoke_kwargs": invoke_kwargs or {}}):
self.__pydantic_validator__.validate_python(
data,
self_instance=self,
context=_init_context_var.get(),
)

@model_validator(mode="before")
@classmethod
def with_kwargs(cls, **kwargs):
"""Create the provider with kwargs."""
obj = cls()
obj.provider_kwargs = kwargs
return obj
def validate_class_vars(cls, data: Any) -> Any:
"""
Ensure that the class has a name and capabilities defined.
"""

if not hasattr(cls, "name"):
raise ValueError("Task providers must have a name defined")

if not hasattr(cls, "capabilities"):
raise ValueError("Task providers must have capabilities defined")

if not cls.capabilities:
raise ValueError("Task providers must have at least one capability defined")

return data

@model_validator(mode="after")
def set_invoke_kwargs(self, info: ValidationInfo) -> Self:
"""
Set the default invoke kwargs for the task provider.
"""
self._default_invoke_kwargs = info.context["invoke_kwargs"]
return self

async def _ainvoke(
self,
input: Iterable[TTaskInput],
config: Optional[TTaskConfig] = None,
**kwargs,
) -> List[PageTaskResult]:
) -> List[TTaskResult]:
raise NotImplementedError

async def ainvoke(
self,
input: Iterable[TTaskInput],
config: Optional[TTaskConfig] = None,
**kwargs,
) -> List[PageTaskResult]:
) -> List[TTaskResult]:
invoke_kwargs = {
**self._default_invoke_kwargs,
**kwargs,
Expand All @@ -151,15 +130,15 @@ def _invoke(
input: Iterable[TTaskInput],
config: Optional[TTaskConfig] = None,
**kwargs,
) -> List[PageTaskResult]:
) -> List[TTaskResult]:
raise NotImplementedError

def invoke(
self,
input: Iterable[TTaskInput],
config: Optional[TTaskConfig] = None,
**kwargs,
) -> List[PageTaskResult]:
) -> List[TTaskResult]:
invoke_kwargs = {
**self._default_invoke_kwargs,
**kwargs,
Expand All @@ -175,7 +154,7 @@ def process_document_node(
stop: Optional[int] = None,
contribute_to_document: bool = True,
**kwargs,
) -> Dict[int, PageTaskResult]:
) -> Dict[int, TTaskResult]:
raise NotImplementedError

async def aprocess_document_node(
Expand All @@ -186,47 +165,33 @@ async def aprocess_document_node(
stop: Optional[int] = None,
contribute_to_document: bool = True,
**kwargs,
) -> Dict[int, PageTaskResult]:
) -> Dict[int, TTaskResult]:
raise NotImplementedError


class AbstractDocumentTaskProvider(ABC, Generic[TTaskInput, DocumentTaskResult]):
class AbstractPageTaskProvider(AbstractTaskProvider):
"""
A task provider performs a specific, repeatable task on a document
A page task provider performs a specific, repeatable task on a page.
"""

name: str
capabilities: List[DocumentLevelCapabilities]
capabilities: ClassVar[List[PageLevelCapabilities]]

# NOTE: Temporary solution to allo kwargs from the factory to providers who
# don't take arbitrary kwargs in there __init__ method
_provider_kwargs: Dict[str, Any]
# NOTE: We need the stubs defined here for the flexible decorators to work
# for now

@classmethod
def with_kwargs(cls, **kwargs):
"""Create the provider with kwargs."""
obj = cls()
obj.provider_kwargs = kwargs
return obj

@abstractmethod
def process_document(
self, document: Document, task_input: Optional[TTaskInput] = None, **kwargs
) -> DocumentTaskResult:
raise NotImplementedError
class Meta:
abstract = True

def process_document_node(
self,
document_node: "DocumentNode",
task_input: Optional[TTaskInput] = None,
contribute_to_document: bool = True,
**kwargs,
) -> DocumentTaskResult:
result = self.process_document(
document_node.document, task_input=task_input, **kwargs
)

if contribute_to_document:
result.contribute_to_document_node(document_node)
class AbstractDocumentTaskProvider(AbstractTaskProvider):
"""
A task provider performs a specific, repeatable task on a document.
"""

capabilities: ClassVar[List[DocumentLevelCapabilities]]

# NOTE: We need the stubs defined here for the flexible decorators to work
# for now

return result
class Meta:
abstract = True
2 changes: 1 addition & 1 deletion docprompt/tasks/classification/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ async def _prepare_messages(
class AnthropicClassificationProvider(BaseClassificationProvider):
"""The Anthropic implementation of unscored page classification."""

name: str = "anthropic"
name = "anthropic"

async def _ainvoke(
self, input: Iterable[bytes], config: ClassificationConfig = None
Expand Down
12 changes: 10 additions & 2 deletions docprompt/tasks/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
from pydantic import BaseModel, Field, model_validator

from docprompt import DocumentNode
from docprompt.tasks.base import AbstractPageTaskProvider, PageTaskResult
from docprompt.tasks.base import AbstractPageTaskProvider
from docprompt.tasks.parser import BaseOutputParser
from docprompt.tasks.result import BasePageResult

from ..capabilities import PageLevelCapabilities

LabelType = Union[List[str], Enum, str]

Expand Down Expand Up @@ -108,7 +111,7 @@ def formatted_labels(self):
yield from raw_labels


class ClassificationOutput(PageTaskResult):
class ClassificationOutput(BasePageResult):
type: ClassificationTypes
labels: LabelType
score: Optional[ConfidenceLevel] = Field(None)
Expand Down Expand Up @@ -184,6 +187,11 @@ class BaseClassificationProvider(
The base classification provider.
"""

capabilities = [PageLevelCapabilities.PAGE_CLASSIFICATION]

class Meta:
abstract = True

def process_document_node(
self,
document_node: "DocumentNode",
Expand Down
Loading

0 comments on commit d4c674f

Please sign in to comment.