Skip to content

Commit

Permalink
feat: make model listing more generic (#5022)
Browse files Browse the repository at this point in the history
* feat: make model listing more generic

* feat(playgrround): Make input optional for models query

* cleanup

* fix

* fix more

* Update src/phoenix/server/api/queries.py

Co-authored-by: Xander Song <axiomofjoy@gmail.com>

* remove azure

---------

Co-authored-by: Tony Powell <apowell@arize.com>
Co-authored-by: Xander Song <axiomofjoy@gmail.com>
  • Loading branch information
3 people authored Oct 16, 2024
1 parent 93ca8c4 commit fd0c6ce
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 84 deletions.
11 changes: 8 additions & 3 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,11 @@ type Functionality {
tracing: Boolean!
}

type GenerativeModel {
name: String!
providerKey: GenerativeProviderKey!
}

input GenerativeModelInput {
providerKey: GenerativeProviderKey!
name: String!
Expand Down Expand Up @@ -963,8 +968,8 @@ type Model {
): PerformanceTimeSeries!
}

input ModelNamesInput {
providerKey: GenerativeProviderKey!
input ModelsInput {
providerKey: GenerativeProviderKey
}

type Mutation {
Expand Down Expand Up @@ -1160,7 +1165,7 @@ type PromptResponse {

type Query {
modelProviders: [GenerativeProvider!]!
modelNames(input: ModelNamesInput!): [String!]!
models(input: ModelsInput = null): [GenerativeModel!]!
users(first: Int = 50, last: Int, after: String, before: String): UserConnection!
userRoles: [UserRole!]!
userApiKeys: [UserApiKey!]!
Expand Down
8 changes: 5 additions & 3 deletions app/src/pages/playground/ModelPicker.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ export function ModelPicker({ query, onChange, ...props }: ModelPickerProps) {
@argumentDefinitions(
providerKey: { type: "GenerativeProviderKey!", defaultValue: OPENAI }
) {
modelNames(input: { providerKey: $providerKey })
models(input: { providerKey: $providerKey }) {
name
}
}
`,
query
Expand All @@ -42,8 +44,8 @@ export function ModelPicker({ query, onChange, ...props }: ModelPickerProps) {
width={"100%"}
{...props}
>
{data.modelNames.map((modelName) => {
return <Item key={modelName}>{modelName}</Item>;
{data.models.map(({ name }) => {
return <Item key={name}>{name}</Item>;
})}
</Picker>
);
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

109 changes: 49 additions & 60 deletions src/phoenix/server/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from strawberry import ID, UNSET
from strawberry.relay import Connection, GlobalID, Node
from strawberry.types import Info
from typing_extensions import Annotated, TypeAlias, assert_never
from typing_extensions import Annotated, TypeAlias

from phoenix.db import enums, models
from phoenix.db.models import (
Expand Down Expand Up @@ -58,6 +58,7 @@
from phoenix.server.api.types.ExperimentComparison import ExperimentComparison, RunComparisonItem
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
from phoenix.server.api.types.Functionality import Functionality
from phoenix.server.api.types.GenerativeModel import GenerativeModel
from phoenix.server.api.types.GenerativeProvider import (
GenerativeProvider,
GenerativeProviderKey,
Expand All @@ -81,8 +82,8 @@


@strawberry.input
class ModelNamesInput:
provider_key: GenerativeProviderKey
class ModelsInput:
provider_key: Optional[GenerativeProviderKey]


@strawberry.type
Expand All @@ -105,63 +106,51 @@ async def model_providers(self) -> List[GenerativeProvider]:
]

@strawberry.field
async def model_names(self, input: ModelNamesInput) -> List[str]:
if (provider_key := input.provider_key) == GenerativeProviderKey.OPENAI:
return [
"o1-preview",
"o1-preview-2024-09-12",
"o1-mini",
"o1-mini-2024-09-12",
"gpt-4o",
"gpt-4o-2024-08-06",
"gpt-4o-2024-05-13",
"chatgpt-4o-latest",
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-4",
"gpt-4-0613",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-instruct",
]
if provider_key == GenerativeProviderKey.AZURE_OPENAI:
return [
"o1-preview",
"o1-preview-2024-09-12",
"o1-mini",
"o1-mini-2024-09-12",
"gpt-4o",
"gpt-4o-2024-08-06",
"gpt-4o-2024-05-13",
"chatgpt-4o-latest",
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-4",
"gpt-4-0613",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-instruct",
]
if provider_key == GenerativeProviderKey.ANTHROPIC:
return [
"claude-3-5-sonnet-20240620",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
]
assert_never(provider_key)
async def models(self, input: Optional[ModelsInput] = None) -> List[GenerativeModel]:
openai_models = [
"o1-preview",
"o1-preview-2024-09-12",
"o1-mini",
"o1-mini-2024-09-12",
"gpt-4o",
"gpt-4o-2024-08-06",
"gpt-4o-2024-05-13",
"chatgpt-4o-latest",
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-4",
"gpt-4-0613",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-instruct",
]
anthropic_models = [
"claude-3-5-sonnet-20240620",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
]
openai_generative_models = [
GenerativeModel(name=model_name, provider_key=GenerativeProviderKey.OPENAI)
for model_name in openai_models
]
anthropic_generative_models = [
GenerativeModel(name=model_name, provider_key=GenerativeProviderKey.ANTHROPIC)
for model_name in anthropic_models
]

all_models = openai_generative_models + anthropic_generative_models

if input is not None and input.provider_key is not None:
return [model for model in all_models if model.provider_key == input.provider_key]

return all_models

@strawberry.field(permission_classes=[IsAdmin]) # type: ignore
async def users(
Expand Down
9 changes: 9 additions & 0 deletions src/phoenix/server/api/types/GenerativeModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import strawberry

from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey


@strawberry.type
class GenerativeModel:
name: str
provider_key: GenerativeProviderKey

0 comments on commit fd0c6ce

Please sign in to comment.