From 497e2275a329b21028fd3fde3560780a3f59084e Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Wed, 3 Apr 2024 13:24:09 +0200 Subject: [PATCH] add tagds and metadata to steps and messages --- backend/chainlit/data/__init__.py | 18 +++++++++++------- backend/chainlit/message.py | 9 +++++++++ backend/chainlit/step.py | 11 ++++++++++- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/backend/chainlit/data/__init__.py b/backend/chainlit/data/__init__.py index d8210d9ff0..19ee762cd0 100644 --- a/backend/chainlit/data/__init__.py +++ b/backend/chainlit/data/__init__.py @@ -327,13 +327,16 @@ async def delete_element(self, element_id: str): @queue_until_user_message() async def create_step(self, step_dict: "StepDict"): - metadata = { - "disableFeedback": step_dict.get("disableFeedback"), - "isError": step_dict.get("isError"), - "waitForAnswer": step_dict.get("waitForAnswer"), - "language": step_dict.get("language"), - "showInput": step_dict.get("showInput"), - } + metadata = dict( + step_dict.get("metadata", {}), + **{ + "disableFeedback": step_dict.get("disableFeedback"), + "isError": step_dict.get("isError"), + "waitForAnswer": step_dict.get("waitForAnswer"), + "language": step_dict.get("language"), + "showInput": step_dict.get("showInput"), + }, + ) step: LiteralStepDict = { "createdAt": step_dict.get("createdAt"), @@ -345,6 +348,7 @@ async def create_step(self, step_dict: "StepDict"): "name": step_dict.get("name"), "threadId": step_dict.get("threadId"), "type": step_dict.get("type"), + "tags": step_dict.get("tags"), "metadata": metadata, } if step_dict.get("input"): diff --git a/backend/chainlit/message.py b/backend/chainlit/message.py index 85f9d67f6d..3abb7cb8cd 100644 --- a/backend/chainlit/message.py +++ b/backend/chainlit/message.py @@ -39,6 +39,8 @@ class MessageBase(ABC): persisted = False is_error = False language: Optional[str] = None + metadata: Optional[Dict] = None + tags: Optional[List[str]] = None wait_for_answer = False indent: Optional[int] = None generation: Optional[BaseGeneration] = None @@ -83,6 +85,8 @@ def to_dict(self) -> StepDict: "waitForAnswer": self.wait_for_answer, "indent": self.indent, "generation": self.generation.to_dict() if self.generation else None, + "metadata": self.metadata or {}, + "tags": self.tags, } return _dict @@ -209,6 +213,8 @@ def __init__( disable_feedback: bool = False, type: MessageStepType = "assistant_message", generation: Optional[BaseGeneration] = None, + metadata: Optional[Dict] = None, + tags: Optional[List[str]] = None, id: Optional[str] = None, created_at: Union[str, None] = None, ): @@ -234,6 +240,9 @@ def __init__( if created_at: self.created_at = created_at + self.metadata = metadata + self.tags = tags + self.author = author self.type = type self.actions = actions if actions is not None else [] diff --git a/backend/chainlit/step.py b/backend/chainlit/step.py index 46c9a263c3..eb335556a5 100644 --- a/backend/chainlit/step.py +++ b/backend/chainlit/step.py @@ -29,6 +29,7 @@ class StepDict(TypedDict, total=False): waitForAnswer: Optional[bool] isError: Optional[bool] metadata: Dict + tags: Optional[List[str]] input: str output: str createdAt: Optional[str] @@ -47,6 +48,7 @@ def step( name: Optional[str] = "", type: TrueStepType = "undefined", id: Optional[str] = None, + tags: Optional[List[str]] = None, disable_feedback: bool = True, root: bool = False, language: Optional[str] = None, @@ -71,6 +73,7 @@ async def async_wrapper(*args, **kwargs): id=id, disable_feedback=disable_feedback, root=root, + tags=tags, language=language, show_input=show_input, ) as step: @@ -97,6 +100,7 @@ def sync_wrapper(*args, **kwargs): id=id, disable_feedback=disable_feedback, root=root, + tags=tags, language=language, show_input=show_input, ) as step: @@ -137,6 +141,7 @@ class Step: is_error: Optional[bool] metadata: Dict + tags: Optional[List[str]] thread_id: str created_at: Union[str, None] start: Union[str, None] @@ -153,6 +158,8 @@ def __init__( id: Optional[str] = None, parent_id: Optional[str] = None, elements: Optional[List[Element]] = None, + metadata: Optional[Dict] = None, + tags: Optional[List[str]] = None, disable_feedback: bool = True, root: bool = False, language: Optional[str] = None, @@ -167,7 +174,8 @@ def __init__( self.type = type self.id = id or str(uuid.uuid4()) self.disable_feedback = disable_feedback - self.metadata = {} + self.metadata = metadata or {} + self.tags = tags self.is_error = False self.show_input = show_input self.parent_id = parent_id @@ -231,6 +239,7 @@ def to_dict(self) -> StepDict: "disableFeedback": self.disable_feedback, "streaming": self.streaming, "metadata": self.metadata, + "tags": self.tags, "input": self.input, "isError": self.is_error, "output": self.output,