Skip to content

Commit

Permalink
fix: swap process_document_node to recieve a DocumentNode instead of …
Browse files Browse the repository at this point in the history
…Document
  • Loading branch information
jmoffatt32 committed Jun 30, 2024
1 parent ff486ee commit a799c10
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 64 deletions.
2 changes: 1 addition & 1 deletion docprompt/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def process_document_node(
) -> Dict[int, PageTaskResult]:
kwargs = {**(self.provider_kwargs or {}), **kwargs}
results = self.process_document_pages(
document_node.document,
document_node,
task_input=task_input,
start=start,
stop=stop,
Expand Down
58 changes: 6 additions & 52 deletions docprompt/tasks/classification/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
from abc import abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import Any, List, Optional, Union

from pydantic import BaseModel, Field, model_validator
from typing_extensions import override

from docprompt.schema.pipeline import DocumentNode
from docprompt.tasks.base import AbstractPageTaskProvider, PageTaskResult

if TYPE_CHECKING:
from docprompt.schema.pipeline import DocumentNode

LabelType = Union[List[str], Enum, str]


Expand Down Expand Up @@ -120,48 +114,8 @@ class ClassificationOutput(PageTaskResult):
class BaseClassificationProvider(
AbstractPageTaskProvider[ClassificationInput, ClassificationOutput]
):
# NOTE: We override the method here for more accurate type-hinting
@override
async def aprocess_document_node(
self,
document_node: DocumentNode,
task_input: ClassificationInput,
start: int | None = None,
stop: int | None = None,
contribute_to_document: bool = True,
**kwargs,
) -> Dict[int, ClassificationOutput]:
# NOTE: We need to pass the document node to `process_document_pages` instead of `document_node.document`
# because the `DocumentNode` object contains an easy iterable of `PageNode` objects.
# The defualt implementation of `process_document_nodes` passes, the `document_node.document` object to
# `process_document_pages` which is not iterable.

# Essentially an identical implementation as the `AbstractPageTaskProvider` class, but with
# `DocumentNode` instead of `Document` as the first argument to `process_document_pages`
kwargs = {**(self.provider_kwargs or {}), **kwargs}
results = await self.aprocess_document_pages(
document_node,
task_input,
start=start,
stop=stop,
**kwargs,
)

# We still need to make sure that this is reimplemented
if contribute_to_document:
for page_number, page_result in results.items():
page_result.contribute_to_document_node(document_node, page_number)

return results

@abstractmethod
@override
def process_document_pages(
self,
document_node: DocumentNode, # NOTE: We override here to DocumentNode, instead of Document
task_input: ClassificationInput,
start: int | None = None,
stop: int | None = None,
contribute_to_document: bool = True,
**kwargs,
) -> Dict[int, ClassificationOutput]: ...
"""
The base classification provider.
"""

pass
11 changes: 0 additions & 11 deletions docprompt/tasks/markerize/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from typing import Dict

from docprompt.schema.pipeline import DocumentNode
from docprompt.tasks.base import AbstractPageTaskProvider, BasePageResult


Expand All @@ -12,11 +9,3 @@ class MarkerizeResult(BasePageResult):
class BaseMarkerizeProvider(AbstractPageTaskProvider[None, MarkerizeResult]):
class Meta:
abstract = True

def contribute_to_document_node(
self, document_node: DocumentNode, results: Dict[int, MarkerizeResult]
) -> None:
for page_number, page_result in results.items():
document_node.page_nodes[page_number - 1].extra["raw_markdown"] = (
page_result.raw_markdown
)

0 comments on commit a799c10

Please sign in to comment.