diff --git a/app/schema.graphql b/app/schema.graphql index 3a5441f43d..ba94ee2c66 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -68,7 +68,8 @@ union Bin = NominalBin | IntervalBin | MissingValueBin input ChatCompletionInput { messages: [ChatCompletionMessageInput!]! model: GenerativeModelInput! - apiKey: String = null + invocationParameters: InvocationParameters! + apiKey: String } input ChatCompletionMessageInput { @@ -895,6 +896,15 @@ type IntervalBin { range: NumericRange! } +input InvocationParameters { + temperature: Float + maxCompletionTokens: Int + maxTokens: Int + topP: Float + stop: [String!] + seed: Int +} + """ The `JSON` scalar type represents JSON values as specified by [ECMA-404](https://ecma-international.org/wp-content/uploads/ECMA-404_2nd_edition_december_2017.pdf). """ diff --git a/app/src/pages/playground/PlaygroundOutput.tsx b/app/src/pages/playground/PlaygroundOutput.tsx index 0676c05a64..50996fc47d 100644 --- a/app/src/pages/playground/PlaygroundOutput.tsx +++ b/app/src/pages/playground/PlaygroundOutput.tsx @@ -105,10 +105,16 @@ function useChatCompletionSubscription({ subscription PlaygroundOutputSubscription( $messages: [ChatCompletionMessageInput!]! $model: GenerativeModelInput! + $invocationParameters: InvocationParameters! $apiKey: String ) { chatCompletion( - input: { messages: $messages, model: $model, apiKey: $apiKey } + input: { + messages: $messages + model: $model + invocationParameters: $invocationParameters + apiKey: $apiKey + } ) } `, @@ -187,6 +193,9 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) { providerKey: instance.model.provider, name: instance.model.modelName || "", }, + invocationParameters: { + temperature: 0.1, // TODO: add invocation parameters + }, apiKey: credentials[instance.model.provider], }, runId: instance.activeRunId, diff --git a/app/src/pages/playground/__generated__/PlaygroundOutputSubscription.graphql.ts b/app/src/pages/playground/__generated__/PlaygroundOutputSubscription.graphql.ts index ad17234ac7..071e575d5e 100644 --- a/app/src/pages/playground/__generated__/PlaygroundOutputSubscription.graphql.ts +++ b/app/src/pages/playground/__generated__/PlaygroundOutputSubscription.graphql.ts @@ -1,5 +1,5 @@ /** - * @generated SignedSource<<2cdc9b61c363ef02eee1db666410161c>> + * @generated SignedSource<<2541e9e2bfed22e1e5bd7970dd955d56>> * @lightSyntaxTransform * @nogrep */ @@ -19,8 +19,17 @@ export type GenerativeModelInput = { name: string; providerKey: GenerativeProviderKey; }; +export type InvocationParameters = { + maxCompletionTokens?: number | null; + maxTokens?: number | null; + seed?: number | null; + stop?: ReadonlyArray | null; + temperature?: number | null; + topP?: number | null; +}; export type PlaygroundOutputSubscription$variables = { apiKey?: string | null; + invocationParameters: InvocationParameters; messages: ReadonlyArray; model: GenerativeModelInput; }; @@ -41,14 +50,19 @@ var v0 = { v1 = { "defaultValue": null, "kind": "LocalArgument", - "name": "messages" + "name": "invocationParameters" }, v2 = { + "defaultValue": null, + "kind": "LocalArgument", + "name": "messages" +}, +v3 = { "defaultValue": null, "kind": "LocalArgument", "name": "model" }, -v3 = [ +v4 = [ { "alias": null, "args": [ @@ -59,6 +73,11 @@ v3 = [ "name": "apiKey", "variableName": "apiKey" }, + { + "kind": "Variable", + "name": "invocationParameters", + "variableName": "invocationParameters" + }, { "kind": "Variable", "name": "messages", @@ -84,37 +103,39 @@ return { "argumentDefinitions": [ (v0/*: any*/), (v1/*: any*/), - (v2/*: any*/) + (v2/*: any*/), + (v3/*: any*/) ], "kind": "Fragment", "metadata": null, "name": "PlaygroundOutputSubscription", - "selections": (v3/*: any*/), + "selections": (v4/*: any*/), "type": "Subscription", "abstractKey": null }, "kind": "Request", "operation": { "argumentDefinitions": [ - (v1/*: any*/), (v2/*: any*/), + (v3/*: any*/), + (v1/*: any*/), (v0/*: any*/) ], "kind": "Operation", "name": "PlaygroundOutputSubscription", - "selections": (v3/*: any*/) + "selections": (v4/*: any*/) }, "params": { - "cacheID": "5cff076e526a229159c9105f632f563b", + "cacheID": "6d0368256a4708eceb692fac9b469a70", "id": null, "metadata": {}, "name": "PlaygroundOutputSubscription", "operationKind": "subscription", - "text": "subscription PlaygroundOutputSubscription(\n $messages: [ChatCompletionMessageInput!]!\n $model: GenerativeModelInput!\n $apiKey: String\n) {\n chatCompletion(input: {messages: $messages, model: $model, apiKey: $apiKey})\n}\n" + "text": "subscription PlaygroundOutputSubscription(\n $messages: [ChatCompletionMessageInput!]!\n $model: GenerativeModelInput!\n $invocationParameters: InvocationParameters!\n $apiKey: String\n) {\n chatCompletion(input: {messages: $messages, model: $model, invocationParameters: $invocationParameters, apiKey: $apiKey})\n}\n" } }; })(); -(node as any).hash = "cafb70821a7cde3503c13af42b602aa1"; +(node as any).hash = "0ffc0f049247896bfff22a90effc24a9"; export default node; diff --git a/src/phoenix/db/migrations/versions/10460e46d750_datasets.py b/src/phoenix/db/migrations/versions/10460e46d750_datasets.py index 3a4aeec79e..8d4eea00c4 100644 --- a/src/phoenix/db/migrations/versions/10460e46d750_datasets.py +++ b/src/phoenix/db/migrations/versions/10460e46d750_datasets.py @@ -20,7 +20,7 @@ class JSONB(JSON): __visit_name__ = "JSONB" -@compiles(JSONB, "sqlite") # type: ignore +@compiles(JSONB, "sqlite") def _(*args: Any, **kwargs: Any) -> str: # See https://docs.sqlalchemy.org/en/20/core/custom_types.html return "JSONB" diff --git a/src/phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py b/src/phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py index 141a378335..9b5a36c553 100644 --- a/src/phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +++ b/src/phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py @@ -32,7 +32,7 @@ class JSONB(JSON): __visit_name__ = "JSONB" -@compiles(JSONB, "sqlite") # type: ignore +@compiles(JSONB, "sqlite") def _(*args: Any, **kwargs: Any) -> str: # See https://docs.sqlalchemy.org/en/20/core/custom_types.html return "JSONB" diff --git a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py index e838f04f04..0baa6b90d5 100644 --- a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py +++ b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py @@ -20,7 +20,7 @@ class JSONB(JSON): __visit_name__ = "JSONB" -@compiles(JSONB, "sqlite") # type: ignore +@compiles(JSONB, "sqlite") def _(*args: Any, **kwargs: Any) -> str: # See https://docs.sqlalchemy.org/en/20/core/custom_types.html return "JSONB" diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index ad3070f6db..2adef3b19f 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -50,7 +50,7 @@ class JSONB(JSON): __visit_name__ = "JSONB" -@compiles(JSONB, "sqlite") # type: ignore +@compiles(JSONB, "sqlite") def _(*args: Any, **kwargs: Any) -> str: # See https://docs.sqlalchemy.org/en/20/core/custom_types.html return "JSONB" @@ -271,7 +271,7 @@ class LatencyMs(expression.FunctionElement[float]): name = "latency_ms" -@compiles(LatencyMs) # type: ignore +@compiles(LatencyMs) def _(element: Any, compiler: Any, **kw: Any) -> Any: # See https://docs.sqlalchemy.org/en/20/core/compiler.html start_time, end_time = list(element.clauses) @@ -287,7 +287,7 @@ def _(element: Any, compiler: Any, **kw: Any) -> Any: ) -@compiles(LatencyMs, "sqlite") # type: ignore +@compiles(LatencyMs, "sqlite") def _(element: Any, compiler: Any, **kw: Any) -> Any: # See https://docs.sqlalchemy.org/en/20/core/compiler.html start_time, end_time = list(element.clauses) @@ -308,21 +308,21 @@ class TextContains(expression.FunctionElement[str]): name = "text_contains" -@compiles(TextContains) # type: ignore +@compiles(TextContains) def _(element: Any, compiler: Any, **kw: Any) -> Any: # See https://docs.sqlalchemy.org/en/20/core/compiler.html string, substring = list(element.clauses) return compiler.process(string.contains(substring), **kw) -@compiles(TextContains, "postgresql") # type: ignore +@compiles(TextContains, "postgresql") def _(element: Any, compiler: Any, **kw: Any) -> Any: # See https://docs.sqlalchemy.org/en/20/core/compiler.html string, substring = list(element.clauses) return compiler.process(func.strpos(string, substring) > 0, **kw) -@compiles(TextContains, "sqlite") # type: ignore +@compiles(TextContains, "sqlite") def _(element: Any, compiler: Any, **kw: Any) -> Any: # See https://docs.sqlalchemy.org/en/20/core/compiler.html string, substring = list(element.clauses) diff --git a/src/phoenix/server/api/input_types/InvocationParameters.py b/src/phoenix/server/api/input_types/InvocationParameters.py new file mode 100644 index 0000000000..736db6b052 --- /dev/null +++ b/src/phoenix/server/api/input_types/InvocationParameters.py @@ -0,0 +1,18 @@ +from typing import List, Optional + +import strawberry +from strawberry import UNSET + + +@strawberry.input +class InvocationParameters: + """ + Invocation parameters interface shared between different providers. + """ + + temperature: Optional[float] = UNSET + max_completion_tokens: Optional[int] = UNSET + max_tokens: Optional[int] = UNSET + top_p: Optional[float] = UNSET + stop: Optional[List[str]] = UNSET + seed: Optional[int] = UNSET diff --git a/src/phoenix/server/api/subscriptions.py b/src/phoenix/server/api/subscriptions.py index 24755dfc44..b2ae9b0c40 100644 --- a/src/phoenix/server/api/subscriptions.py +++ b/src/phoenix/server/api/subscriptions.py @@ -1,12 +1,9 @@ -import json -from dataclasses import asdict from datetime import datetime -from enum import Enum from itertools import chain -from json import JSONEncoder -from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple import strawberry +from openinference.instrumentation import safe_json_dumps from openinference.semconv.trace import ( MessageAttributes, OpenInferenceMimeTypeValues, @@ -17,18 +14,20 @@ from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import StatusCode -from pydantic import BaseModel from sqlalchemy import insert, select +from strawberry import UNSET from strawberry.types import Info from typing_extensions import assert_never from phoenix.db import models from phoenix.server.api.context import Context from phoenix.server.api.input_types.ChatCompletionMessageInput import ChatCompletionMessageInput +from phoenix.server.api.input_types.InvocationParameters import InvocationParameters from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey from phoenix.server.dml_event import SpanInsertEvent from phoenix.trace.attributes import unflatten +from phoenix.utilities.json import jsonify if TYPE_CHECKING: from openai.types.chat import ( @@ -48,7 +47,8 @@ class GenerativeModelInput: class ChatCompletionInput: messages: List[ChatCompletionMessageInput] model: GenerativeModelInput - api_key: Optional[str] = None + invocation_parameters: InvocationParameters + api_key: Optional[str] = UNSET def to_openai_chat_completion_param( @@ -94,7 +94,9 @@ async def chat_completion( ) -> AsyncIterator[str]: from openai import AsyncOpenAI - client = AsyncOpenAI(api_key=input.api_key) + api_key = input.api_key or None + client = AsyncOpenAI(api_key=api_key) + invocation_parameters = jsonify(input.invocation_parameters) in_memory_span_exporter = InMemorySpanExporter() tracer_provider = TracerProvider() @@ -109,8 +111,9 @@ async def chat_completion( chain( _llm_span_kind(), _llm_model_name(input.model.name), - _input_value_and_mime_type(input), _llm_input_messages(input.messages), + _llm_invocation_parameters(invocation_parameters), + _input_value_and_mime_type(input), ) ), ) as span: @@ -121,6 +124,7 @@ async def chat_completion( messages=(to_openai_chat_completion_param(message) for message in input.messages), model=input.model.name, stream=True, + **invocation_parameters, ): chunks.append(chunk) choice = chunk.choices[0] @@ -206,14 +210,18 @@ def _llm_model_name(model_name: str) -> Iterator[Tuple[str, Any]]: yield LLM_MODEL_NAME, model_name +def _llm_invocation_parameters(invocation_parameters: Dict[str, Any]) -> Iterator[Tuple[str, Any]]: + yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters) + + def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str, Any]]: yield INPUT_MIME_TYPE, JSON - yield INPUT_VALUE, json.dumps(asdict(input), cls=GraphQLInputJSONEncoder) + yield INPUT_VALUE, safe_json_dumps(jsonify(input)) def _output_value_and_mime_type(output: Any) -> Iterator[Tuple[str, Any]]: yield OUTPUT_MIME_TYPE, JSON - yield OUTPUT_VALUE, json.dumps(output, cls=ChatCompletionOutputJSONEncoder) + yield OUTPUT_VALUE, safe_json_dumps(jsonify(output)) def _llm_input_messages(messages: List[ChatCompletionMessageInput]) -> Iterator[Tuple[str, Any]]: @@ -242,20 +250,6 @@ def _datetime(*, epoch_nanoseconds: float) -> datetime: return datetime.fromtimestamp(epoch_seconds) -class GraphQLInputJSONEncoder(JSONEncoder): - def default(self, obj: Any) -> Any: - if isinstance(obj, Enum): - return obj.value - return super().default(obj) - - -class ChatCompletionOutputJSONEncoder(JSONEncoder): - def default(self, obj: Any) -> Any: - if isinstance(obj, BaseModel): - return obj.model_dump() - return super().default(obj) - - JSON = OpenInferenceMimeTypeValues.JSON.value LLM = OpenInferenceSpanKindValues.LLM.value @@ -268,6 +262,7 @@ def default(self, obj: Any) -> Any: LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME +LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE diff --git a/src/phoenix/utilities/json.py b/src/phoenix/utilities/json.py index 691a31c8b1..d9c456dcbf 100644 --- a/src/phoenix/utilities/json.py +++ b/src/phoenix/utilities/json.py @@ -5,6 +5,8 @@ from typing import Any, Mapping, Sequence, Union, get_args, get_origin import numpy as np +from strawberry import UNSET +from strawberry.types.base import StrawberryObjectDefinition def jsonify(obj: Any) -> Any: @@ -19,6 +21,15 @@ def jsonify(obj: Any) -> Any: return [jsonify(v) for v in obj] if isinstance(obj, (dict, Mapping)): return {jsonify(k): jsonify(v) for k, v in obj.items()} + is_strawberry_type = isinstance( + getattr(obj, "__strawberry_definition__", None), StrawberryObjectDefinition + ) + if is_strawberry_type: + return { + k: jsonify(v) + for field in dataclasses.fields(obj) + if (v := getattr(obj, (k := field.name))) is not UNSET + } if dataclasses.is_dataclass(obj): return { k: jsonify(v)