Skip to content

Commit

Permalink
add optional tags and metadata to steps and messages (#877)
Browse files Browse the repository at this point in the history
  • Loading branch information
willydouhard authored Apr 4, 2024
1 parent 22d9f58 commit e997ab5
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 8 deletions.
18 changes: 11 additions & 7 deletions backend/chainlit/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,13 +327,16 @@ async def delete_element(self, element_id: str):

@queue_until_user_message()
async def create_step(self, step_dict: "StepDict"):
metadata = {
"disableFeedback": step_dict.get("disableFeedback"),
"isError": step_dict.get("isError"),
"waitForAnswer": step_dict.get("waitForAnswer"),
"language": step_dict.get("language"),
"showInput": step_dict.get("showInput"),
}
metadata = dict(
step_dict.get("metadata", {}),
**{
"disableFeedback": step_dict.get("disableFeedback"),
"isError": step_dict.get("isError"),
"waitForAnswer": step_dict.get("waitForAnswer"),
"language": step_dict.get("language"),
"showInput": step_dict.get("showInput"),
},
)

step: LiteralStepDict = {
"createdAt": step_dict.get("createdAt"),
Expand All @@ -345,6 +348,7 @@ async def create_step(self, step_dict: "StepDict"):
"name": step_dict.get("name"),
"threadId": step_dict.get("threadId"),
"type": step_dict.get("type"),
"tags": step_dict.get("tags"),
"metadata": metadata,
}
if step_dict.get("input"):
Expand Down
9 changes: 9 additions & 0 deletions backend/chainlit/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class MessageBase(ABC):
persisted = False
is_error = False
language: Optional[str] = None
metadata: Optional[Dict] = None
tags: Optional[List[str]] = None
wait_for_answer = False
indent: Optional[int] = None
generation: Optional[BaseGeneration] = None
Expand Down Expand Up @@ -83,6 +85,8 @@ def to_dict(self) -> StepDict:
"waitForAnswer": self.wait_for_answer,
"indent": self.indent,
"generation": self.generation.to_dict() if self.generation else None,
"metadata": self.metadata or {},
"tags": self.tags,
}

return _dict
Expand Down Expand Up @@ -209,6 +213,8 @@ def __init__(
disable_feedback: bool = False,
type: MessageStepType = "assistant_message",
generation: Optional[BaseGeneration] = None,
metadata: Optional[Dict] = None,
tags: Optional[List[str]] = None,
id: Optional[str] = None,
created_at: Union[str, None] = None,
):
Expand All @@ -234,6 +240,9 @@ def __init__(
if created_at:
self.created_at = created_at

self.metadata = metadata
self.tags = tags

self.author = author
self.type = type
self.actions = actions if actions is not None else []
Expand Down
11 changes: 10 additions & 1 deletion backend/chainlit/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class StepDict(TypedDict, total=False):
waitForAnswer: Optional[bool]
isError: Optional[bool]
metadata: Dict
tags: Optional[List[str]]
input: str
output: str
createdAt: Optional[str]
Expand All @@ -47,6 +48,7 @@ def step(
name: Optional[str] = "",
type: TrueStepType = "undefined",
id: Optional[str] = None,
tags: Optional[List[str]] = None,
disable_feedback: bool = True,
root: bool = False,
language: Optional[str] = None,
Expand All @@ -71,6 +73,7 @@ async def async_wrapper(*args, **kwargs):
id=id,
disable_feedback=disable_feedback,
root=root,
tags=tags,
language=language,
show_input=show_input,
) as step:
Expand All @@ -97,6 +100,7 @@ def sync_wrapper(*args, **kwargs):
id=id,
disable_feedback=disable_feedback,
root=root,
tags=tags,
language=language,
show_input=show_input,
) as step:
Expand Down Expand Up @@ -137,6 +141,7 @@ class Step:

is_error: Optional[bool]
metadata: Dict
tags: Optional[List[str]]
thread_id: str
created_at: Union[str, None]
start: Union[str, None]
Expand All @@ -153,6 +158,8 @@ def __init__(
id: Optional[str] = None,
parent_id: Optional[str] = None,
elements: Optional[List[Element]] = None,
metadata: Optional[Dict] = None,
tags: Optional[List[str]] = None,
disable_feedback: bool = True,
root: bool = False,
language: Optional[str] = None,
Expand All @@ -167,7 +174,8 @@ def __init__(
self.type = type
self.id = id or str(uuid.uuid4())
self.disable_feedback = disable_feedback
self.metadata = {}
self.metadata = metadata or {}
self.tags = tags
self.is_error = False
self.show_input = show_input
self.parent_id = parent_id
Expand Down Expand Up @@ -231,6 +239,7 @@ def to_dict(self) -> StepDict:
"disableFeedback": self.disable_feedback,
"streaming": self.streaming,
"metadata": self.metadata,
"tags": self.tags,
"input": self.input,
"isError": self.is_error,
"output": self.output,
Expand Down

0 comments on commit e997ab5

Please sign in to comment.