diff --git a/app/schema.graphql b/app/schema.graphql
index 3a5441f43d..26d7db3b26 100644
--- a/app/schema.graphql
+++ b/app/schema.graphql
@@ -824,6 +824,11 @@ type Functionality {
tracing: Boolean!
}
+type GenerativeModel {
+ name: String!
+ providerKey: GenerativeProviderKey!
+}
+
input GenerativeModelInput {
providerKey: GenerativeProviderKey!
name: String!
@@ -944,8 +949,8 @@ type Model {
): PerformanceTimeSeries!
}
-input ModelNamesInput {
- providerKey: GenerativeProviderKey!
+input ModelsInput {
+ providerKey: GenerativeProviderKey
}
type Mutation {
@@ -1141,7 +1146,7 @@ type PromptResponse {
type Query {
modelProviders: [GenerativeProvider!]!
- modelNames(input: ModelNamesInput!): [String!]!
+ models(input: ModelsInput = null): [GenerativeModel!]!
users(first: Int = 50, last: Int, after: String, before: String): UserConnection!
userRoles: [UserRole!]!
userApiKeys: [UserApiKey!]!
diff --git a/app/src/pages/playground/ModelPicker.tsx b/app/src/pages/playground/ModelPicker.tsx
index 1d9258e803..3a91eb8503 100644
--- a/app/src/pages/playground/ModelPicker.tsx
+++ b/app/src/pages/playground/ModelPicker.tsx
@@ -22,7 +22,9 @@ export function ModelPicker({ query, onChange, ...props }: ModelPickerProps) {
@argumentDefinitions(
providerKey: { type: "GenerativeProviderKey!", defaultValue: OPENAI }
) {
- modelNames(input: { providerKey: $providerKey })
+ models(input: { providerKey: $providerKey }) {
+ name
+ }
}
`,
query
@@ -42,8 +44,8 @@ export function ModelPicker({ query, onChange, ...props }: ModelPickerProps) {
width={"100%"}
{...props}
>
- {data.modelNames.map((modelName) => {
- return - {modelName}
;
+ {data.models.map(({ name }) => {
+ return - {name}
;
})}
);
diff --git a/app/src/pages/playground/__generated__/ModelConfigButtonDialogQuery.graphql.ts b/app/src/pages/playground/__generated__/ModelConfigButtonDialogQuery.graphql.ts
index 187da00815..50d6f089ee 100644
--- a/app/src/pages/playground/__generated__/ModelConfigButtonDialogQuery.graphql.ts
+++ b/app/src/pages/playground/__generated__/ModelConfigButtonDialogQuery.graphql.ts
@@ -1,5 +1,5 @@
/**
- * @generated SignedSource<<176456afea57f0245ab80564600db337>>
+ * @generated SignedSource<<814210e1a750a2c446a2043b5a6ab0b8>>
* @lightSyntaxTransform
* @nogrep
*/
@@ -36,7 +36,14 @@ v1 = [
"name": "providerKey",
"variableName": "providerKey"
}
-];
+],
+v2 = {
+ "alias": null,
+ "args": null,
+ "kind": "ScalarField",
+ "name": "name",
+ "storageKey": null
+};
return {
"fragment": {
"argumentDefinitions": (v0/*: any*/),
@@ -79,13 +86,7 @@ return {
"name": "key",
"storageKey": null
},
- {
- "alias": null,
- "args": null,
- "kind": "ScalarField",
- "name": "name",
- "storageKey": null
- }
+ (v2/*: any*/)
],
"storageKey": null
},
@@ -98,19 +99,24 @@ return {
"name": "input"
}
],
- "kind": "ScalarField",
- "name": "modelNames",
+ "concreteType": "GenerativeModel",
+ "kind": "LinkedField",
+ "name": "models",
+ "plural": true,
+ "selections": [
+ (v2/*: any*/)
+ ],
"storageKey": null
}
]
},
"params": {
- "cacheID": "34f8d81e91b335ca310c9be756719426",
+ "cacheID": "8e6ad232aae761280ca29a0571fe7c23",
"id": null,
"metadata": {},
"name": "ModelConfigButtonDialogQuery",
"operationKind": "query",
- "text": "query ModelConfigButtonDialogQuery(\n $providerKey: GenerativeProviderKey!\n) {\n ...ModelProviderPickerFragment\n ...ModelPickerFragment_3rERSq\n}\n\nfragment ModelPickerFragment_3rERSq on Query {\n modelNames(input: {providerKey: $providerKey})\n}\n\nfragment ModelProviderPickerFragment on Query {\n modelProviders {\n key\n name\n }\n}\n"
+ "text": "query ModelConfigButtonDialogQuery(\n $providerKey: GenerativeProviderKey!\n) {\n ...ModelProviderPickerFragment\n ...ModelPickerFragment_3rERSq\n}\n\nfragment ModelPickerFragment_3rERSq on Query {\n models(input: {providerKey: $providerKey}) {\n name\n }\n}\n\nfragment ModelProviderPickerFragment on Query {\n modelProviders {\n key\n name\n }\n}\n"
}
};
})();
diff --git a/app/src/pages/playground/__generated__/ModelPickerFragment.graphql.ts b/app/src/pages/playground/__generated__/ModelPickerFragment.graphql.ts
index 9dda921ba2..e77eac8851 100644
--- a/app/src/pages/playground/__generated__/ModelPickerFragment.graphql.ts
+++ b/app/src/pages/playground/__generated__/ModelPickerFragment.graphql.ts
@@ -1,5 +1,5 @@
/**
- * @generated SignedSource<<6931dc528aea2b22801320e6d297dd58>>
+ * @generated SignedSource<>
* @lightSyntaxTransform
* @nogrep
*/
@@ -11,7 +11,9 @@
import { Fragment, ReaderFragment } from 'relay-runtime';
import { FragmentRefs } from "relay-runtime";
export type ModelPickerFragment$data = {
- readonly modelNames: ReadonlyArray;
+ readonly models: ReadonlyArray<{
+ readonly name: string;
+ }>;
readonly " $fragmentType": "ModelPickerFragment";
};
export type ModelPickerFragment$key = {
@@ -46,8 +48,19 @@ const node: ReaderFragment = {
"name": "input"
}
],
- "kind": "ScalarField",
- "name": "modelNames",
+ "concreteType": "GenerativeModel",
+ "kind": "LinkedField",
+ "name": "models",
+ "plural": true,
+ "selections": [
+ {
+ "alias": null,
+ "args": null,
+ "kind": "ScalarField",
+ "name": "name",
+ "storageKey": null
+ }
+ ],
"storageKey": null
}
],
@@ -55,6 +68,6 @@ const node: ReaderFragment = {
"abstractKey": null
};
-(node as any).hash = "bb2557396c978bb5f57c7a4f67d756b1";
+(node as any).hash = "1e660ad77ce19db1c1bbe8698a661b4f";
export default node;
diff --git a/src/phoenix/db/migrations/versions/10460e46d750_datasets.py b/src/phoenix/db/migrations/versions/10460e46d750_datasets.py
index 3a4aeec79e..8d4eea00c4 100644
--- a/src/phoenix/db/migrations/versions/10460e46d750_datasets.py
+++ b/src/phoenix/db/migrations/versions/10460e46d750_datasets.py
@@ -20,7 +20,7 @@ class JSONB(JSON):
__visit_name__ = "JSONB"
-@compiles(JSONB, "sqlite") # type: ignore
+@compiles(JSONB, "sqlite")
def _(*args: Any, **kwargs: Any) -> str:
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
return "JSONB"
diff --git a/src/phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py b/src/phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py
index 141a378335..9b5a36c553 100644
--- a/src/phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py
+++ b/src/phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py
@@ -32,7 +32,7 @@ class JSONB(JSON):
__visit_name__ = "JSONB"
-@compiles(JSONB, "sqlite") # type: ignore
+@compiles(JSONB, "sqlite")
def _(*args: Any, **kwargs: Any) -> str:
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
return "JSONB"
diff --git a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py
index e838f04f04..0baa6b90d5 100644
--- a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py
+++ b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py
@@ -20,7 +20,7 @@ class JSONB(JSON):
__visit_name__ = "JSONB"
-@compiles(JSONB, "sqlite") # type: ignore
+@compiles(JSONB, "sqlite")
def _(*args: Any, **kwargs: Any) -> str:
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
return "JSONB"
diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py
index ad3070f6db..2adef3b19f 100644
--- a/src/phoenix/db/models.py
+++ b/src/phoenix/db/models.py
@@ -50,7 +50,7 @@ class JSONB(JSON):
__visit_name__ = "JSONB"
-@compiles(JSONB, "sqlite") # type: ignore
+@compiles(JSONB, "sqlite")
def _(*args: Any, **kwargs: Any) -> str:
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
return "JSONB"
@@ -271,7 +271,7 @@ class LatencyMs(expression.FunctionElement[float]):
name = "latency_ms"
-@compiles(LatencyMs) # type: ignore
+@compiles(LatencyMs)
def _(element: Any, compiler: Any, **kw: Any) -> Any:
# See https://docs.sqlalchemy.org/en/20/core/compiler.html
start_time, end_time = list(element.clauses)
@@ -287,7 +287,7 @@ def _(element: Any, compiler: Any, **kw: Any) -> Any:
)
-@compiles(LatencyMs, "sqlite") # type: ignore
+@compiles(LatencyMs, "sqlite")
def _(element: Any, compiler: Any, **kw: Any) -> Any:
# See https://docs.sqlalchemy.org/en/20/core/compiler.html
start_time, end_time = list(element.clauses)
@@ -308,21 +308,21 @@ class TextContains(expression.FunctionElement[str]):
name = "text_contains"
-@compiles(TextContains) # type: ignore
+@compiles(TextContains)
def _(element: Any, compiler: Any, **kw: Any) -> Any:
# See https://docs.sqlalchemy.org/en/20/core/compiler.html
string, substring = list(element.clauses)
return compiler.process(string.contains(substring), **kw)
-@compiles(TextContains, "postgresql") # type: ignore
+@compiles(TextContains, "postgresql")
def _(element: Any, compiler: Any, **kw: Any) -> Any:
# See https://docs.sqlalchemy.org/en/20/core/compiler.html
string, substring = list(element.clauses)
return compiler.process(func.strpos(string, substring) > 0, **kw)
-@compiles(TextContains, "sqlite") # type: ignore
+@compiles(TextContains, "sqlite")
def _(element: Any, compiler: Any, **kw: Any) -> Any:
# See https://docs.sqlalchemy.org/en/20/core/compiler.html
string, substring = list(element.clauses)
diff --git a/src/phoenix/server/api/queries.py b/src/phoenix/server/api/queries.py
index 4c8319555e..fe09ca6f25 100644
--- a/src/phoenix/server/api/queries.py
+++ b/src/phoenix/server/api/queries.py
@@ -11,7 +11,7 @@
from strawberry import ID, UNSET
from strawberry.relay import Connection, GlobalID, Node
from strawberry.types import Info
-from typing_extensions import Annotated, TypeAlias, assert_never
+from typing_extensions import Annotated, TypeAlias
from phoenix.db import enums, models
from phoenix.db.models import (
@@ -58,6 +58,7 @@
from phoenix.server.api.types.ExperimentComparison import ExperimentComparison, RunComparisonItem
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
from phoenix.server.api.types.Functionality import Functionality
+from phoenix.server.api.types.GenerativeModel import GenerativeModel
from phoenix.server.api.types.GenerativeProvider import (
GenerativeProvider,
GenerativeProviderKey,
@@ -81,8 +82,8 @@
@strawberry.input
-class ModelNamesInput:
- provider_key: GenerativeProviderKey
+class ModelsInput:
+ provider_key: Optional[GenerativeProviderKey]
@strawberry.type
@@ -105,63 +106,51 @@ async def model_providers(self) -> List[GenerativeProvider]:
]
@strawberry.field
- async def model_names(self, input: ModelNamesInput) -> List[str]:
- if (provider_key := input.provider_key) == GenerativeProviderKey.OPENAI:
- return [
- "o1-preview",
- "o1-preview-2024-09-12",
- "o1-mini",
- "o1-mini-2024-09-12",
- "gpt-4o",
- "gpt-4o-2024-08-06",
- "gpt-4o-2024-05-13",
- "chatgpt-4o-latest",
- "gpt-4o-mini",
- "gpt-4o-mini-2024-07-18",
- "gpt-4-turbo",
- "gpt-4-turbo-2024-04-09",
- "gpt-4-turbo-preview",
- "gpt-4-0125-preview",
- "gpt-4-1106-preview",
- "gpt-4",
- "gpt-4-0613",
- "gpt-3.5-turbo-0125",
- "gpt-3.5-turbo",
- "gpt-3.5-turbo-1106",
- "gpt-3.5-turbo-instruct",
- ]
- if provider_key == GenerativeProviderKey.AZURE_OPENAI:
- return [
- "o1-preview",
- "o1-preview-2024-09-12",
- "o1-mini",
- "o1-mini-2024-09-12",
- "gpt-4o",
- "gpt-4o-2024-08-06",
- "gpt-4o-2024-05-13",
- "chatgpt-4o-latest",
- "gpt-4o-mini",
- "gpt-4o-mini-2024-07-18",
- "gpt-4-turbo",
- "gpt-4-turbo-2024-04-09",
- "gpt-4-turbo-preview",
- "gpt-4-0125-preview",
- "gpt-4-1106-preview",
- "gpt-4",
- "gpt-4-0613",
- "gpt-3.5-turbo-0125",
- "gpt-3.5-turbo",
- "gpt-3.5-turbo-1106",
- "gpt-3.5-turbo-instruct",
- ]
- if provider_key == GenerativeProviderKey.ANTHROPIC:
- return [
- "claude-3-5-sonnet-20240620",
- "claude-3-opus-20240229",
- "claude-3-sonnet-20240229",
- "claude-3-haiku-20240307",
- ]
- assert_never(provider_key)
+ async def models(self, input: Optional[ModelsInput] = None) -> List[GenerativeModel]:
+ openai_models = [
+ "o1-preview",
+ "o1-preview-2024-09-12",
+ "o1-mini",
+ "o1-mini-2024-09-12",
+ "gpt-4o",
+ "gpt-4o-2024-08-06",
+ "gpt-4o-2024-05-13",
+ "chatgpt-4o-latest",
+ "gpt-4o-mini",
+ "gpt-4o-mini-2024-07-18",
+ "gpt-4-turbo",
+ "gpt-4-turbo-2024-04-09",
+ "gpt-4-turbo-preview",
+ "gpt-4-0125-preview",
+ "gpt-4-1106-preview",
+ "gpt-4",
+ "gpt-4-0613",
+ "gpt-3.5-turbo-0125",
+ "gpt-3.5-turbo",
+ "gpt-3.5-turbo-1106",
+ "gpt-3.5-turbo-instruct",
+ ]
+ anthropic_models = [
+ "claude-3-5-sonnet-20240620",
+ "claude-3-opus-20240229",
+ "claude-3-sonnet-20240229",
+ "claude-3-haiku-20240307",
+ ]
+ openai_generative_models = [
+ GenerativeModel(name=model_name, provider_key=GenerativeProviderKey.OPENAI)
+ for model_name in openai_models
+ ]
+ anthropic_generative_models = [
+ GenerativeModel(name=model_name, provider_key=GenerativeProviderKey.ANTHROPIC)
+ for model_name in anthropic_models
+ ]
+
+ all_models = openai_generative_models + anthropic_generative_models
+
+ if input is not None and input.provider_key is not None:
+ return [model for model in all_models if model.provider_key == input.provider_key]
+
+ return all_models
@strawberry.field(permission_classes=[IsAdmin]) # type: ignore
async def users(
diff --git a/src/phoenix/server/api/types/GenerativeModel.py b/src/phoenix/server/api/types/GenerativeModel.py
new file mode 100644
index 0000000000..00e5313f6f
--- /dev/null
+++ b/src/phoenix/server/api/types/GenerativeModel.py
@@ -0,0 +1,9 @@
+import strawberry
+
+from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
+
+
+@strawberry.type
+class GenerativeModel:
+ name: str
+ provider_key: GenerativeProviderKey