11from typing import Any , Dict , List , Union , Optional
22
3+ from labelbox .data .annotation_types import ImageData , TextData , VideoData
34from labelbox .data .mixins import ConfidenceMixin , CustomMetric , CustomMetricsMixin
45from labelbox .data .serialization .ndjson .base import DataRow , NDAnnotation
56
7+ from ....annotated_types import Cuid
8+
69from ...annotation_types .annotation import ClassificationAnnotation
710from ...annotation_types .video import VideoClassificationAnnotation
811from ...annotation_types .llm_prompt_response .prompt import PromptClassificationAnnotation , PromptText
912from ...annotation_types .classification .classification import ClassificationAnswer , Text , Checklist , Radio
10- from ...annotation_types .types import Cuid
11- from ...annotation_types .data import TextData , VideoData , ImageData
1213from pydantic import model_validator , Field , BaseModel , ConfigDict , model_serializer
1314from pydantic .alias_generators import to_camel
1415from .base import _SubclassRegistryBase
@@ -18,11 +19,12 @@ class NDAnswer(ConfidenceMixin, CustomMetricsMixin):
1819 name : Optional [str ] = None
1920 schema_id : Optional [Cuid ] = None
2021 classifications : Optional [List ['NDSubclassificationType' ]] = None
21- model_config = ConfigDict (populate_by_name = True , alias_generator = to_camel )
22+ model_config = ConfigDict (populate_by_name = True , alias_generator = to_camel )
2223
2324 @model_validator (mode = "after" )
2425 def must_set_one (self ):
25- if (not hasattr (self , "schema_id" ) or self .schema_id is None ) and (not hasattr (self , "name" ) or self .name is None ):
26+ if (not hasattr (self , "schema_id" ) or self .schema_id
27+ is None ) and (not hasattr (self , "name" ) or self .name is None ):
2628 raise ValueError ("Schema id or name are not set. Set either one." )
2729 return self
2830
@@ -102,7 +104,10 @@ def from_common(cls, checklist: Checklist, name: str,
102104 NDAnswer (name = answer .name ,
103105 schema_id = answer .feature_schema_id ,
104106 confidence = answer .confidence ,
105- classifications = [NDSubclassification .from_common (annot ) for annot in answer .classifications ] if answer .classifications else None ,
107+ classifications = [
108+ NDSubclassification .from_common (annot )
109+ for annot in answer .classifications
110+ ] if answer .classifications else None ,
106111 custom_metrics = answer .custom_metrics )
107112 for answer in checklist .answer
108113 ],
@@ -152,8 +157,8 @@ class NDPromptTextSubclass(NDAnswer):
152157
153158 def to_common (self ) -> PromptText :
154159 return PromptText (answer = self .answer ,
155- confidence = self .confidence ,
156- custom_metrics = self .custom_metrics )
160+ confidence = self .confidence ,
161+ custom_metrics = self .custom_metrics )
157162
158163 @classmethod
159164 def from_common (cls , prompt_text : PromptText , name : str ,
@@ -194,7 +199,8 @@ def from_common(cls,
194199 )
195200
196201
197- class NDChecklist (NDAnnotation , NDChecklistSubclass , VideoSupported , _SubclassRegistryBase ):
202+ class NDChecklist (NDAnnotation , NDChecklistSubclass , VideoSupported ,
203+ _SubclassRegistryBase ):
198204
199205 @model_serializer (mode = "wrap" )
200206 def serialize_model (self , handler ):
@@ -237,7 +243,8 @@ def from_common(
237243 confidence = confidence )
238244
239245
240- class NDRadio (NDAnnotation , NDRadioSubclass , VideoSupported , _SubclassRegistryBase ):
246+ class NDRadio (NDAnnotation , NDRadioSubclass , VideoSupported ,
247+ _SubclassRegistryBase ):
241248
242249 @classmethod
243250 def from_common (
@@ -266,35 +273,32 @@ def from_common(
266273 frames = extra .get ('frames' ),
267274 message_id = message_id ,
268275 confidence = confidence )
269-
276+
270277 @model_serializer (mode = "wrap" )
271278 def serialize_model (self , handler ):
272279 res = handler (self )
273280 if "classifications" in res and res ["classifications" ] == []:
274281 del res ["classifications" ]
275282 return res
276-
277-
283+
284+
278285class NDPromptText (NDAnnotation , NDPromptTextSubclass , _SubclassRegistryBase ):
279-
286+
280287 @classmethod
281- def from_common (
282- cls ,
283- uuid : str ,
284- text : PromptText ,
285- name ,
286- data : Dict ,
287- feature_schema_id : Cuid ,
288- confidence : Optional [float ] = None
289- ) -> "NDPromptText" :
290- return cls (
291- answer = text .answer ,
292- data_row = DataRow (id = data .uid , global_key = data .global_key ),
293- name = name ,
294- schema_id = feature_schema_id ,
295- uuid = uuid ,
296- confidence = text .confidence ,
297- custom_metrics = text .custom_metrics )
288+ def from_common (cls ,
289+ uuid : str ,
290+ text : PromptText ,
291+ name ,
292+ data : Dict ,
293+ feature_schema_id : Cuid ,
294+ confidence : Optional [float ] = None ) -> "NDPromptText" :
295+ return cls (answer = text .answer ,
296+ data_row = DataRow (id = data .uid , global_key = data .global_key ),
297+ name = name ,
298+ schema_id = feature_schema_id ,
299+ uuid = uuid ,
300+ confidence = text .confidence ,
301+ custom_metrics = text .custom_metrics )
298302
299303
300304class NDSubclassification :
@@ -350,7 +354,8 @@ def to_common(
350354 for frame in annotation .frames :
351355 for idx in range (frame .start , frame .end + 1 , 1 ):
352356 results .append (
353- VideoClassificationAnnotation (frame = idx , ** common .model_dump (exclude_none = True )))
357+ VideoClassificationAnnotation (
358+ frame = idx , ** common .model_dump (exclude_none = True )))
354359 return results
355360
356361 @classmethod
@@ -382,6 +387,7 @@ def lookup_classification(
382387 Radio : NDRadio
383388 }.get (type (annotation .value ))
384389
390+
385391class NDPromptClassification :
386392
387393 @staticmethod
@@ -404,8 +410,7 @@ def from_common(
404410 data : Union [VideoData , TextData , ImageData ]
405411 ) -> Union [NDTextSubclass , NDChecklistSubclass , NDRadioSubclass ]:
406412 return NDPromptText .from_common (str (annotation ._uuid ), annotation .value ,
407- annotation .name ,
408- data ,
413+ annotation .name , data ,
409414 annotation .feature_schema_id ,
410415 annotation .confidence )
411416
@@ -427,4 +432,4 @@ def from_common(
427432# Make sure to keep NDChecklist prior to NDRadio in the list,
428433# otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used
429434NDClassificationType = Union [NDChecklist , NDRadio , NDText ]
430- NDPromptClassificationType = Union [NDPromptText ]
435+ NDPromptClassificationType = Union [NDPromptText ]
0 commit comments