Skip to content

Commit

Permalink
fix(playground): plumb through credentials (#5003)
Browse files Browse the repository at this point in the history
Co-authored-by: Parker Stafford <52351508+Parker-Stafford@users.noreply.github.com>
  • Loading branch information
axiomofjoy and Parker-Stafford authored Oct 15, 2024
1 parent 716b1c7 commit 0fa0c87
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 56 deletions.
1 change: 1 addition & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ union Bin = NominalBin | IntervalBin | MissingValueBin
input ChatCompletionInput {
messages: [ChatCompletionMessageInput!]!
model: GenerativeModelInput!
apiKey: String = null
}

input ChatCompletionMessageInput {
Expand Down
20 changes: 9 additions & 11 deletions app/src/pages/playground/PlaygroundCredentialsDropdown.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@ import {

import { useCredentialsContext } from "@phoenix/contexts/CredentialsContext";
import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext";
import { CredentialKey } from "@phoenix/store";

export const ProviderToCredentialKeyMap: Record<ModelProvider, CredentialKey> =
{
OPENAI: "OPENAI_API_KEY",
ANTHROPIC: "ANTHROPIC_API_KEY",
AZURE_OPENAI: "AZURE_OPENAI_API_KEY",
};
export const ProviderToCredentialNameMap: Record<ModelProvider, string> = {
OPENAI: "OPENAI_API_KEY",
ANTHROPIC: "ANTHROPIC_API_KEY",
AZURE_OPENAI: "AZURE_OPENAI_API_KEY",
};

export function PlaygroundCredentialsDropdown() {
const currentProviders = usePlaygroundContext((state) =>
Expand Down Expand Up @@ -54,17 +52,17 @@ export function PlaygroundCredentialsDropdown() {
</Text>
<Form>
{currentProviders.map((provider) => {
const credentialKey = ProviderToCredentialKeyMap[provider];
const credentialName = ProviderToCredentialNameMap[provider];
return (
<TextField
key={provider}
label={credentialKey}
label={credentialName}
type="password"
isRequired
onChange={(value) => {
setCredential({ credential: credentialKey, value });
setCredential({ provider, value });
}}
value={credentials[credentialKey]}
value={credentials[provider]}
/>
);
})}
Expand Down
8 changes: 7 additions & 1 deletion app/src/pages/playground/PlaygroundOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { graphql, GraphQLSubscriptionConfig } from "relay-runtime";

import { Card, Flex, Icon, Icons } from "@arizeai/components";

import { useCredentialsContext } from "@phoenix/contexts/CredentialsContext";
import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext";
import { useChatMessageStyles } from "@phoenix/hooks/useChatMessageStyles";
import { ChatMessage, generateMessageId } from "@phoenix/store";
Expand Down Expand Up @@ -104,8 +105,11 @@ function useChatCompletionSubscription({
subscription PlaygroundOutputSubscription(
$messages: [ChatCompletionMessageInput!]!
$model: GenerativeModelInput!
$apiKey: String
) {
chatCompletion(input: { messages: $messages, model: $model })
chatCompletion(
input: { messages: $messages, model: $model, apiKey: $apiKey }
)
}
`,
variables: params,
Expand Down Expand Up @@ -156,6 +160,7 @@ function toGqlChatCompletionRole(

function PlaygroundOutputText(props: PlaygroundInstanceProps) {
const instances = usePlaygroundContext((state) => state.instances);
const credentials = useCredentialsContext((state) => state);
const instance = instances.find(
(instance) => instance.id === props.playgroundInstanceId
);
Expand All @@ -182,6 +187,7 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) {
providerKey: instance.model.provider,
name: instance.model.modelName || "",
},
apiKey: credentials[instance.model.provider],
},
runId: instance.activeRunId,
onNext: (response) => {
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 4 additions & 22 deletions app/src/store/credentialsStore.tsx
Original file line number Diff line number Diff line change
@@ -1,41 +1,23 @@
import { create, StateCreator } from "zustand";
import { devtools, persist } from "zustand/middleware";

export interface CredentialsProps {
/**
* The API key for the OpenAI API.
*/
OPENAI_API_KEY?: string;
/**
* The API key for the Azure OpenAI API.
*/
AZURE_OPENAI_API_KEY?: string;
/**
* The API key for the Anthropic API.
*/
ANTHROPIC_API_KEY?: string;
}

export type CredentialKey = keyof CredentialsProps;
export type CredentialsProps = Partial<Record<ModelProvider, string>>;

export interface CredentialsState extends CredentialsProps {
/**
* Setter for a given credential
* @param credential the name of the credential to set
* @param value the value of the credential
*/
setCredential: (params: {
credential: keyof CredentialsProps;
value: string;
}) => void;
setCredential: (params: { provider: ModelProvider; value: string }) => void;
}

export const createCredentialsStore = (
initialProps: Partial<CredentialsProps>
) => {
const credentialsStore: StateCreator<CredentialsState> = (set) => ({
setCredential: ({ credential, value }) => {
set({ [credential]: value });
setCredential: ({ provider, value }) => {
set({ [provider]: value });
},
...initialProps,
});
Expand Down
3 changes: 2 additions & 1 deletion src/phoenix/server/api/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class GenerativeModelInput:
class ChatCompletionInput:
messages: List[ChatCompletionMessageInput]
model: GenerativeModelInput
api_key: Optional[str] = None


def to_openai_chat_completion_param(
Expand Down Expand Up @@ -93,7 +94,7 @@ async def chat_completion(
) -> AsyncIterator[str]:
from openai import AsyncOpenAI

client = AsyncOpenAI()
client = AsyncOpenAI(api_key=input.api_key)

in_memory_span_exporter = InMemorySpanExporter()
tracer_provider = TracerProvider()
Expand Down

0 comments on commit 0fa0c87

Please sign in to comment.