Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENG-1158] Default participant identifier to None and select 'id' as dict #59

Merged
6 changes: 4 additions & 2 deletions literalai/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,10 @@ def get_threads(
before: Optional[str] = None,
filters: Optional[threads_filters] = None,
order_by: Optional[threads_order_by] = None,
step_types_to_keep: Optional[List[StepType]] = None,
):
return self.gql_helper(
*get_threads_helper(first, after, before, filters, order_by)
*get_threads_helper(first, after, before, filters, order_by, step_types_to_keep)
)

def list_threads(
Expand Down Expand Up @@ -812,9 +813,10 @@ async def get_threads(
before: Optional[str] = None,
filters: Optional[threads_filters] = None,
order_by: Optional[threads_order_by] = None,
step_types_to_keep: Optional[List[StepType]] = None,
):
return await self.gql_helper(
*get_threads_helper(first, after, before, filters, order_by)
*get_threads_helper(first, after, before, filters, order_by, step_types_to_keep)
)

async def list_threads(
Expand Down
2 changes: 2 additions & 0 deletions literalai/api/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@
$first: Int,
$last: Int,
$projectId: String,
$stepTypesToKeep: [StepType!],
) {
threads(
after: $after,
Expand All @@ -193,6 +194,7 @@
first: $first,
last: $last,
projectId: $projectId,
stepTypesToKeep: $stepTypesToKeep,
) {
pageInfo {
startCursor
Expand Down
4 changes: 4 additions & 0 deletions literalai/api/thread_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from literalai.filter import threads_filters, threads_order_by
from literalai.my_types import PaginatedResponse
from literalai.step import StepType
from literalai.thread import Thread

from . import gql
Expand All @@ -13,6 +14,7 @@ def get_threads_helper(
before: Optional[str] = None,
filters: Optional[threads_filters] = None,
order_by: Optional[threads_order_by] = None,
step_types_to_keep: Optional[List[StepType]] = None,
):
variables: Dict[str, Any] = {}

Expand All @@ -26,6 +28,8 @@ def get_threads_helper(
variables["filters"] = filters
if order_by:
variables["orderBy"] = order_by
if step_types_to_keep:
variables["stepTypesToKeep"] = step_types_to_keep

def process_response(response):
processed_response = response["data"]["threads"]
Expand Down
6 changes: 3 additions & 3 deletions literalai/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Thread:
tags: Optional[List[str]]
steps: Optional[List[Step]]
participant_id: Optional[str]
participant_identifier: Optional[str]
participant_identifier: Optional[str] = None
created_at: Optional[str] # read-only, set by server
needs_upsert: Optional[bool]

Expand Down Expand Up @@ -64,7 +64,7 @@ def to_dict(self) -> ThreadDict:
id=self.participant_id, identifier=self.participant_identifier
)
if self.participant_id
else None,
else UserDict(),
"createdAt": getattr(self, "created_at", None),
}

Expand Down Expand Up @@ -125,7 +125,7 @@ def upsert(self):
thread_data_to_upsert["metadata"] = metadata
if tags := thread_data.get("tags"):
thread_data_to_upsert["tags"] = tags
if participant_id := thread_data.get("participant_id"):
if participant_id := thread_data.get("participant", {}).get("id"):
thread_data_to_upsert["participant_id"] = participant_id

try:
Expand Down
7 changes: 7 additions & 0 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from literalai import AsyncLiteralClient, LiteralClient
from literalai.context import active_steps_var
from literalai.my_types import ChatGeneration
from literalai.thread import Thread

"""
End to end tests for the SDK
Expand Down Expand Up @@ -573,3 +574,9 @@ async def test_gracefulness(self, broken_client: LiteralClient):

broken_client.flush()
assert True

@pytest.mark.timeout(5)
async def test_thread_to_dict(self, client: LiteralClient):
thread = Thread(id="thread-id", participant_id="participant-id")
participant = thread.to_dict().get("participant", {})
assert participant and participant["id"] == "participant-id"
Loading