22import copy
33from itertools import groupby
44from operator import itemgetter
5- from typing import Generator , List , Tuple , Union
5+ from typing import Generator , List , Tuple , Union , Iterator , Dict
66from uuid import uuid4
77
88from pydantic import BaseModel
2424 VideoMaskAnnotation ,
2525 VideoObjectAnnotation ,
2626)
27+ from labelbox .types import DocumentRectangle , DocumentEntity
2728from .classification import (
2829 NDChecklistSubclass ,
2930 NDClassification ,
@@ -61,9 +62,7 @@ class NDLabel(BaseModel):
6162 annotations : AnnotationType
6263
6364 @classmethod
64- def from_common (
65- cls , data : LabelCollection
66- ) -> Generator ["NDLabel" , None , None ]:
65+ def from_common (cls , data : LabelCollection ) -> Generator ["NDLabel" , None , None ]:
6766 for label in data :
6867 yield from cls ._create_relationship_annotations (label )
6968 yield from cls ._create_non_video_annotations (label )
@@ -127,16 +126,12 @@ def _create_video_annotations(
127126 if isinstance (
128127 annot , (VideoClassificationAnnotation , VideoObjectAnnotation )
129128 ):
130- video_annotations [annot .feature_schema_id or annot .name ].append (
131- annot
132- )
129+ video_annotations [annot .feature_schema_id or annot .name ].append (annot )
133130 elif isinstance (annot , VideoMaskAnnotation ):
134131 yield NDObject .from_common (annotation = annot , data = label .data )
135132
136133 for annotation_group in video_annotations .values ():
137- segment_frame_ranges = cls ._get_segment_frame_ranges (
138- annotation_group
139- )
134+ segment_frame_ranges = cls ._get_segment_frame_ranges (annotation_group )
140135 if isinstance (annotation_group [0 ], VideoClassificationAnnotation ):
141136 annotation = annotation_group [0 ]
142137 frames_data = []
@@ -169,6 +164,7 @@ def _create_non_video_annotations(cls, label: Label):
169164 VideoClassificationAnnotation ,
170165 VideoObjectAnnotation ,
171166 VideoMaskAnnotation ,
167+ RelationshipAnnotation ,
172168 ),
173169 )
174170 ]
@@ -179,8 +175,6 @@ def _create_non_video_annotations(cls, label: Label):
179175 yield NDObject .from_common (annotation , label .data )
180176 elif isinstance (annotation , (ScalarMetric , ConfusionMatrixMetric )):
181177 yield NDMetricAnnotation .from_common (annotation , label .data )
182- elif isinstance (annotation , RelationshipAnnotation ):
183- yield NDRelationship .from_common (annotation , label .data )
184178 elif isinstance (annotation , PromptClassificationAnnotation ):
185179 yield NDPromptClassification .from_common (annotation , label .data )
186180 elif isinstance (annotation , MessageEvaluationTaskAnnotation ):
@@ -191,19 +185,35 @@ def _create_non_video_annotations(cls, label: Label):
191185 )
192186
193187 @classmethod
194- def _create_relationship_annotations (cls , label : Label ):
188+ def _create_relationship_annotations (
189+ cls , label : Label
190+ ) -> Generator [NDRelationship , None , None ]:
195191 for annotation in label .annotations :
196192 if isinstance (annotation , RelationshipAnnotation ):
197193 uuid1 = uuid4 ()
198194 uuid2 = uuid4 ()
199195 source = copy .copy (annotation .value .source )
200196 target = copy .copy (annotation .value .target )
201- if not isinstance (source , ObjectAnnotation ) or not isinstance (
202- target , ObjectAnnotation
203- ):
197+
198+ # Check if source type is valid based on target type
199+ if isinstance (target .value , (DocumentRectangle , DocumentEntity )):
200+ if not isinstance (
201+ source , (ObjectAnnotation , ClassificationAnnotation )
202+ ):
203+ raise TypeError (
204+ f"Unable to create relationship with invalid source. For PDF targets, "
205+ f"source must be ObjectAnnotation or ClassificationAnnotation. Got: { type (source )} "
206+ )
207+ elif not isinstance (source , ObjectAnnotation ):
204208 raise TypeError (
205- f"Unable to create relationship with non ObjectAnnotations. `Source : { type (source )} Target: { type ( target ) } ` "
209+ f"Unable to create relationship with non ObjectAnnotation source : { type (source )} "
206210 )
211+
212+ if not isinstance (target , ObjectAnnotation ):
213+ raise TypeError (
214+ f"Unable to create relationship with non ObjectAnnotation target: { type (target )} "
215+ )
216+
207217 if not source ._uuid :
208218 source ._uuid = uuid1
209219 if not target ._uuid :
0 commit comments