Skip to content

Commit

Permalink
Merge pull request #165 from Chainlit/willy/concurrency
Browse files Browse the repository at this point in the history
fix: thread/step concurrency
  • Loading branch information
willydouhard authored Feb 18, 2025
2 parents cafe70d + 3ae371a commit 3eebec5
Show file tree
Hide file tree
Showing 21 changed files with 206 additions and 181 deletions.
7 changes: 5 additions & 2 deletions literalai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from literalai.client import AsyncLiteralClient, LiteralClient
from literalai.evaluation.dataset import Dataset
from literalai.evaluation.dataset_experiment import (
DatasetExperiment,
DatasetExperimentItem,
)
from literalai.evaluation.dataset_item import DatasetItem
from literalai.evaluation.dataset_experiment import DatasetExperiment, DatasetExperimentItem
from literalai.prompt_engineering.prompt import Prompt
from literalai.my_types import * # noqa
from literalai.observability.generation import (
BaseGeneration,
Expand All @@ -13,6 +15,7 @@
from literalai.observability.message import Message
from literalai.observability.step import Attachment, Score, Step
from literalai.observability.thread import Thread
from literalai.prompt_engineering.prompt import Prompt
from literalai.version import __version__

__all__ = [
Expand Down
79 changes: 41 additions & 38 deletions literalai/api/asynchronous.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
import logging
import uuid
from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union, cast

import httpx
from typing_extensions import deprecated
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
TypeVar,
Union,
cast,
)

from literalai.api.base import BaseLiteralAPI, prepare_variables

from literalai.api.helpers.attachment_helpers import (
AttachmentUpload,
create_attachment_helper,
Expand Down Expand Up @@ -91,6 +81,7 @@
DatasetExperimentItem,
)
from literalai.evaluation.dataset_item import DatasetItem
from literalai.my_types import PaginatedResponse, User
from literalai.observability.filter import (
generations_filters,
generations_order_by,
Expand All @@ -102,12 +93,6 @@
threads_order_by,
users_filters,
)
from literalai.observability.thread import Thread
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings

import httpx

from literalai.my_types import PaginatedResponse, User
from literalai.observability.generation import (
BaseGeneration,
ChatGeneration,
Expand All @@ -123,6 +108,8 @@
StepDict,
StepType,
)
from literalai.observability.thread import Thread
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings

logger = logging.getLogger(__name__)

Expand All @@ -141,7 +128,11 @@ class AsyncLiteralAPI(BaseLiteralAPI):
R = TypeVar("R")

async def make_gql_call(
self, description: str, query: str, variables: Dict[str, Any], timeout: Optional[int] = 10
self,
description: str,
query: str,
variables: Dict[str, Any],
timeout: Optional[int] = 10,
) -> Dict:
def raise_error(error):
logger.error(f"Failed to {description}: {error}")
Expand All @@ -166,8 +157,7 @@ def raise_error(error):
json = response.json()
except ValueError as e:
raise_error(
f"""Failed to parse JSON response: {
e}, content: {response.content!r}"""
f"Failed to parse JSON response: {e}, content: {response.content!r}"
)

if json.get("errors"):
Expand All @@ -178,8 +168,7 @@ def raise_error(error):
for value in json["data"].values():
if value and value.get("ok") is False:
raise_error(
f"""Failed to {description}: {
value.get('message')}"""
f"""Failed to {description}: {value.get("message")}"""
)
return json

Expand All @@ -203,9 +192,9 @@ async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
return response.json()
except ValueError as e:
raise ValueError(
f"""Failed to parse JSON response: {
e}, content: {response.content!r}"""
f"Failed to parse JSON response: {e}, content: {response.content!r}"
)

async def gql_helper(
self,
query: str,
Expand Down Expand Up @@ -235,7 +224,9 @@ async def get_user(
) -> "User":
return await self.gql_helper(*get_user_helper(id, identifier))

async def create_user(self, identifier: str, metadata: Optional[Dict] = None) -> "User":
async def create_user(
self, identifier: str, metadata: Optional[Dict] = None
) -> "User":
return await self.gql_helper(*create_user_helper(identifier, metadata))

async def update_user(
Expand All @@ -245,7 +236,7 @@ async def update_user(

async def delete_user(self, id: str) -> Dict:
return await self.gql_helper(*delete_user_helper(id))

async def get_or_create_user(
self, identifier: str, metadata: Optional[Dict] = None
) -> "User":
Expand Down Expand Up @@ -273,7 +264,7 @@ async def get_threads(
first, after, before, filters, order_by, step_types_to_keep
)
)

async def list_threads(
self,
first: Optional[int] = None,
Expand Down Expand Up @@ -491,7 +482,7 @@ async def create_attachment(
thread_id = active_thread.id

if not step_id:
if active_steps := active_steps_var.get([]):
if active_steps := active_steps_var.get():
step_id = active_steps[-1].id
else:
raise Exception("No step_id provided and no active step found.")
Expand Down Expand Up @@ -532,7 +523,9 @@ async def create_attachment(
response = await self.make_gql_call(description, query, variables)
return process_response(response)

async def update_attachment(self, id: str, update_params: AttachmentUpload) -> "Attachment":
async def update_attachment(
self, id: str, update_params: AttachmentUpload
) -> "Attachment":
return await self.gql_helper(*update_attachment_helper(id, update_params))

async def get_attachment(self, id: str) -> Optional["Attachment"]:
Expand All @@ -545,7 +538,6 @@ async def delete_attachment(self, id: str) -> Dict:
# Step APIs #
##################################################################################


async def create_step(
self,
thread_id: Optional[str] = None,
Expand Down Expand Up @@ -646,7 +638,7 @@ async def get_generations(
return await self.gql_helper(
*get_generations_helper(first, after, before, filters, order_by)
)

async def create_generation(
self, generation: Union["ChatGeneration", "CompletionGeneration"]
) -> Union["ChatGeneration", "CompletionGeneration"]:
Expand All @@ -667,8 +659,10 @@ async def create_dataset(
return await self.gql_helper(
*create_dataset_helper(sync_api, name, description, metadata, type)
)

async def get_dataset(self, id: Optional[str] = None, name: Optional[str] = None) -> "Dataset":

async def get_dataset(
self, id: Optional[str] = None, name: Optional[str] = None
) -> "Dataset":
sync_api = LiteralAPI(self.api_key, self.url)
subpath, _, variables, process_response = get_dataset_helper(
sync_api, id=id, name=name
Expand Down Expand Up @@ -738,7 +732,7 @@ async def create_experiment_item(
result.scores = await self.create_scores(experiment_item.scores)

return result

##################################################################################
# DatasetItem APIs #
##################################################################################
Expand All @@ -753,7 +747,7 @@ async def create_dataset_item(
return await self.gql_helper(
*create_dataset_item_helper(dataset_id, input, expected_output, metadata)
)

async def get_dataset_item(self, id: str) -> "DatasetItem":
return await self.gql_helper(*get_dataset_item_helper(id))

Expand Down Expand Up @@ -784,7 +778,9 @@ async def get_or_create_prompt_lineage(
return await self.gql_helper(*create_prompt_lineage_helper(name, description))

@deprecated('Please use "get_or_create_prompt_lineage" instead.')
async def create_prompt_lineage(self, name: str, description: Optional[str] = None) -> Dict:
async def create_prompt_lineage(
self, name: str, description: Optional[str] = None
) -> Dict:
return await self.get_or_create_prompt_lineage(name, description)

async def get_or_create_prompt(
Expand Down Expand Up @@ -838,7 +834,14 @@ async def get_prompt(
raise ValueError("At least the `id` or the `name` must be provided.")

sync_api = LiteralAPI(self.api_key, self.url)
get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper(
(
get_prompt_query,
description,
variables,
process_response,
timeout,
cached_prompt,
) = get_prompt_helper(
api=sync_api, id=id, name=name, version=version, cache=self.cache
)

Expand Down
56 changes: 14 additions & 42 deletions literalai/api/base.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,16 @@
import os

from abc import ABC, abstractmethod
from typing import (
Any,
Dict,
List,
Optional,
Union,
)
from typing import Any, Dict, List, Optional, Union

from typing_extensions import deprecated

from literalai.my_types import Environment

from literalai.api.helpers.attachment_helpers import AttachmentUpload
from literalai.api.helpers.prompt_helpers import PromptRollout
from literalai.api.helpers.score_helpers import ScoreUpdate
from literalai.cache.shared_cache import SharedCache
from literalai.evaluation.dataset import DatasetType
from literalai.evaluation.dataset_experiment import (
DatasetExperimentItem,
)
from literalai.api.helpers.attachment_helpers import (
AttachmentUpload)
from literalai.api.helpers.score_helpers import (
ScoreUpdate,
)

from literalai.evaluation.dataset_experiment import DatasetExperimentItem
from literalai.my_types import Environment
from literalai.observability.filter import (
generations_filters,
generations_order_by,
Expand All @@ -35,24 +22,14 @@
threads_order_by,
users_filters,
)
from literalai.prompt_engineering.prompt import ProviderSettings


from literalai.api.helpers.prompt_helpers import (
PromptRollout)

from literalai.observability.generation import (
ChatGeneration,
CompletionGeneration,
GenerationMessage,
)
from literalai.observability.step import (
ScoreDict,
ScoreType,
Step,
StepDict,
StepType,
)
from literalai.observability.step import ScoreDict, ScoreType, Step, StepDict, StepType
from literalai.prompt_engineering.prompt import ProviderSettings


def prepare_variables(variables: Dict[str, Any]) -> Dict[str, Any]:
"""
Expand All @@ -72,6 +49,7 @@ def handle_bytes(item):

return handle_bytes(variables)


class BaseLiteralAPI(ABC):
def __init__(
self,
Expand Down Expand Up @@ -676,7 +654,7 @@ def delete_step(
@abstractmethod
def send_steps(self, steps: List[Union[StepDict, "Step"]]):
"""
Sends a list of steps to process.
Sends a list of steps to process.
Step ingestion happens asynchronously if you configured a cache. See [Cache Configuration](https://docs.literalai.com/self-hosting/deployment#4-cache-configuration-optional).
Args:
Expand Down Expand Up @@ -773,9 +751,7 @@ def create_dataset(
pass

@abstractmethod
def get_dataset(
self, id: Optional[str] = None, name: Optional[str] = None
):
def get_dataset(self, id: Optional[str] = None, name: Optional[str] = None):
"""
Retrieves a dataset by its ID or name.
Expand Down Expand Up @@ -846,9 +822,7 @@ def create_experiment(
pass

@abstractmethod
def create_experiment_item(
self, experiment_item: DatasetExperimentItem
):
def create_experiment_item(self, experiment_item: DatasetExperimentItem):
"""
Creates an experiment item within an existing experiment.
Expand Down Expand Up @@ -1065,9 +1039,7 @@ def get_prompt_ab_testing(self, name: str):
pass

@abstractmethod
def update_prompt_ab_testing(
self, name: str, rollouts: List[PromptRollout]
):
def update_prompt_ab_testing(self, name: str, rollouts: List[PromptRollout]):
"""
Update the A/B testing configuration for a prompt lineage.
Expand Down
11 changes: 6 additions & 5 deletions literalai/api/helpers/generation_helpers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Any, Dict, Optional, Union

from literalai.api.helpers import gql
from literalai.my_types import PaginatedResponse
from literalai.observability.filter import generations_filters, generations_order_by
from literalai.my_types import (
PaginatedResponse,
from literalai.observability.generation import (
BaseGeneration,
ChatGeneration,
CompletionGeneration,
)
from literalai.observability.generation import BaseGeneration, CompletionGeneration, ChatGeneration

from literalai.api.helpers import gql


def get_generations_helper(
Expand Down
Loading

0 comments on commit 3eebec5

Please sign in to comment.