diff --git a/adala/runtimes/_litellm.py b/adala/runtimes/_litellm.py index 7c654cfd..8aaa0222 100644 --- a/adala/runtimes/_litellm.py +++ b/adala/runtimes/_litellm.py @@ -1,7 +1,18 @@ import asyncio import logging from collections import defaultdict -from typing import Any, Dict, List, Optional, Type, Union, Literal, TypedDict, Iterable, Generator +from typing import ( + Any, + Dict, + List, + Optional, + Type, + Union, + Literal, + TypedDict, + Iterable, + Generator, +) from functools import cached_property from enum import Enum import litellm @@ -651,7 +662,9 @@ def add_to_current_chunk( return chunk # Build chunks by iterating through parsed template parts - def build_chunks(parsed: Iterable[TemplateChunks]) -> Generator[MessageChunk, None, None]: + def build_chunks( + parsed: Iterable[TemplateChunks], + ) -> Generator[MessageChunk, None, None]: current_chunk: Optional[MessageChunk] = None for part in parsed: diff --git a/adala/skills/collection/label_studio.py b/adala/skills/collection/label_studio.py index 70b047b6..32f299e3 100644 --- a/adala/skills/collection/label_studio.py +++ b/adala/skills/collection/label_studio.py @@ -1,6 +1,6 @@ import logging import pandas as pd -from typing import Type, Iterator, Optional +from typing import List, Optional, Type from functools import cached_property from copy import deepcopy from collections import defaultdict @@ -44,26 +44,31 @@ def label_interface(self) -> LabelInterface: return LabelInterface(self.label_config) @cached_property - def ner_tags(self) -> Iterator[ControlTag]: + def ner_tags(self) -> List[ControlTag]: # check if the input config has NER tag ( + ), and return its `from_name` and `to_name` control_tag_names = self.allowed_control_tags or list( self.label_interface._controls.keys() ) + tags = [] for tag_name in control_tag_names: tag = self.label_interface.get_control(tag_name) if tag.tag.lower() in {"labels", "hypertextlabels"}: - yield tag + tags.append(tag) + return tags @cached_property - def image_tags(self) -> Iterator[ObjectTag]: + def image_tags(self) -> List[ObjectTag]: # check if any image tags are used as input variables object_tag_names = self.allowed_object_tags or list( self.label_interface._objects.keys() ) + tags = [] for tag_name in object_tag_names: tag = self.label_interface.get_object(tag_name) if tag.tag.lower() == "image": - yield tag + tags.append(tag) + return tags + def __getstate__(self): """Exclude cached properties when pickling - otherwise the 'Agent' can not be serialized in celery"""