Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 143 additions & 6 deletions docling_core/experimental/idoctags.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Define classes for DocTags serialization."""

from enum import Enum
from typing import Any, Final, Optional
from typing import Any, Final, Optional, Tuple
from xml.dom.minidom import parseString

from pydantic import BaseModel
Expand Down Expand Up @@ -38,7 +38,10 @@
TabularChartMetaField,
)
from docling_core.types.doc.labels import DocItemLabel
from docling_core.types.doc.tokens import DocumentToken
from docling_core.types.doc.tokens import (
_CodeLanguageToken,
_PictureClassificationToken,
)

DOCTAGS_VERSION: Final = "1.0.0"

Expand All @@ -61,6 +64,127 @@ class IDocTagsTableToken(str, Enum):
OTSL_RHED = "<rhed/>" # - row header cell,
OTSL_SROW = "<srow/>" # - section row cell

@classmethod
def get_special_tokens(
cls,
):
"""Return all table-related special tokens.

Includes the opening/closing OTSL tags and each enum token value.
"""
special_tokens: list[str] = ["<otsl>", "</otsl>"]
for token in cls:
special_tokens.append(f"{token.value}")

return special_tokens


class IDocTagsToken(str, Enum):
"""IDocTagsToken."""

_LOC_PREFIX = "loc_"
_SECTION_HEADER_PREFIX = "section_header_level_"

DOCUMENT = "doctag"
VERSION = "version"

OTSL = "otsl"
ORDERED_LIST = "ordered_list"
UNORDERED_LIST = "unordered_list"

PAGE_BREAK = "page_break"

CAPTION = "caption"
FOOTNOTE = "footnote"
FORMULA = "formula"
LIST_ITEM = "list_item"
PAGE_FOOTER = "page_footer"
PAGE_HEADER = "page_header"
PICTURE = "picture"
SECTION_HEADER = "section_header"
TABLE = "table"
TEXT = "text"
TITLE = "title"
DOCUMENT_INDEX = "document_index"
CODE = "code"
CHECKBOX_SELECTED = "checkbox_selected"
CHECKBOX_UNSELECTED = "checkbox_unselected"
FORM = "form"
EMPTY_VALUE = "empty_value" # used for empty value fields in fillable forms

@classmethod
def get_special_tokens(
cls,
*,
page_dimension: Tuple[int, int] = (500, 500),
include_location_tokens: bool = True,
include_code_class: bool = False,
include_picture_class: bool = False,
):
"""Function to get all special document tokens."""
special_tokens: list[str] = []
for token in cls:
if not token.value.endswith("_"):
special_tokens.append(f"<{token.value}>")
special_tokens.append(f"</{token.value}>")

for i in range(6):
special_tokens += [
f"<{IDocTagsToken._SECTION_HEADER_PREFIX.value}{i}>",
f"</{IDocTagsToken._SECTION_HEADER_PREFIX.value}{i}>",
]

special_tokens.extend(IDocTagsTableToken.get_special_tokens())

if include_picture_class:
special_tokens.extend([t.value for t in _PictureClassificationToken])

if include_code_class:
special_tokens.extend([t.value for t in _CodeLanguageToken])

if include_location_tokens:
# Adding dynamically generated location-tokens
for i in range(0, max(page_dimension[0], page_dimension[1])):
special_tokens.append(f"<{IDocTagsToken._LOC_PREFIX.value}{i}/>")

return special_tokens

@classmethod
def create_token_name_from_doc_item_label(cls, label: str, level: int = 1) -> str:
"""Get token corresponding to passed doc item label."""
doc_token_by_item_label = {
DocItemLabel.CAPTION: IDocTagsToken.CAPTION,
DocItemLabel.FOOTNOTE: IDocTagsToken.FOOTNOTE,
DocItemLabel.FORMULA: IDocTagsToken.FORMULA,
DocItemLabel.LIST_ITEM: IDocTagsToken.LIST_ITEM,
DocItemLabel.PAGE_FOOTER: IDocTagsToken.PAGE_FOOTER,
DocItemLabel.PAGE_HEADER: IDocTagsToken.PAGE_HEADER,
DocItemLabel.PICTURE: IDocTagsToken.PICTURE,
DocItemLabel.TABLE: IDocTagsToken.TABLE,
DocItemLabel.TEXT: IDocTagsToken.TEXT,
DocItemLabel.TITLE: IDocTagsToken.TITLE,
DocItemLabel.DOCUMENT_INDEX: IDocTagsToken.DOCUMENT_INDEX,
DocItemLabel.CODE: IDocTagsToken.CODE,
DocItemLabel.CHECKBOX_SELECTED: IDocTagsToken.CHECKBOX_SELECTED,
DocItemLabel.CHECKBOX_UNSELECTED: IDocTagsToken.CHECKBOX_UNSELECTED,
DocItemLabel.FORM: IDocTagsToken.FORM,
# Fallback mappings for labels without dedicated tokens in IDocTagsToken
DocItemLabel.KEY_VALUE_REGION: IDocTagsToken.TEXT,
DocItemLabel.PARAGRAPH: IDocTagsToken.TEXT,
DocItemLabel.REFERENCE: IDocTagsToken.TEXT,
DocItemLabel.CHART: IDocTagsToken.PICTURE,
}

res: str
if label == DocItemLabel.SECTION_HEADER:
res = f"{IDocTagsToken._SECTION_HEADER_PREFIX}{level}"
else:
try:
res = doc_token_by_item_label[DocItemLabel(label)].value
except KeyError as e:
raise RuntimeError(f"Unexpected DocItemLabel: {label}") from e
return res


class IDocTagsParams(DocTagsParams):
"""DocTags-specific serialization parameters."""
Expand Down Expand Up @@ -187,6 +311,8 @@ def serialize(
otsl_content = temp_table.export_to_otsl(
temp_doc,
add_cell_location=False,
# Suppress chart cell text if global content is off
add_cell_text=params.add_content,
self_closing=params.do_self_closing,
table_token=IDocTagsTableToken,
)
Expand All @@ -200,7 +326,7 @@ def serialize(

text_res = "".join([r.text for r in res_parts])
if text_res:
token = DocumentToken.create_token_name_from_doc_item_label(
token = IDocTagsToken.create_token_name_from_doc_item_label(
label=DocItemLabel.CHART if is_chart else DocItemLabel.PICTURE,
)
text_res = _wrap(text=text_res, wrap_tag=token)
Expand Down Expand Up @@ -238,12 +364,20 @@ def serialize_doc(
text_res = delim.join([p.text for p in parts if p.text])

if self.params.add_page_break:
page_sep = f"<{DocumentToken.PAGE_BREAK.value}{'/' if self.params.do_self_closing else ''}>"
page_sep = f"<{IDocTagsToken.PAGE_BREAK.value}{'/' if self.params.do_self_closing else ''}>"
for full_match, _, _ in self._get_page_breaks(text=text_res):
text_res = text_res.replace(full_match, page_sep)

wrap_tag = DocumentToken.DOCUMENT.value
text_res = f"<{wrap_tag}><version>{DOCTAGS_VERSION}</version>{text_res}{delim}</{wrap_tag}>"
# print(f"text-res-v1: {text_res}")

tmp = f"<{IDocTagsToken.DOCUMENT.value}>"
tmp += f"<{IDocTagsToken.VERSION.value}>{DOCTAGS_VERSION}</{IDocTagsToken.VERSION.value}>"
# text_res += f"{text_res}{delim}"
tmp += f"{text_res}"
tmp += f"</{IDocTagsToken.DOCUMENT.value}>"

# print(f"text-res-v2: {tmp}")
text_res = tmp

if self.params.pretty_indentation and (
my_root := parseString(text_res).documentElement
Expand All @@ -252,4 +386,7 @@ def serialize_doc(
text_res = "\n".join(
[line for line in text_res.split("\n") if line.strip()]
)

print(f"text-res-v3:\n{text_res}")

return create_ser_result(text=text_res, span_source=parts)
140 changes: 106 additions & 34 deletions docling_core/transforms/serializer/doctags.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,16 @@ def serialize(
"""Serializes the passed item."""
my_visited = visited if visited is not None else set()
params = DocTagsParams(**kwargs)
wrap_tag: Optional[str] = DocumentToken.create_token_name_from_doc_item_label(
label=item.label,
**({"level": item.level} if isinstance(item, SectionHeaderItem) else {}),
# Decide wrapping up-front so ListItem never gets wrapped here
wrap_tag_token: Optional[str] = (
DocumentToken.create_token_name_from_doc_item_label(
label=item.label,
**(
{"level": item.level} if isinstance(item, SectionHeaderItem) else {}
),
)
)
wrap_tag: Optional[str] = None if isinstance(item, ListItem) else wrap_tag_token
parts: list[str] = []

if item.meta:
Expand Down Expand Up @@ -152,8 +158,6 @@ def serialize(
text_part = f"{language_token}{text_part}"
else:
text_part = text_part.strip()
if isinstance(item, ListItem):
wrap_tag = None # deferring list item tags to list handling

if text_part:
parts.append(text_part)
Expand Down Expand Up @@ -203,7 +207,8 @@ def serialize(
otsl_text = item.export_to_otsl(
doc=doc,
add_cell_location=params.add_table_cell_location,
add_cell_text=params.add_table_cell_text,
# Suppress cell text when global content is disabled
add_cell_text=(params.add_table_cell_text and params.add_content),
xsize=params.xsize,
ysize=params.ysize,
visited=visited,
Expand Down Expand Up @@ -452,22 +457,87 @@ def serialize(
"""Serializes the passed item."""
my_visited = visited if visited is not None else set()
params = DocTagsParams(**kwargs)
parts = doc_serializer.get_parts(
item=item,
list_level=list_level + 1,
is_inline_scope=is_inline_scope,
visited=my_visited,
**kwargs,
)
delim = _get_delim(params=params)
if parts:
text_res = delim.join(
[
t
for p in parts
if (t := _wrap(text=p.text, wrap_tag=DocumentToken.LIST_ITEM.value))
]

# Build list children explicitly. Requirements:
# 1) <ordered_list>/<unordered_list> can be children of lists.
# 2) Do NOT wrap nested lists into <list_item>, even if they are
# children of a ListItem in the logical structure.
# 3) Still ensure structural wrappers are preserved even when
# content is suppressed (e.g., add_content=False).
item_results: list[SerializationResult] = []
child_results_wrapped: list[str] = []

excluded = doc_serializer.get_excluded_refs(**kwargs)
for child_ref in item.children:
child = child_ref.resolve(doc)

# If a nested list group is present directly under this list group,
# emit it as a sibling (no <list_item> wrapper).
if isinstance(child, ListGroup):
if child.self_ref in my_visited or child.self_ref in excluded:
continue
my_visited.add(child.self_ref)
sub_res = doc_serializer.serialize(
item=child,
list_level=list_level + 1,
is_inline_scope=is_inline_scope,
visited=my_visited,
**kwargs,
)
if sub_res.text:
child_results_wrapped.append(sub_res.text)
item_results.append(sub_res)
continue

# Normal case: ListItem under ListGroup
if not isinstance(child, ListItem):
continue
if child.self_ref in my_visited or child.self_ref in excluded:
continue

my_visited.add(child.self_ref)

# Serialize the list item content (DocTagsTextSerializer will not wrap it)
child_res = doc_serializer.serialize(
item=child,
list_level=list_level + 1,
is_inline_scope=is_inline_scope,
visited=my_visited,
**kwargs,
)
item_results.append(child_res)

# Wrap the content into <list_item>, without any nested list content.
child_text_wrapped = _wrap(
text=f"{child_res.text}",
wrap_tag=DocumentToken.LIST_ITEM.value,
)
child_results_wrapped.append(child_text_wrapped)

# After the <list_item>, append any nested lists (children of this ListItem)
# as siblings at the same level (not wrapped in <list_item>).
for subref in child.children:
sub = subref.resolve(doc)
if (
isinstance(sub, ListGroup)
and sub.self_ref not in my_visited
and sub.self_ref not in excluded
):
my_visited.add(sub.self_ref)
sub_res = doc_serializer.serialize(
item=sub,
list_level=list_level + 1,
is_inline_scope=is_inline_scope,
visited=my_visited,
**kwargs,
)
if sub_res.text:
child_results_wrapped.append(sub_res.text)
item_results.append(sub_res)

delim = _get_delim(params=params)
if child_results_wrapped:
text_res = delim.join(child_results_wrapped)
text_res = f"{text_res}{delim}"
wrap_tag = (
DocumentToken.ORDERED_LIST.value
Expand All @@ -477,7 +547,8 @@ def serialize(
text_res = _wrap(text=text_res, wrap_tag=wrap_tag)
else:
text_res = ""
return create_ser_result(text=text_res, span_source=parts)

return create_ser_result(text=text_res, span_source=item_results)


class DocTagsInlineSerializer(BaseInlineSerializer):
Expand Down Expand Up @@ -636,18 +707,19 @@ def serialize_captions(
results: list[SerializationResult] = []
if item.captions:
cap_res = super().serialize_captions(item, **kwargs)
if cap_res.text:
if params.add_location:
for caption in item.captions:
if caption.cref not in self.get_excluded_refs(**kwargs):
if isinstance(cap := caption.resolve(self.doc), DocItem):
loc_txt = cap.get_location_tokens(
doc=self.doc,
xsize=params.xsize,
ysize=params.ysize,
self_closing=params.do_self_closing,
)
results.append(create_ser_result(text=loc_txt))
if cap_res.text and params.add_location:
for caption in item.captions:
if caption.cref not in self.get_excluded_refs(**kwargs):
if isinstance(cap := caption.resolve(self.doc), DocItem):
loc_txt = cap.get_location_tokens(
doc=self.doc,
xsize=params.xsize,
ysize=params.ysize,
self_closing=params.do_self_closing,
)
results.append(create_ser_result(text=loc_txt))
# Only include caption textual content when add_content is True
if cap_res.text and params.add_content:
results.append(cap_res)
text_res = "".join([r.text for r in results])
if text_res:
Expand Down
Loading
Loading