Skip to content

Commit

Permalink
feat: update classification and table_extraction tasks to new interface
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoffatt32 committed Jun 30, 2024
1 parent 1b8ed68 commit ca5a730
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 182 deletions.
156 changes: 39 additions & 117 deletions docprompt/tasks/classification/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
"""The antrhopic implementation of page level calssification."""

import re
from typing import Dict, List, Union
from typing import Iterable, List, Optional

from jinja2 import Template
from pydantic import Field

from docprompt.schema.pipeline import DocumentNode
from docprompt.tasks.message import OpenAIMessage
from docprompt.tasks.parser import BaseOutputParser
from docprompt.tasks.message import OpenAIComplexContent, OpenAIImageURL, OpenAIMessage
from docprompt.utils import inference

from .base import (
BaseClassificationProvider,
ClassificationInput,
BasePageClassificationOutputParser,
ClassificationConfig,
ClassificationOutput,
ClassificationTypes,
ConfidenceLevel,
LabelType,
)

PAGE_CLASSIFICATION_SYSTEM_PROMPT = Template(
Expand Down Expand Up @@ -65,63 +60,9 @@
)


class PageClassificationOutputParser(
BaseOutputParser[ClassificationInput, ClassificationOutput]
):
class AnthropicPageClassificationOutputParser(BasePageClassificationOutputParser):
"""The output parser for the page classification system."""

name: str = Field(...)
type: ClassificationTypes = Field(...)
labels: LabelType = Field(...)
confidence: bool = Field(False)

@classmethod
def from_task_input(cls, task_input: ClassificationInput, provider_name: str):
return cls(
type=task_input.type,
name=provider_name,
labels=task_input.labels,
confidence=task_input.confidence,
)

def resolve_match(self, _match: Union[re.Match, None]) -> LabelType:
"""Get the regex pattern for the output parser."""

if not _match:
raise ValueError("Could not find the answer in the text.")

val = _match.group(1)
match self.type:
case ClassificationTypes.BINARY:
if val not in self.labels:
raise ValueError(f"Invalid label: {val}")
return val

case ClassificationTypes.SINGLE_LABEL:
if val not in self.labels:
raise ValueError(f"Invalid label: {val}")
return val

case ClassificationTypes.MULTI_LABEL:
labels = val.split(", ")
for label in labels:
if label not in self.labels:
raise ValueError(f"Invalid label: {label}")
return labels

case _:
raise ValueError(f"Invalid classification type: {self.type}")

def resolve_confidence(self, _match: Union[re.Match, None]) -> ConfidenceLevel:
"""Get the confidence level from the text."""

if not _match:
return None

val = _match.group(1).lower()

return ConfidenceLevel(val)

def parse(self, text: str) -> ClassificationOutput:
"""Parse the results of the classification task."""
pattern = re.compile(r"Answer: (.+)")
Expand All @@ -146,68 +87,49 @@ def parse(self, text: str) -> ClassificationOutput:
)


async def classify_images(
image_uris: List[str], task_input: ClassificationInput, **kwargs
) -> List[ClassificationOutput]:
"""Classify a list of images with the given input."""

def _format_message(image_uri: str):
system = OpenAIMessage(
role="system",
content=PAGE_CLASSIFICATION_SYSTEM_PROMPT.render(input=task_input),
async def _prepare_messages(
document_images: Iterable[bytes],
config: ClassificationConfig,
start: Optional[int] = None,
stop: Optional[int] = None,
):
messages = []

for image_bytes in document_images:
messages.append(
[
OpenAIMessage(
role="user",
content=[
OpenAIComplexContent(
type="image_url",
image_url=OpenAIImageURL(url=image_bytes),
),
OpenAIComplexContent(
type="text",
text=PAGE_CLASSIFICATION_SYSTEM_PROMPT.render(input=config),
),
],
),
]
)

human = OpenAIMessage.from_image_uri(image_uri)
return [system, human]

messages = [_format_message(uri) for uri in image_uris]

model_name = kwargs.pop("model_name", "claude-3-haiku-20240307")

provider_name = kwargs.pop("provider_name", "anthropic")
parser = PageClassificationOutputParser.from_task_input(
task_input, provider_name=provider_name
)

completions = await inference.run_batch_inference_anthropic(
model_name, messages, **kwargs
)

labels = [parser.parse(res) for res in completions]

return labels
return messages


class AnthropicClassificationProvider(BaseClassificationProvider):
"""The Anthropic implementation of unscored page classification."""

name: str = "anthropic"

async def aprocess_document_pages(
self,
document_node: DocumentNode,
task_input: ClassificationInput,
start: int | None = None,
stop: int | None = None,
contribute_to_document: bool = True,
**kwargs,
) -> Dict[int, ClassificationOutput]:
start = start or 0
stop = stop or len(document_node.page_nodes)

assert (
0 <= start < stop <= len(document_node.page_nodes)
), f"Invalid start and stop values: {start}, {stop}"

image_uris = [
page.rasterizer.rasterize_to_data_uri("default")
for page in document_node.page_nodes[start:stop]
]

labels = await classify_images(
image_uris, task_input, **kwargs, provider_name=self.name
async def _ainvoke(
self, input: Iterable[bytes], config: ClassificationConfig = None
) -> List[ClassificationOutput]:
messages = _prepare_messages(input, config)
parser = AnthropicPageClassificationOutputParser.from_task_input(
config, provider_name=self.name
)

results = {i: label for i, label in zip(range(start, stop), labels)}
completions = await inference.run_batch_inference_anthropic(messages)

return results
return [parser.parse(res) for res in completions]
95 changes: 92 additions & 3 deletions docprompt/tasks/classification/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import re
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, List, Optional, Union

from pydantic import BaseModel, Field, model_validator

from docprompt import DocumentNode
from docprompt.tasks.base import AbstractPageTaskProvider, PageTaskResult
from docprompt.tasks.parser import BaseOutputParser

LabelType = Union[List[str], Enum, str]

Expand All @@ -22,7 +26,7 @@ class ClassificationTypes(str, Enum):
BINARY = "binary"


class ClassificationInput(BaseModel):
class ClassificationConfig(BaseModel):
type: ClassificationTypes
labels: LabelType
descriptions: Optional[List[str]] = Field(
Expand Down Expand Up @@ -111,11 +115,96 @@ class ClassificationOutput(PageTaskResult):
task_name: str = "classification"


class BasePageClassificationOutputParser(
ABC, BaseOutputParser[ClassificationConfig, ClassificationOutput]
):
"""The output parser for the page classification system."""

name: str = Field(...)
type: ClassificationTypes = Field(...)
labels: LabelType = Field(...)
confidence: bool = Field(False)

@classmethod
def from_task_input(cls, task_input: ClassificationConfig, provider_name: str):
return cls(
type=task_input.type,
name=provider_name,
labels=task_input.labels,
confidence=task_input.confidence,
)

def resolve_match(self, _match: Union[re.Match, None]) -> LabelType:
"""Get the regex pattern for the output parser."""

if not _match:
raise ValueError("Could not find the answer in the text.")

val = _match.group(1)
match self.type:
case ClassificationTypes.BINARY:
if val not in self.labels:
raise ValueError(f"Invalid label: {val}")
return val

case ClassificationTypes.SINGLE_LABEL:
if val not in self.labels:
raise ValueError(f"Invalid label: {val}")
return val

case ClassificationTypes.MULTI_LABEL:
labels = val.split(", ")
for label in labels:
if label not in self.labels:
raise ValueError(f"Invalid label: {label}")
return labels

case _:
raise ValueError(f"Invalid classification type: {self.type}")

def resolve_confidence(self, _match: Union[re.Match, None]) -> ConfidenceLevel:
"""Get the confidence level from the text."""

if not _match:
return None

val = _match.group(1).lower()

return ConfidenceLevel(val)

@abstractmethod
def parse(self, text: str) -> ClassificationOutput:
pass


class BaseClassificationProvider(
AbstractPageTaskProvider[ClassificationInput, ClassificationOutput]
AbstractPageTaskProvider[bytes, ClassificationConfig, ClassificationOutput]
):
"""
The base classification provider.
"""

pass
def process_document_node(
self,
document_node: "DocumentNode",
task_config: ClassificationConfig = None,
start: Optional[int] = None,
stop: Optional[int] = None,
contribute_to_document: bool = True,
**kwargs,
):
raster_bytes = []
for page_number in range(start or 1, (stop or len(document_node)) + 1):
image_bytes = document_node.page_nodes[
page_number - 1
].rasterizer.rasterize("default")
raster_bytes.append(image_bytes)

results = self._invoke(raster_bytes, config=task_config, **kwargs)

return {
i: res
for i, res in zip(
range(start or 1, (stop or len(document_node)) + 1), results
)
}
37 changes: 7 additions & 30 deletions docprompt/tasks/markerize/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from bs4 import BeautifulSoup

from docprompt.schema.pipeline.node.document import DocumentNode
from docprompt.tasks.message import OpenAIComplexContent, OpenAIImageURL, OpenAIMessage
from docprompt.utils import inference

Expand Down Expand Up @@ -54,36 +53,14 @@ async def _prepare_messages(
class AnthropicMarkerizeProvider(BaseMarkerizeProvider):
name = "anthropic"

def _invoke(
async def _ainvoke(
self, input: Iterable[bytes], config: Optional[None] = None
) -> List[MarkerizeResult]:
messages = _prepare_messages(input)

completions = inference.run_batch_inference_anthropic(messages)

return [_parse_result(x) for x in completions]

def process_document_node(
self,
document_node: "DocumentNode",
task_config: Optional[None] = None,
start: Optional[int] = None,
stop: Optional[int] = None,
contribute_to_document: bool = True,
**kwargs,
):
raster_bytes = []
for page_number in range(start or 1, (stop or len(document_node)) + 1):
image_bytes = document_node.page_nodes[
page_number - 1
].rasterizer.rasterize("default")
raster_bytes.append(image_bytes)

results = self._invoke(raster_bytes, config=task_config, **kwargs)

return {
i: MarkerizeResult(provider_name=self.name, raw_markdown=x)
for i, x in zip(
range(start or 1, (stop or len(document_node)) + 1), results
)
}
completions = await inference.run_batch_inference_anthropic(messages)

return [
MarkerizeResult(raw_markdown=_parse_result(x), provider_name=self.name)
for x in completions
]
Loading

0 comments on commit ca5a730

Please sign in to comment.