Skip to content

Commit

Permalink
fix: DIA-1451: return list instead of generator (#268)
Browse files Browse the repository at this point in the history
Co-authored-by: Matt Bernstein <matt@humansignal.com>
  • Loading branch information
pakelley and matt-bernstein authored Dec 6, 2024
1 parent 8896f3b commit 1dc47e4
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
17 changes: 15 additions & 2 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import asyncio
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Type, Union, Literal, TypedDict, Iterable, Generator
from typing import (
Any,
Dict,
List,
Optional,
Type,
Union,
Literal,
TypedDict,
Iterable,
Generator,
)
from functools import cached_property
from enum import Enum
import litellm
Expand Down Expand Up @@ -651,7 +662,9 @@ def add_to_current_chunk(
return chunk

# Build chunks by iterating through parsed template parts
def build_chunks(parsed: Iterable[TemplateChunks]) -> Generator[MessageChunk, None, None]:
def build_chunks(
parsed: Iterable[TemplateChunks],
) -> Generator[MessageChunk, None, None]:
current_chunk: Optional[MessageChunk] = None

for part in parsed:
Expand Down
15 changes: 10 additions & 5 deletions adala/skills/collection/label_studio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import pandas as pd
from typing import Type, Iterator, Optional
from typing import List, Optional, Type
from functools import cached_property
from copy import deepcopy
from collections import defaultdict
Expand Down Expand Up @@ -44,26 +44,31 @@ def label_interface(self) -> LabelInterface:
return LabelInterface(self.label_config)

@cached_property
def ner_tags(self) -> Iterator[ControlTag]:
def ner_tags(self) -> List[ControlTag]:
# check if the input config has NER tag (<Labels> + <Text>), and return its `from_name` and `to_name`
control_tag_names = self.allowed_control_tags or list(
self.label_interface._controls.keys()
)
tags = []
for tag_name in control_tag_names:
tag = self.label_interface.get_control(tag_name)
if tag.tag.lower() in {"labels", "hypertextlabels"}:
yield tag
tags.append(tag)
return tags

@cached_property
def image_tags(self) -> Iterator[ObjectTag]:
def image_tags(self) -> List[ObjectTag]:
# check if any image tags are used as input variables
object_tag_names = self.allowed_object_tags or list(
self.label_interface._objects.keys()
)
tags = []
for tag_name in object_tag_names:
tag = self.label_interface.get_object(tag_name)
if tag.tag.lower() == "image":
yield tag
tags.append(tag)
return tags


def __getstate__(self):
"""Exclude cached properties when pickling - otherwise the 'Agent' can not be serialized in celery"""
Expand Down

0 comments on commit 1dc47e4

Please sign in to comment.