Skip to content

Commit

Permalink
feat: fix annotations from Document to DocumentNode
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoffatt32 committed Jun 30, 2024
1 parent a799c10 commit 6938fe2
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions docprompt/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
if TYPE_CHECKING:
from docprompt.schema.pipeline import DocumentNode

TDocumentNode = TypeVar("TDocumentNode", bound="DocumentNode")


class BaseResult(BaseModel):
provider_name: str = Field(
Expand Down Expand Up @@ -89,7 +91,15 @@ def result(self):
)
class AbstractPageTaskProvider(ABC, Generic[TTaskInput, PageTaskResult]):
"""
A task provider performs a specific, repeatable task on a document or its pages
A task provider performs a specific, repeatable task on a document or its pages.
NOTE: Either the `process_document_pages` or `aprocess_document_pages` method must be implemented in
a valid subclass. The `process_document_pages` method is explicitly defined, while the `aprocess_document_pages`
method is an async version of the same method.
If you wish to provide seperate implementations for sync and async, you can define both methods individually, and
they will each use their own custom implementation when called. Otherwise, if you only implement one or the other of
a flexible method pair, the other will automatically be generated and provided for you at runtime.
"""

name: str
Expand All @@ -107,27 +117,31 @@ def with_kwargs(cls, **kwargs):

async def aprocess_document_pages(
self,
document: Document,
document: TDocumentNode,
task_input: Optional[TTaskInput] = None,
start: Optional[int] = None,
stop: Optional[int] = None,
**kwargs,
):
raise NotImplementedError
raise NotImplementedError(
"`process_document_pages` or `aprocess_document_pages` must be implemented."
)

def process_document_pages(
self,
document: Document,
document: TDocumentNode,
task_input: Optional[TTaskInput] = None,
start: Optional[int] = None,
stop: Optional[int] = None,
**kwargs,
) -> Dict[int, PageTaskResult]:
raise NotImplementedError
raise NotImplementedError(
"`process_document_pages` or `aprocess_document_pages` must be implemented."
)

def process_document_node(
self,
document_node: "DocumentNode",
document_node: TDocumentNode,
task_input: Optional[TTaskInput] = None,
start: Optional[int] = None,
stop: Optional[int] = None,
Expand Down

0 comments on commit 6938fe2

Please sign in to comment.