Skip to content

Commit

Permalink
Remove jinja dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
PSU3D0 committed Oct 4, 2024
1 parent f201d76 commit 9ec94b1
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 53 deletions.
107 changes: 58 additions & 49 deletions docprompt/tasks/classification/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""The antrhopic implementation of page level calssification."""
"""The anthropic implementation of page level classification."""

import re
from typing import Iterable, List

from jinja2 import Template
from pydantic import Field

from docprompt.tasks.message import OpenAIComplexContent, OpenAIImageURL, OpenAIMessage
Expand All @@ -16,52 +15,62 @@
ClassificationOutput,
)

PAGE_CLASSIFICATION_SYSTEM_PROMPT = Template(
"""
You are a classification expert. You are given a single page to perform a classification task on.
{% if input.instructions %}\
Task Instructions:
{{ input.instructions }}
{% endif %}\
{% if input.type == "binary" %}\
You must classify the page with a binary label:
"YES"/"NO"
{% else %}\
Classify the page as {% if input.type == 'multi_label' %}all labels that apply{% else %}one of the following{% endif %}:
{% for label in input.formatted_labels %}
- {{ label }}
{% endfor %}\
These are the only label values you may use when providing your classifications!
{% endif %}\
It is crucial that your response is accurate and provides a valid answer using \
{% if input.type == 'multi_label' %}\
the labels \
{% else %}\
one of the labels \
{% endif %}\
above. There are consequences for providing INVALID or INACCURATE labels.
Answer in the following format:
Reasoning: { your reasoning and analysis }
{% if input.type == "binary" %}\
Answer: { "YES" or "NO" }
{% elif input.type == "single_label" %}\
Answer: { "label-value" }
{% else %}\
Answer: { "label-value", "label-value", ... }
{% endif %}\
{% if input.confidence %}\
Confidence: { low, medium, high }
{% endif %}\
You MUST ONLY use the labels provided and described above. Do not use ANY additional labels.
""".strip()
)

def get_classification_system_prompt(input: ClassificationConfig) -> str:
prompt_parts = [
"You are a classification expert. You are given a single page to perform a classification task on.\n"
]

if input.instructions:
prompt_parts.append(f"Task Instructions:\n{input.instructions}\n\n")

if input.type == "binary":
prompt_parts.append(
'You must classify the page with a binary label:\n"YES"/"NO"\n'
)
else:
classification_task = (
"all labels that apply"
if input.type == "multi_label"
else "one of the following"
)
prompt_parts.append(f"Classify the page as {classification_task}:\n")
for label in input.formatted_labels:
prompt_parts.append(f"- {label}\n")
prompt_parts.append(
"\nThese are the only label values you may use when providing your classifications!\n"
)

prompt_parts.append(
"\nIt is crucial that your response is accurate and provides a valid answer using "
)
if input.type == "multi_label":
prompt_parts.append("the labels ")
else:
prompt_parts.append("one of the labels ")
prompt_parts.append(
"above. There are consequences for providing INVALID or INACCURATE labels.\n\n"
)

prompt_parts.append(
"Answer in the following format:\n\nReasoning: { your reasoning and analysis }\n"
)

if input.type == "binary":
prompt_parts.append('Answer: { "YES" or "NO" }\n')
elif input.type == "single_label":
prompt_parts.append('Answer: { "label-value" }\n')
else:
prompt_parts.append('Answer: { "label-value", "label-value", ... }\n')

if input.confidence:
prompt_parts.append("Confidence: { low, medium, high }\n")

prompt_parts.append(
"\nYou MUST ONLY use the labels provided and described above. Do not use ANY additional labels.\n"
)

return "".join(prompt_parts).strip()


class AnthropicPageClassificationOutputParser(BasePageClassificationOutputParser):
Expand Down Expand Up @@ -109,7 +118,7 @@ def _prepare_messages(
),
OpenAIComplexContent(
type="text",
text=PAGE_CLASSIFICATION_SYSTEM_PROMPT.render(input=config),
text=get_classification_system_prompt(config),
),
],
),
Expand Down
6 changes: 3 additions & 3 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ dependencies = [
"tenacity>=7.0.0",
"pypdfium2<5.0.0,>=4.28.0",
"filetype>=1.2.0",
"jinja2>=3.1.4",
"beautifulsoup4>=4.12.3",
"pypdf>=5.0.0"
]
Expand Down

0 comments on commit 9ec94b1

Please sign in to comment.