Skip to content

Commit

Permalink
feat(playground): rudimentary graphql support for messages input (#4907)
Browse files Browse the repository at this point in the history
* feat(playground): Chat message with role

* feat(playground): pass in messages and roles

* Update app/src/pages/playground/MessageRolePicker.tsx
  • Loading branch information
mikeldking committed Oct 11, 2024
1 parent 72f05de commit ee1f85b
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 27 deletions.
16 changes: 15 additions & 1 deletion app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,21 @@ enum AuthMethod {
union Bin = NominalBin | IntervalBin | MissingValueBin

input ChatCompletionInput {
message: String!
messages: [ChatCompletionMessageInput!]!
}

input ChatCompletionMessageInput {
role: ChatCompletionMessageRole!

"""The content of the message as JSON to support text and tools"""
content: JSON!
}

enum ChatCompletionMessageRole {
USER
SYSTEM
TOOL
AI
}

input ClearProjectInput {
Expand Down
6 changes: 4 additions & 2 deletions app/src/pages/playground/MessageRolePicker.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ export function MessageRolePicker({
data-testid="inferences-time-range"
aria-label={`Time range for the primary inferences`}
size="compact"
onSelectionChange={() => {}}
onSelectionChange={() => {
// TODO: fill out
}}
>
<Item key="system">System</Item>
<Item key="user">User</Item>
<Item key="assistant">Assistant</Item>
<Item key="ai">AI</Item>
</Picker>
);
}
65 changes: 54 additions & 11 deletions app/src/pages/playground/PlaygroundOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ import React, { useMemo, useState } from "react";
import { useSubscription } from "react-relay";
import { graphql, GraphQLSubscriptionConfig } from "relay-runtime";

import { Card } from "@arizeai/components";
import { Card, Flex, Icon, Icons } from "@arizeai/components";

import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext";
import { ChatMessage, ChatMessageRole } from "@phoenix/store";
import { assertUnreachable } from "@phoenix/typeUtils";

import {
ChatCompletionMessageInput,
ChatCompletionMessageRole,
PlaygroundOutputSubscription,
PlaygroundOutputSubscription$data,
PlaygroundOutputSubscription$variables,
} from "./__generated__/PlaygroundOutputSubscription.graphql";
import { PlaygroundInstanceProps } from "./types";

Expand All @@ -34,12 +39,12 @@ export function PlaygroundOutput(props: PlaygroundOutputProps) {
}

function useChatCompletionSubscription({
message,
params,
runId,
onNext,
onCompleted,
}: {
message: string;
params: PlaygroundOutputSubscription$variables;
runId: number;
onNext: (response: PlaygroundOutputSubscription$data) => void;
onCompleted: () => void;
Expand All @@ -49,11 +54,13 @@ function useChatCompletionSubscription({
>(
() => ({
subscription: graphql`
subscription PlaygroundOutputSubscription($message: String!) {
chatCompletion(input: { message: $message })
subscription PlaygroundOutputSubscription(
$messages: [ChatCompletionMessageInput!]!
) {
chatCompletion(input: { messages: $messages })
}
`,
variables: { message },
variables: params,
onNext: (response) => {
if (response) {
onNext(response);
Expand All @@ -70,6 +77,35 @@ function useChatCompletionSubscription({
return useSubscription(config);
}

/**
* A utility function to convert playground messages content to GQL chat completion message input
*/
function toGqlChatCompletionMessage(
message: ChatMessage
): ChatCompletionMessageInput {
return {
content: message.content,
role: toGqlChatCompletionRole(message.role),
};
}

function toGqlChatCompletionRole(
role: ChatMessageRole
): ChatCompletionMessageRole {
switch (role) {
case "system":
return "SYSTEM";
case "user":
return "USER";
case "tool":
return "TOOL";
case "ai":
return "AI";
default:
assertUnreachable(role);
}
}

function PlaygroundOutputText(props: PlaygroundInstanceProps) {
const instance = usePlaygroundContext(
(state) => state.instances[props.playgroundInstanceId]
Expand All @@ -89,12 +125,10 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) {
throw new Error("We only support chat templates for now");
}

const message = instance.template.messages.reduce((acc, message) => {
return acc + message.content;
}, "");

useChatCompletionSubscription({
message: message,
params: {
messages: instance.template.messages.map(toGqlChatCompletionMessage),
},
runId: instance.activeRunId,
onNext: (response) => {
setOutput((acc) => acc + response.chatCompletion);
Expand All @@ -103,5 +137,14 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) {
markPlaygroundInstanceComplete(props.playgroundInstanceId);
},
});

if (!output) {
return (
<Flex direction="row" gap="size-100" alignItems="center">
<Icon svg={<Icons.LoadingOutline />} />
Running...
</Flex>
);
}
return <span>{output}</span>;
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion app/src/store/playgroundStore.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export type PlaygroundTemplate =
/**
* The role of a chat message with a LLM
*/
export type ChatMessageRole = "user" | "assistant" | "system" | "tool";
export type ChatMessageRole = "user" | "ai" | "system" | "tool";

/**
* A chat message with a role and content
Expand Down
12 changes: 12 additions & 0 deletions src/phoenix/server/api/input_types/ChatCompletionMessageInput.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import strawberry
from strawberry.scalars import JSON

from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole


@strawberry.input
class ChatCompletionMessageInput:
role: ChatCompletionMessageRole
content: JSON = strawberry.field(
description="The content of the message as JSON to support text and tools",
)
53 changes: 49 additions & 4 deletions src/phoenix/server/api/subscriptions.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,57 @@
from datetime import datetime
from typing import AsyncIterator
from typing import TYPE_CHECKING, AsyncIterator, List

import strawberry
from sqlalchemy import insert, select
from strawberry.types import Info

from phoenix.db import models
from phoenix.server.api.context import Context
from phoenix.server.api.input_types.ChatCompletionMessageInput import ChatCompletionMessageInput
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole

if TYPE_CHECKING:
from openai.types.chat import (
ChatCompletionMessageParam,
)


@strawberry.input
class ChatCompletionInput:
message: str
messages: List[ChatCompletionMessageInput]


def to_openai_chat_completion_param(
message: ChatCompletionMessageInput,
) -> "ChatCompletionMessageParam":
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
)

if message.role is ChatCompletionMessageRole.USER:
return ChatCompletionUserMessageParam(
{
"content": message.content,
"role": "user",
}
)
if message.role is ChatCompletionMessageRole.SYSTEM:
return ChatCompletionSystemMessageParam(
{
"content": message.content,
"role": "system",
}
)
if message.role is ChatCompletionMessageRole.AI:
return ChatCompletionAssistantMessageParam(
{
"content": message.content,
"role": "assistant",
}
)
raise ValueError(f"Unsupported role: {message.role}")


@strawberry.type
Expand All @@ -21,13 +61,18 @@ async def chat_completion(
self, info: Info[Context, None], input: ChatCompletionInput
) -> AsyncIterator[str]:
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionUserMessageParam

client = AsyncOpenAI()

# Loop over the input messages and map them to the OpenAI format

messages: List[ChatCompletionMessageParam] = [
to_openai_chat_completion_param(message) for message in input.messages
]
chunk_contents = []
start_time = datetime.now()
async for chunk in await client.chat.completions.create(
messages=[ChatCompletionUserMessageParam(role="user", content=input.message)],
messages=messages,
model="gpt-4",
stream=True,
):
Expand Down
11 changes: 11 additions & 0 deletions src/phoenix/server/api/types/ChatCompletionMessageRole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from enum import Enum

import strawberry


@strawberry.enum
class ChatCompletionMessageRole(Enum):
USER = "USER"
SYSTEM = "SYSTEM"
TOOL = "TOOL"
AI = "AI" # E.g. the assistant. Normalize to AI for consistency.

0 comments on commit ee1f85b

Please sign in to comment.