diff --git a/app/schema.graphql b/app/schema.graphql index 3a5441f43d..26d7db3b26 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -824,6 +824,11 @@ type Functionality { tracing: Boolean! } +type GenerativeModel { + name: String! + providerKey: GenerativeProviderKey! +} + input GenerativeModelInput { providerKey: GenerativeProviderKey! name: String! @@ -944,8 +949,8 @@ type Model { ): PerformanceTimeSeries! } -input ModelNamesInput { - providerKey: GenerativeProviderKey! +input ModelsInput { + providerKey: GenerativeProviderKey } type Mutation { @@ -1141,7 +1146,7 @@ type PromptResponse { type Query { modelProviders: [GenerativeProvider!]! - modelNames(input: ModelNamesInput!): [String!]! + models(input: ModelsInput = null): [GenerativeModel!]! users(first: Int = 50, last: Int, after: String, before: String): UserConnection! userRoles: [UserRole!]! userApiKeys: [UserApiKey!]! diff --git a/app/src/pages/playground/ModelPicker.tsx b/app/src/pages/playground/ModelPicker.tsx index 1d9258e803..3a91eb8503 100644 --- a/app/src/pages/playground/ModelPicker.tsx +++ b/app/src/pages/playground/ModelPicker.tsx @@ -22,7 +22,9 @@ export function ModelPicker({ query, onChange, ...props }: ModelPickerProps) { @argumentDefinitions( providerKey: { type: "GenerativeProviderKey!", defaultValue: OPENAI } ) { - modelNames(input: { providerKey: $providerKey }) + models(input: { providerKey: $providerKey }) { + name + } } `, query @@ -42,8 +44,8 @@ export function ModelPicker({ query, onChange, ...props }: ModelPickerProps) { width={"100%"} {...props} > - {data.modelNames.map((modelName) => { - return {modelName}; + {data.models.map(({ name }) => { + return {name}; })} ); diff --git a/app/src/pages/playground/__generated__/ModelConfigButtonDialogQuery.graphql.ts b/app/src/pages/playground/__generated__/ModelConfigButtonDialogQuery.graphql.ts index 187da00815..50d6f089ee 100644 --- a/app/src/pages/playground/__generated__/ModelConfigButtonDialogQuery.graphql.ts +++ b/app/src/pages/playground/__generated__/ModelConfigButtonDialogQuery.graphql.ts @@ -1,5 +1,5 @@ /** - * @generated SignedSource<<176456afea57f0245ab80564600db337>> + * @generated SignedSource<<814210e1a750a2c446a2043b5a6ab0b8>> * @lightSyntaxTransform * @nogrep */ @@ -36,7 +36,14 @@ v1 = [ "name": "providerKey", "variableName": "providerKey" } -]; +], +v2 = { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "name", + "storageKey": null +}; return { "fragment": { "argumentDefinitions": (v0/*: any*/), @@ -79,13 +86,7 @@ return { "name": "key", "storageKey": null }, - { - "alias": null, - "args": null, - "kind": "ScalarField", - "name": "name", - "storageKey": null - } + (v2/*: any*/) ], "storageKey": null }, @@ -98,19 +99,24 @@ return { "name": "input" } ], - "kind": "ScalarField", - "name": "modelNames", + "concreteType": "GenerativeModel", + "kind": "LinkedField", + "name": "models", + "plural": true, + "selections": [ + (v2/*: any*/) + ], "storageKey": null } ] }, "params": { - "cacheID": "34f8d81e91b335ca310c9be756719426", + "cacheID": "8e6ad232aae761280ca29a0571fe7c23", "id": null, "metadata": {}, "name": "ModelConfigButtonDialogQuery", "operationKind": "query", - "text": "query ModelConfigButtonDialogQuery(\n $providerKey: GenerativeProviderKey!\n) {\n ...ModelProviderPickerFragment\n ...ModelPickerFragment_3rERSq\n}\n\nfragment ModelPickerFragment_3rERSq on Query {\n modelNames(input: {providerKey: $providerKey})\n}\n\nfragment ModelProviderPickerFragment on Query {\n modelProviders {\n key\n name\n }\n}\n" + "text": "query ModelConfigButtonDialogQuery(\n $providerKey: GenerativeProviderKey!\n) {\n ...ModelProviderPickerFragment\n ...ModelPickerFragment_3rERSq\n}\n\nfragment ModelPickerFragment_3rERSq on Query {\n models(input: {providerKey: $providerKey}) {\n name\n }\n}\n\nfragment ModelProviderPickerFragment on Query {\n modelProviders {\n key\n name\n }\n}\n" } }; })(); diff --git a/app/src/pages/playground/__generated__/ModelPickerFragment.graphql.ts b/app/src/pages/playground/__generated__/ModelPickerFragment.graphql.ts index 9dda921ba2..e77eac8851 100644 --- a/app/src/pages/playground/__generated__/ModelPickerFragment.graphql.ts +++ b/app/src/pages/playground/__generated__/ModelPickerFragment.graphql.ts @@ -1,5 +1,5 @@ /** - * @generated SignedSource<<6931dc528aea2b22801320e6d297dd58>> + * @generated SignedSource<> * @lightSyntaxTransform * @nogrep */ @@ -11,7 +11,9 @@ import { Fragment, ReaderFragment } from 'relay-runtime'; import { FragmentRefs } from "relay-runtime"; export type ModelPickerFragment$data = { - readonly modelNames: ReadonlyArray; + readonly models: ReadonlyArray<{ + readonly name: string; + }>; readonly " $fragmentType": "ModelPickerFragment"; }; export type ModelPickerFragment$key = { @@ -46,8 +48,19 @@ const node: ReaderFragment = { "name": "input" } ], - "kind": "ScalarField", - "name": "modelNames", + "concreteType": "GenerativeModel", + "kind": "LinkedField", + "name": "models", + "plural": true, + "selections": [ + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "name", + "storageKey": null + } + ], "storageKey": null } ], @@ -55,6 +68,6 @@ const node: ReaderFragment = { "abstractKey": null }; -(node as any).hash = "bb2557396c978bb5f57c7a4f67d756b1"; +(node as any).hash = "1e660ad77ce19db1c1bbe8698a661b4f"; export default node; diff --git a/src/phoenix/db/migrations/versions/10460e46d750_datasets.py b/src/phoenix/db/migrations/versions/10460e46d750_datasets.py index 3a4aeec79e..8d4eea00c4 100644 --- a/src/phoenix/db/migrations/versions/10460e46d750_datasets.py +++ b/src/phoenix/db/migrations/versions/10460e46d750_datasets.py @@ -20,7 +20,7 @@ class JSONB(JSON): __visit_name__ = "JSONB" -@compiles(JSONB, "sqlite") # type: ignore +@compiles(JSONB, "sqlite") def _(*args: Any, **kwargs: Any) -> str: # See https://docs.sqlalchemy.org/en/20/core/custom_types.html return "JSONB" diff --git a/src/phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py b/src/phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py index 141a378335..9b5a36c553 100644 --- a/src/phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +++ b/src/phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py @@ -32,7 +32,7 @@ class JSONB(JSON): __visit_name__ = "JSONB" -@compiles(JSONB, "sqlite") # type: ignore +@compiles(JSONB, "sqlite") def _(*args: Any, **kwargs: Any) -> str: # See https://docs.sqlalchemy.org/en/20/core/custom_types.html return "JSONB" diff --git a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py index e838f04f04..0baa6b90d5 100644 --- a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py +++ b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py @@ -20,7 +20,7 @@ class JSONB(JSON): __visit_name__ = "JSONB" -@compiles(JSONB, "sqlite") # type: ignore +@compiles(JSONB, "sqlite") def _(*args: Any, **kwargs: Any) -> str: # See https://docs.sqlalchemy.org/en/20/core/custom_types.html return "JSONB" diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index ad3070f6db..2adef3b19f 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -50,7 +50,7 @@ class JSONB(JSON): __visit_name__ = "JSONB" -@compiles(JSONB, "sqlite") # type: ignore +@compiles(JSONB, "sqlite") def _(*args: Any, **kwargs: Any) -> str: # See https://docs.sqlalchemy.org/en/20/core/custom_types.html return "JSONB" @@ -271,7 +271,7 @@ class LatencyMs(expression.FunctionElement[float]): name = "latency_ms" -@compiles(LatencyMs) # type: ignore +@compiles(LatencyMs) def _(element: Any, compiler: Any, **kw: Any) -> Any: # See https://docs.sqlalchemy.org/en/20/core/compiler.html start_time, end_time = list(element.clauses) @@ -287,7 +287,7 @@ def _(element: Any, compiler: Any, **kw: Any) -> Any: ) -@compiles(LatencyMs, "sqlite") # type: ignore +@compiles(LatencyMs, "sqlite") def _(element: Any, compiler: Any, **kw: Any) -> Any: # See https://docs.sqlalchemy.org/en/20/core/compiler.html start_time, end_time = list(element.clauses) @@ -308,21 +308,21 @@ class TextContains(expression.FunctionElement[str]): name = "text_contains" -@compiles(TextContains) # type: ignore +@compiles(TextContains) def _(element: Any, compiler: Any, **kw: Any) -> Any: # See https://docs.sqlalchemy.org/en/20/core/compiler.html string, substring = list(element.clauses) return compiler.process(string.contains(substring), **kw) -@compiles(TextContains, "postgresql") # type: ignore +@compiles(TextContains, "postgresql") def _(element: Any, compiler: Any, **kw: Any) -> Any: # See https://docs.sqlalchemy.org/en/20/core/compiler.html string, substring = list(element.clauses) return compiler.process(func.strpos(string, substring) > 0, **kw) -@compiles(TextContains, "sqlite") # type: ignore +@compiles(TextContains, "sqlite") def _(element: Any, compiler: Any, **kw: Any) -> Any: # See https://docs.sqlalchemy.org/en/20/core/compiler.html string, substring = list(element.clauses) diff --git a/src/phoenix/server/api/queries.py b/src/phoenix/server/api/queries.py index 4c8319555e..fe09ca6f25 100644 --- a/src/phoenix/server/api/queries.py +++ b/src/phoenix/server/api/queries.py @@ -11,7 +11,7 @@ from strawberry import ID, UNSET from strawberry.relay import Connection, GlobalID, Node from strawberry.types import Info -from typing_extensions import Annotated, TypeAlias, assert_never +from typing_extensions import Annotated, TypeAlias from phoenix.db import enums, models from phoenix.db.models import ( @@ -58,6 +58,7 @@ from phoenix.server.api.types.ExperimentComparison import ExperimentComparison, RunComparisonItem from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run from phoenix.server.api.types.Functionality import Functionality +from phoenix.server.api.types.GenerativeModel import GenerativeModel from phoenix.server.api.types.GenerativeProvider import ( GenerativeProvider, GenerativeProviderKey, @@ -81,8 +82,8 @@ @strawberry.input -class ModelNamesInput: - provider_key: GenerativeProviderKey +class ModelsInput: + provider_key: Optional[GenerativeProviderKey] @strawberry.type @@ -105,63 +106,51 @@ async def model_providers(self) -> List[GenerativeProvider]: ] @strawberry.field - async def model_names(self, input: ModelNamesInput) -> List[str]: - if (provider_key := input.provider_key) == GenerativeProviderKey.OPENAI: - return [ - "o1-preview", - "o1-preview-2024-09-12", - "o1-mini", - "o1-mini-2024-09-12", - "gpt-4o", - "gpt-4o-2024-08-06", - "gpt-4o-2024-05-13", - "chatgpt-4o-latest", - "gpt-4o-mini", - "gpt-4o-mini-2024-07-18", - "gpt-4-turbo", - "gpt-4-turbo-2024-04-09", - "gpt-4-turbo-preview", - "gpt-4-0125-preview", - "gpt-4-1106-preview", - "gpt-4", - "gpt-4-0613", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo", - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo-instruct", - ] - if provider_key == GenerativeProviderKey.AZURE_OPENAI: - return [ - "o1-preview", - "o1-preview-2024-09-12", - "o1-mini", - "o1-mini-2024-09-12", - "gpt-4o", - "gpt-4o-2024-08-06", - "gpt-4o-2024-05-13", - "chatgpt-4o-latest", - "gpt-4o-mini", - "gpt-4o-mini-2024-07-18", - "gpt-4-turbo", - "gpt-4-turbo-2024-04-09", - "gpt-4-turbo-preview", - "gpt-4-0125-preview", - "gpt-4-1106-preview", - "gpt-4", - "gpt-4-0613", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo", - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo-instruct", - ] - if provider_key == GenerativeProviderKey.ANTHROPIC: - return [ - "claude-3-5-sonnet-20240620", - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", - "claude-3-haiku-20240307", - ] - assert_never(provider_key) + async def models(self, input: Optional[ModelsInput] = None) -> List[GenerativeModel]: + openai_models = [ + "o1-preview", + "o1-preview-2024-09-12", + "o1-mini", + "o1-mini-2024-09-12", + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "chatgpt-4o-latest", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-turbo-preview", + "gpt-4-0125-preview", + "gpt-4-1106-preview", + "gpt-4", + "gpt-4-0613", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-instruct", + ] + anthropic_models = [ + "claude-3-5-sonnet-20240620", + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", + ] + openai_generative_models = [ + GenerativeModel(name=model_name, provider_key=GenerativeProviderKey.OPENAI) + for model_name in openai_models + ] + anthropic_generative_models = [ + GenerativeModel(name=model_name, provider_key=GenerativeProviderKey.ANTHROPIC) + for model_name in anthropic_models + ] + + all_models = openai_generative_models + anthropic_generative_models + + if input is not None and input.provider_key is not None: + return [model for model in all_models if model.provider_key == input.provider_key] + + return all_models @strawberry.field(permission_classes=[IsAdmin]) # type: ignore async def users( diff --git a/src/phoenix/server/api/types/GenerativeModel.py b/src/phoenix/server/api/types/GenerativeModel.py new file mode 100644 index 0000000000..00e5313f6f --- /dev/null +++ b/src/phoenix/server/api/types/GenerativeModel.py @@ -0,0 +1,9 @@ +import strawberry + +from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey + + +@strawberry.type +class GenerativeModel: + name: str + provider_key: GenerativeProviderKey