diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 76d49ae..2c8a9f2 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -235,9 +235,10 @@ def get_threads( before: Optional[str] = None, filters: Optional[threads_filters] = None, order_by: Optional[threads_order_by] = None, + step_types_to_keep: Optional[List[StepType]] = None, ): return self.gql_helper( - *get_threads_helper(first, after, before, filters, order_by) + *get_threads_helper(first, after, before, filters, order_by, step_types_to_keep) ) def list_threads( @@ -812,9 +813,10 @@ async def get_threads( before: Optional[str] = None, filters: Optional[threads_filters] = None, order_by: Optional[threads_order_by] = None, + step_types_to_keep: Optional[List[StepType]] = None, ): return await self.gql_helper( - *get_threads_helper(first, after, before, filters, order_by) + *get_threads_helper(first, after, before, filters, order_by, step_types_to_keep) ) async def list_threads( diff --git a/literalai/api/gql.py b/literalai/api/gql.py index e6b8e11..f39078e 100644 --- a/literalai/api/gql.py +++ b/literalai/api/gql.py @@ -183,6 +183,7 @@ $first: Int, $last: Int, $projectId: String, + $stepTypesToKeep: [StepType!], ) { threads( after: $after, @@ -193,6 +194,7 @@ first: $first, last: $last, projectId: $projectId, + stepTypesToKeep: $stepTypesToKeep, ) { pageInfo { startCursor diff --git a/literalai/api/thread_helpers.py b/literalai/api/thread_helpers.py index 631253b..88602d0 100644 --- a/literalai/api/thread_helpers.py +++ b/literalai/api/thread_helpers.py @@ -2,6 +2,7 @@ from literalai.filter import threads_filters, threads_order_by from literalai.my_types import PaginatedResponse +from literalai.step import StepType from literalai.thread import Thread from . import gql @@ -13,6 +14,7 @@ def get_threads_helper( before: Optional[str] = None, filters: Optional[threads_filters] = None, order_by: Optional[threads_order_by] = None, + step_types_to_keep: Optional[List[StepType]] = None, ): variables: Dict[str, Any] = {} @@ -26,6 +28,8 @@ def get_threads_helper( variables["filters"] = filters if order_by: variables["orderBy"] = order_by + if step_types_to_keep: + variables["stepTypesToKeep"] = step_types_to_keep def process_response(response): processed_response = response["data"]["threads"] diff --git a/literalai/thread.py b/literalai/thread.py index 6d2ab97..95c9148 100644 --- a/literalai/thread.py +++ b/literalai/thread.py @@ -32,7 +32,7 @@ class Thread: tags: Optional[List[str]] steps: Optional[List[Step]] participant_id: Optional[str] - participant_identifier: Optional[str] + participant_identifier: Optional[str] = None created_at: Optional[str] # read-only, set by server needs_upsert: Optional[bool] @@ -64,7 +64,7 @@ def to_dict(self) -> ThreadDict: id=self.participant_id, identifier=self.participant_identifier ) if self.participant_id - else None, + else UserDict(), "createdAt": getattr(self, "created_at", None), } @@ -125,7 +125,7 @@ def upsert(self): thread_data_to_upsert["metadata"] = metadata if tags := thread_data.get("tags"): thread_data_to_upsert["tags"] = tags - if participant_id := thread_data.get("participant_id"): + if participant_id := thread_data.get("participant", {}).get("id"): thread_data_to_upsert["participant_id"] = participant_id try: diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index ca54a77..bab0a6f 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -8,6 +8,7 @@ from literalai import AsyncLiteralClient, LiteralClient from literalai.context import active_steps_var from literalai.my_types import ChatGeneration +from literalai.thread import Thread """ End to end tests for the SDK @@ -573,3 +574,9 @@ async def test_gracefulness(self, broken_client: LiteralClient): broken_client.flush() assert True + + @pytest.mark.timeout(5) + async def test_thread_to_dict(self, client: LiteralClient): + thread = Thread(id="thread-id", participant_id="participant-id") + participant = thread.to_dict().get("participant", {}) + assert participant and participant["id"] == "participant-id"