diff --git a/dco_fix b/dco_fix new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/dco_fix @@ -0,0 +1 @@ + diff --git a/docling_core/transforms/serializer/doctags.py b/docling_core/transforms/serializer/doctags.py index 807b7750..881a5e24 100644 --- a/docling_core/transforms/serializer/doctags.py +++ b/docling_core/transforms/serializer/doctags.py @@ -23,6 +23,7 @@ from docling_core.transforms.serializer.common import ( CommonParams, DocSerializer, + _iterate_items, _should_use_legacy_annotations, create_ser_result, ) @@ -33,6 +34,7 @@ DoclingDocument, FloatingItem, FormItem, + FormulaItem, GroupItem, InlineGroup, KeyValueItem, @@ -79,6 +81,18 @@ class Mode(str, Enum): do_self_closing: bool = False + # Task-specific parameters for filtering content + include_ocr: bool = True + include_layout: bool = True + include_otsl: bool = True + include_code: bool = True + include_picture: bool = True + include_chart: bool = True + include_formula: bool = True + + # Layout mode: when True, only include structure with locations, no content + layout_mode_only: bool = False + def _get_delim(params: DocTagsParams) -> str: if params.mode == DocTagsParams.Mode.HUMAN_FRIENDLY: @@ -90,6 +104,61 @@ def _get_delim(params: DocTagsParams) -> str: return delim +def create_task_filtered_params(task_list: list[str], **kwargs) -> DocTagsParams: + """Create DocTagsParams with task-specific filtering enabled. + + Args: + task_list: List of tasks to include (e.g., ['ocr', 'layout', 'otsl']) + **kwargs: Additional parameters to override defaults + + Returns: + DocTagsParams configured for the specified tasks + """ + # Default: include all tasks + params = { + "include_ocr": True, + "include_layout": True, + "include_otsl": True, + "include_code": True, + "include_picture": True, + "include_chart": True, + "include_formula": True, + "add_location": True, + "add_content": True, + "add_caption": True, + } + + # Override based on task list + if task_list: + params["include_ocr"] = "ocr" in task_list + params["include_layout"] = "layout" in task_list + params["include_otsl"] = "otsl" in task_list + params["include_code"] = "code" in task_list + params["include_picture"] = "picture" in task_list + params["include_chart"] = "chart" in task_list + params["include_formula"] = "formula" in task_list + + # Special handling for layout mode + if "layout" in task_list: + # When layout is present, always include locations + params["add_location"] = True + # For layout mode, we want to show structure but strip content + # The individual serializers will handle content stripping based on task parameters + params["add_content"] = True # Let serializers handle content filtering + else: + # When layout is NOT present, only include locations for specific tasks + params["add_location"] = False + # Only include content for explicitly requested tasks + # But don't disable add_content globally - let individual serializers handle it + if not params["include_otsl"]: + params["add_table_cell_text"] = False + + # Apply any additional kwargs + params.update(kwargs) + + return DocTagsParams(**params) + + class DocTagsTextSerializer(BaseModel, BaseTextSerializer): """DocTags-specific text item serializer.""" @@ -117,7 +186,36 @@ def serialize( if meta_res.text: parts.append(meta_res.text) - if params.add_location: + # Check if this item type should be included based on task parameters + should_include_content = True + if isinstance(item, CodeItem) and not params.include_code: + should_include_content = False + elif isinstance(item, FormulaItem) and not params.include_formula: + should_include_content = False + elif ( + isinstance(item, (TextItem, SectionHeaderItem, ListItem)) + and not params.include_ocr + ): + should_include_content = False + + # In layout mode, skip content for items whose specific flags are disabled + if params.include_layout: + if isinstance(item, CodeItem) and not params.include_code: + should_include_content = False + elif isinstance(item, FormulaItem) and not params.include_formula: + should_include_content = False + + # For code items, if include_code is False, return empty result + # Exception: in layout mode, always include items (with location but no content) + if not params.include_layout: + if isinstance(item, CodeItem) and not params.include_code: + return create_ser_result() + + # For formula items, if include_formula is False, return empty result + if isinstance(item, FormulaItem) and not params.include_formula: + return create_ser_result() + + if params.add_location and params.include_layout: location = item.get_location_tokens( doc=doc, xsize=params.xsize, @@ -127,7 +225,7 @@ def serialize( if location: parts.append(location) - if params.add_content: + if params.add_content and should_include_content: if ( item.text == "" and len(item.children) == 1 @@ -191,7 +289,7 @@ def serialize( res_parts: list[SerializationResult] = [] if item.self_ref not in doc_serializer.get_excluded_refs(**kwargs): - if params.add_location: + if params.add_location and params.include_layout: loc_text = item.get_location_tokens( doc=doc, xsize=params.xsize, @@ -200,16 +298,18 @@ def serialize( ) res_parts.append(create_ser_result(text=loc_text, span_source=item)) - otsl_text = item.export_to_otsl( - doc=doc, - add_cell_location=params.add_table_cell_location, - add_cell_text=params.add_table_cell_text, - xsize=params.xsize, - ysize=params.ysize, - visited=visited, - table_token=self._get_table_token(), - ) - res_parts.append(create_ser_result(text=otsl_text, span_source=item)) + # Check if OTSL content should be included + if params.include_otsl: + otsl_text = item.export_to_otsl( + doc=doc, + add_cell_location=params.add_table_cell_location, + add_cell_text=params.add_table_cell_text, + xsize=params.xsize, + ysize=params.ysize, + visited=visited, + table_token=self._get_table_token(), + ) + res_parts.append(create_ser_result(text=otsl_text, span_source=item)) if params.add_caption: cap_res = doc_serializer.serialize_captions(item=item, **kwargs) @@ -226,99 +326,113 @@ def serialize( class DocTagsPictureSerializer(BasePictureSerializer): """DocTags-specific picture item serializer.""" - @override - def serialize( + def _get_predicted_class( + self, item: PictureItem, params: DocTagsParams + ) -> Optional[str]: + """Get the predicted class from item metadata or annotations.""" + if item.meta: + if item.meta.classification: + return item.meta.classification.get_main_prediction().class_name + elif _should_use_legacy_annotations( + params=params, + item=item, + kind=PictureClassificationData.model_fields["kind"].default, + ): + if classifications := [ + ann + for ann in item.annotations + if isinstance(ann, PictureClassificationData) + ]: + if classifications[0].predicted_classes: + return classifications[0].predicted_classes[0].class_name + return None + + def _is_chart_type(self, predicted_class: Optional[str]) -> bool: + """Check if predicted class indicates a chart.""" + if not predicted_class: + return False + return predicted_class in [ + PictureClassificationLabel.PIE_CHART, + PictureClassificationLabel.BAR_CHART, + PictureClassificationLabel.STACKED_BAR_CHART, + PictureClassificationLabel.LINE_CHART, + PictureClassificationLabel.FLOW_CHART, + PictureClassificationLabel.SCATTER_CHART, + PictureClassificationLabel.HEATMAP, + ] + + def _get_molecule_smi( + self, item: PictureItem, params: DocTagsParams + ) -> Optional[str]: + """Get SMILES string from item metadata or annotations.""" + if item.meta: + if item.meta.molecule: + return item.meta.molecule.smi + elif _should_use_legacy_annotations( + params=params, + item=item, + kind=PictureMoleculeData.model_fields["kind"].default, + ): + if smiles_annotations := [ + ann for ann in item.annotations if isinstance(ann, PictureMoleculeData) + ]: + return smiles_annotations[0].smi + return None + + def _get_tabular_chart_data( + self, item: PictureItem, params: DocTagsParams + ) -> Optional[TableData]: + """Get tabular chart data from item metadata or annotations.""" + if item.meta: + if item.meta.tabular_chart: + return item.meta.tabular_chart.chart_data + elif _should_use_legacy_annotations( + params=params, + item=item, + kind=PictureTabularChartData.model_fields["kind"].default, + ): + if tabular_chart_annotations := [ + ann + for ann in item.annotations + if isinstance(ann, PictureTabularChartData) + ]: + return tabular_chart_annotations[0].chart_data + return None + + def _build_body_content( self, - *, item: PictureItem, - doc_serializer: BaseDocSerializer, doc: DoclingDocument, - **kwargs: Any, - ) -> SerializationResult: - """Serializes the passed item.""" - params = DocTagsParams(**kwargs) - res_parts: list[SerializationResult] = [] - is_chart = False + params: DocTagsParams, + predicted_class: Optional[str], + is_chart: bool, + ) -> str: + """Build the body content for the picture item.""" + body = "" + if params.add_location and params.include_layout: + body += item.get_location_tokens( + doc=doc, + xsize=params.xsize, + ysize=params.ysize, + self_closing=params.do_self_closing, + ) - if item.self_ref not in doc_serializer.get_excluded_refs(**kwargs): - body = "" - if params.add_location: - body += item.get_location_tokens( - doc=doc, - xsize=params.xsize, - ysize=params.ysize, - self_closing=params.do_self_closing, - ) + should_include_content = True + if params.include_layout: + if is_chart and not params.include_chart: + should_include_content = False + elif not is_chart and not params.include_picture: + should_include_content = False - # handle classification data - predicted_class: Optional[str] = None - if item.meta: - if item.meta.classification: - predicted_class = ( - item.meta.classification.get_main_prediction().class_name - ) - elif _should_use_legacy_annotations( - params=params, - item=item, - kind=PictureClassificationData.model_fields["kind"].default, - ): - if classifications := [ - ann - for ann in item.annotations - if isinstance(ann, PictureClassificationData) - ]: - if classifications[0].predicted_classes: - predicted_class = ( - classifications[0].predicted_classes[0].class_name - ) - if predicted_class: - body += DocumentToken.get_picture_classification_token(predicted_class) - if predicted_class in [ - PictureClassificationLabel.PIE_CHART, - PictureClassificationLabel.BAR_CHART, - PictureClassificationLabel.STACKED_BAR_CHART, - PictureClassificationLabel.LINE_CHART, - PictureClassificationLabel.FLOW_CHART, - PictureClassificationLabel.SCATTER_CHART, - PictureClassificationLabel.HEATMAP, - ]: - is_chart = True - - # handle molecule data - smi: Optional[str] = None - if item.meta: - if item.meta.molecule: - smi = item.meta.molecule.smi - elif _should_use_legacy_annotations( - params=params, - item=item, - kind=PictureMoleculeData.model_fields["kind"].default, - ): - if smiles_annotations := [ - ann - for ann in item.annotations - if isinstance(ann, PictureMoleculeData) - ]: - smi = smiles_annotations[0].smi + if should_include_content and predicted_class: + body += DocumentToken.get_picture_classification_token(predicted_class) + + if should_include_content: + smi = self._get_molecule_smi(item, params) if smi: body += _wrap(text=smi, wrap_tag=DocumentToken.SMILES.value) - # handle tabular chart data - chart_data: Optional[TableData] = None - if item.meta: - if item.meta.tabular_chart: - chart_data = item.meta.tabular_chart.chart_data - elif _should_use_legacy_annotations( - params=params, - item=item, - kind=PictureTabularChartData.model_fields["kind"].default, - ): - if tabular_chart_annotations := [ - ann - for ann in item.annotations - if isinstance(ann, PictureTabularChartData) - ]: - chart_data = tabular_chart_annotations[0].chart_data + chart_data = self._get_tabular_chart_data(item, params) if chart_data and chart_data.table_cells: temp_doc = DoclingDocument(name="temp") temp_table = temp_doc.add_table(data=chart_data) @@ -326,6 +440,34 @@ def serialize( temp_doc, add_cell_location=False ) body += otsl_content + return body + + @override + def serialize( + self, + *, + item: PictureItem, + doc_serializer: BaseDocSerializer, + doc: DoclingDocument, + **kwargs: Any, + ) -> SerializationResult: + """Serializes the passed item.""" + params = DocTagsParams(**kwargs) + res_parts: list[SerializationResult] = [] + + predicted_class = self._get_predicted_class(item, params) + is_chart = self._is_chart_type(predicted_class) + + if not params.include_layout: + if is_chart and not params.include_chart: + return create_ser_result() + elif not is_chart and not params.include_picture: + return create_ser_result() + + if item.self_ref not in doc_serializer.get_excluded_refs(**kwargs): + body = self._build_body_content( + item, doc, params, predicted_class, is_chart + ) res_parts.append(create_ser_result(text=body, span_source=item)) if params.add_caption: @@ -363,7 +505,7 @@ def serialize( if len(item.prov) > 0: page_no = item.prov[0].page_no - if params.add_location: + if params.add_location and params.include_layout: body += item.get_location_tokens( doc=doc, xsize=params.xsize, @@ -380,7 +522,7 @@ def serialize( for cell in item.graph.cells: cell_txt = "" - if cell.prov is not None: + if cell.prov is not None and params.add_location and params.include_layout: if len(doc.pages.keys()): page_w, page_h = doc.pages[page_no].size.as_tuple() cell_txt += DocumentToken.get_location( @@ -532,7 +674,7 @@ def serialize( my_visited = visited if visited is not None else set() params = DocTagsParams(**kwargs) parts: List[SerializationResult] = [] - if params.add_location: + if params.add_location and params.include_layout: inline_loc_tags_ser_res = self._get_inline_location_tags( doc=doc, item=item, @@ -605,6 +747,80 @@ class DocTagsDocSerializer(DocSerializer): params: DocTagsParams = DocTagsParams() + @override + def get_parts( + self, + item: Optional[NodeItem] = None, + *, + traverse_pictures: bool = False, + list_level: int = 0, + is_inline_scope: bool = False, + visited: Optional[set[str]] = None, + **kwargs: Any, + ) -> list[SerializationResult]: + """Get the components to be combined for serializing this node with task filtering.""" + parts: list[SerializationResult] = [] + my_visited: set[str] = visited if visited is not None else set() + params = self.params.merge_with_patch(patch=kwargs) + for node, lvl in _iterate_items( + doc=self.doc, + layers=params.layers, + node=item, + traverse_pictures=traverse_pictures, + add_page_breaks=self.requires_page_break(), + ): + if node.self_ref in my_visited: + continue + else: + my_visited.add(node.self_ref) + + # Task-based filtering: only process items that match the requested tasks + should_process = self._should_process_item(node, params) + if not should_process: + continue + + part = self.serialize( + item=node, + list_level=list_level, + is_inline_scope=is_inline_scope, + visited=my_visited, + **(dict(level=lvl) | kwargs), + ) + if part.text: + parts.append(part) + return parts + + def _should_process_item(self, node: NodeItem, params: DocTagsParams) -> bool: + """Determine if an item should be processed based on task parameters.""" + if not isinstance(node, DocItem): + return True # Process non-DocItem nodes (groups, etc.) + + # For layout mode, include all elements (they'll be processed with locations but no content) + if params.include_layout: + return True + + # For non-layout mode, only include elements for explicitly requested tasks + # Tables: allow through if layout is enabled (for locations/captions) or if OTSL is enabled + if isinstance(node, TableItem): + if params.include_layout or params.include_otsl: + return True + return False + elif isinstance(node, PictureItem) and not params.include_picture: + return False + elif isinstance(node, CodeItem) and not params.include_code: + return False + elif isinstance(node, FormulaItem) and not params.include_formula: + return False + elif ( + isinstance(node, (TextItem, SectionHeaderItem, ListItem)) + and not params.include_ocr + ): + return False + elif isinstance(node, FormItem) and not params.include_ocr: + return False + + return True + @override def serialize_doc( self, @@ -635,9 +851,10 @@ def serialize_captions( params = DocTagsParams(**kwargs) results: list[SerializationResult] = [] if item.captions: - cap_res = super().serialize_captions(item, **kwargs) - if cap_res.text: - if params.add_location: + # Always include caption structure when layout is present + if params.include_layout: + # For layout mode, include captions with locations but without content + if params.add_location and params.include_layout: for caption in item.captions: if caption.cref not in self.get_excluded_refs(**kwargs): if isinstance(cap := caption.resolve(self.doc), DocItem): @@ -648,7 +865,33 @@ def serialize_captions( self_closing=params.do_self_closing, ) results.append(create_ser_result(text=loc_txt)) - results.append(cap_res) + # Don't include caption content when only layout is requested + if not params.include_ocr: + pass # Skip content, only include locations + else: + # Include content when OCR is also requested + cap_res = super().serialize_captions(item, **kwargs) + if cap_res.text: + results.append(cap_res) + else: + # For non-layout mode, only include captions if OCR is requested + if params.include_ocr: + cap_res = super().serialize_captions(item, **kwargs) + if cap_res.text: + if params.add_location and params.include_layout: + for caption in item.captions: + if caption.cref not in self.get_excluded_refs(**kwargs): + if isinstance( + cap := caption.resolve(self.doc), DocItem + ): + loc_txt = cap.get_location_tokens( + doc=self.doc, + xsize=params.xsize, + ysize=params.ysize, + self_closing=params.do_self_closing, + ) + results.append(create_ser_result(text=loc_txt)) + results.append(cap_res) text_res = "".join([r.text for r in results]) if text_res: text_res = _wrap(text=text_res, wrap_tag=DocumentToken.CAPTION.value) diff --git a/test/test_doctags_filtering.py b/test/test_doctags_filtering.py new file mode 100644 index 00000000..f77188a6 --- /dev/null +++ b/test/test_doctags_filtering.py @@ -0,0 +1,239 @@ +"""Test DocTags serialization filtering functionality.""" + +from docling_core.transforms.serializer.doctags import ( + DocTagsDocSerializer, + DocTagsParams, + create_task_filtered_params, +) + + +def test_create_task_filtered_params_defaults(): + """Test default behavior.""" + params = create_task_filtered_params([]) + assert params.include_ocr is True + assert params.include_layout is True + assert params.include_otsl is True + assert params.include_code is True + assert params.include_picture is True + assert params.include_chart is True + assert params.include_formula is True + + +def test_create_task_filtered_params_specific_tasks(): + """Test with specific task list.""" + params = create_task_filtered_params(["ocr", "layout"]) + assert params.include_ocr is True + assert params.include_layout is True + assert params.include_otsl is False + assert params.include_code is False + assert params.include_picture is False + assert params.include_chart is False + assert params.include_formula is False + + +def test_create_task_filtered_params_with_layout(): + """Test with layout in task list.""" + params = create_task_filtered_params(["layout"]) + assert params.include_layout is True + assert params.add_location is True + + +def test_create_task_filtered_params_without_layout(): + """Test without layout in task list.""" + params = create_task_filtered_params(["ocr"]) + assert params.include_layout is False + assert params.add_location is False + + +def test_create_task_filtered_params_with_kwargs(): + """Test with additional kwargs.""" + params = create_task_filtered_params(["ocr"], xsize=1000, ysize=1000) + assert params.xsize == 1000 + assert params.ysize == 1000 + assert params.include_ocr is True + + +def test_doctags_exclude_ocr(sample_doc): + """Test excluding OCR.""" + serializer = DocTagsDocSerializer(doc=sample_doc) + serializer.params = serializer.params.merge_with_patch( + DocTagsParams(include_ocr=False).model_dump() + ) + result = serializer.serialize() + assert result.text is not None + + +def test_doctags_exclude_otsl(sample_doc): + """Test excluding OTSL.""" + serializer = DocTagsDocSerializer(doc=sample_doc) + serializer.params = serializer.params.merge_with_patch( + DocTagsParams(include_otsl=False, include_layout=True).model_dump() + ) + result = serializer.serialize() + assert result.text is not None + + +def test_doctags_exclude_picture(sample_doc): + """Test excluding pictures.""" + serializer = DocTagsDocSerializer(doc=sample_doc) + serializer.params = serializer.params.merge_with_patch( + DocTagsParams(include_picture=False).model_dump() + ) + result = serializer.serialize() + assert result.text is not None + + +def test_doctags_exclude_chart(sample_doc): + """Test excluding charts.""" + serializer = DocTagsDocSerializer(doc=sample_doc) + serializer.params = serializer.params.merge_with_patch( + DocTagsParams(include_chart=False).model_dump() + ) + result = serializer.serialize() + assert result.text is not None + + +def test_doctags_exclude_code(sample_doc): + """Test excluding code.""" + serializer = DocTagsDocSerializer(doc=sample_doc) + serializer.params = serializer.params.merge_with_patch( + DocTagsParams(include_code=False).model_dump() + ) + result = serializer.serialize() + assert result.text is not None + + +def test_doctags_exclude_formula(sample_doc): + """Test excluding formulas.""" + serializer = DocTagsDocSerializer(doc=sample_doc) + serializer.params = serializer.params.merge_with_patch( + DocTagsParams(include_formula=False).model_dump() + ) + result = serializer.serialize() + assert result.text is not None + + +def test_doctags_no_layout_no_locations(sample_doc): + """Test no locations when layout is disabled.""" + serializer = DocTagsDocSerializer(doc=sample_doc) + serializer.params = serializer.params.merge_with_patch( + DocTagsParams(include_layout=False, add_location=True).model_dump() + ) + result = serializer.serialize() + assert "