diff --git a/app/src/pages/playground/PlaygroundChatTemplate.tsx b/app/src/pages/playground/PlaygroundChatTemplate.tsx index 1fbc6139bd..9e9631de7d 100644 --- a/app/src/pages/playground/PlaygroundChatTemplate.tsx +++ b/app/src/pages/playground/PlaygroundChatTemplate.tsx @@ -30,6 +30,7 @@ import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext"; import { useChatMessageStyles } from "@phoenix/hooks/useChatMessageStyles"; import { ChatMessage, + ChatMessageRole, generateMessageId, PlaygroundChatTemplate as PlaygroundChatTemplateType, } from "@phoenix/store"; @@ -143,7 +144,7 @@ export function PlaygroundChatTemplate(props: PlaygroundChatTemplateProps) { ...template.messages, { id: generateMessageId(), - role: "user", + role: ChatMessageRole.user, content: "", }, ], diff --git a/app/src/pages/playground/PlaygroundOutput.tsx b/app/src/pages/playground/PlaygroundOutput.tsx index 58f6e589c0..2ac91203d7 100644 --- a/app/src/pages/playground/PlaygroundOutput.tsx +++ b/app/src/pages/playground/PlaygroundOutput.tsx @@ -5,7 +5,12 @@ import { graphql, GraphQLSubscriptionConfig } from "relay-runtime"; import { Card, Flex, Icon, Icons } from "@arizeai/components"; import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext"; -import { ChatMessage, ChatMessageRole } from "@phoenix/store"; +import { useChatMessageStyles } from "@phoenix/hooks/useChatMessageStyles"; +import { + ChatMessage, + ChatMessageRole, + generateMessageId, +} from "@phoenix/store"; import { assertUnreachable } from "@phoenix/typeUtils"; import { @@ -15,11 +20,22 @@ import { PlaygroundOutputSubscription$data, PlaygroundOutputSubscription$variables, } from "./__generated__/PlaygroundOutputSubscription.graphql"; +import { isChatMessages } from "./playgroundUtils"; import { TitleWithAlphabeticIndex } from "./TitleWithAlphabeticIndex"; import { PlaygroundInstanceProps } from "./types"; interface PlaygroundOutputProps extends PlaygroundInstanceProps {} +function PlaygroundOutputMessage({ message }: { message: ChatMessage }) { + const styles = useChatMessageStyles(message.role); + + return ( + + {message.content} + + ); +} + export function PlaygroundOutput(props: PlaygroundOutputProps) { const instanceId = props.playgroundInstanceId; const instance = usePlaygroundContext((state) => @@ -29,22 +45,46 @@ export function PlaygroundOutput(props: PlaygroundOutputProps) { state.instances.findIndex((instance) => instance.id === instanceId) ); if (!instance) { - return null; + throw new Error("Playground instance not found"); } const runId = instance.activeRunId; const hasRunId = runId !== null; + + const OutputEl = useMemo(() => { + if (hasRunId) { + return ( + + ); + } + if (isChatMessages(instance.output)) { + const messages = instance.output; + + return messages.map((message, index) => { + return ; + }); + } + if (typeof instance.output === "string") { + return ( + + ); + } + return "click run to see output"; + }, [hasRunId, instance.output, instanceId, runId]); + return ( } collapsible variant="compact" > - {hasRunId ? ( - - ) : ( - "click run to see output" - )} + {OutputEl} ); } @@ -104,13 +144,13 @@ function toGqlChatCompletionRole( role: ChatMessageRole ): ChatCompletionMessageRole { switch (role) { - case "system": + case ChatMessageRole.system: return "SYSTEM"; - case "user": + case ChatMessageRole.user: return "USER"; - case "tool": + case ChatMessageRole.tool: return "TOOL"; - case "ai": + case ChatMessageRole.ai: return "AI"; default: assertUnreachable(role); @@ -118,13 +158,13 @@ function toGqlChatCompletionRole( } function PlaygroundOutputText(props: PlaygroundInstanceProps) { - const instance = usePlaygroundContext( - (state) => state.instances[props.playgroundInstanceId] + const instances = usePlaygroundContext((state) => state.instances); + const instance = instances.find( + (instance) => instance.id === props.playgroundInstanceId ); const markPlaygroundInstanceComplete = usePlaygroundContext( (state) => state.markPlaygroundInstanceComplete ); - const [output, setOutput] = useState(""); if (!instance) { throw new Error("No instance found"); } @@ -136,6 +176,8 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) { throw new Error("We only support chat templates for now"); } + const [output, setOutput] = useState(""); + useChatCompletionSubscription({ params: { messages: instance.template.messages.map(toGqlChatCompletionMessage), @@ -157,5 +199,13 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) { ); } - return {output}; + return ( + + ); } diff --git a/app/src/pages/playground/SpanPlaygroundPage.tsx b/app/src/pages/playground/SpanPlaygroundPage.tsx index e528008424..14879be630 100644 --- a/app/src/pages/playground/SpanPlaygroundPage.tsx +++ b/app/src/pages/playground/SpanPlaygroundPage.tsx @@ -3,8 +3,6 @@ import { useLoaderData, useNavigate } from "react-router"; import { Alert, Button, Flex, Icon, Icons } from "@arizeai/components"; -import { createPlaygroundInstance } from "@phoenix/store"; - import { spanPlaygroundPageLoaderQuery$data } from "./__generated__/spanPlaygroundPageLoaderQuery.graphql"; import { Playground } from "./Playground"; import { transformSpanAttributesToPlaygroundInstance } from "./playgroundUtils"; @@ -22,37 +20,37 @@ export function SpanPlaygroundPage() { throw new Error("Span not found"); } - const playgroundInstance = useMemo( + const { playgroundInstance, parsingErrors } = useMemo( () => transformSpanAttributesToPlaygroundInstance(span), [span] ); return ( - - + + ); } function SpanPlaygroundBanners({ span, + parsingErrors, }: { span: Extract< NonNullable, { __typename: "Span" } >; + + parsingErrors?: string[]; }) { const navigate = useNavigate(); + const hasParsingErrors = parsingErrors && parsingErrors.length > 0; const [showBackBanner, setShowBackBanner] = useState(true); + const [showParsingErrorsBanner, setShowParsingErrorsBanner] = + useState(hasParsingErrors); return ( -
+ {showBackBanner && ( {`Replay and iterate on your LLM call from your ${span.project.name} project`} )} -
+ {showParsingErrorsBanner && hasParsingErrors && ( + { + setShowParsingErrorsBanner(false); + }} + title="The following errors occurred when parsing span attributes:" + > +
    + {parsingErrors.map((error) => ( +
  • {error}
  • + ))} +
+
+ )} + ); } diff --git a/app/src/pages/playground/__tests__/fixtures.ts b/app/src/pages/playground/__tests__/fixtures.ts index bca3182b2e..bf647d9dc0 100644 --- a/app/src/pages/playground/__tests__/fixtures.ts +++ b/app/src/pages/playground/__tests__/fixtures.ts @@ -34,7 +34,7 @@ export const spanAttributesWithInputMessages = { }, { message: { - content: "Anser me the following question. Are you sentient?", + content: "hello?", role: "user", }, }, diff --git a/app/src/pages/playground/__tests__/playgroundUtils.test.ts b/app/src/pages/playground/__tests__/playgroundUtils.test.ts index 0fca1fb8fe..37f297f58e 100644 --- a/app/src/pages/playground/__tests__/playgroundUtils.test.ts +++ b/app/src/pages/playground/__tests__/playgroundUtils.test.ts @@ -1,7 +1,16 @@ -import { _resetInstanceId, _resetMessageId } from "@phoenix/store"; +import { + _resetInstanceId, + _resetMessageId, + ChatMessageRole, + PlaygroundInstance, +} from "@phoenix/store"; import { getChatRole, + INPUT_MESSAGES_PARSING_ERROR, + OUTPUT_MESSAGES_PARSING_ERROR, + OUTPUT_VALUE_PARSING_ERROR, + SPAN_ATTRIBUTES_PARSING_ERROR, transformSpanAttributesToPlaygroundInstance, } from "../playgroundUtils"; @@ -10,26 +19,42 @@ import { spanAttributesWithInputMessages, } from "./fixtures"; -const expectedPlaygroundInstance = { +const expectedPlaygroundInstanceWithIO: PlaygroundInstance = { id: 0, activeRunId: null, isRunning: false, input: { variables: {}, }, + tools: {}, template: { __type: "chat", - messages: spanAttributesWithInputMessages.llm.input_messages.map( - ({ message }, index) => { - return { - id: index, - ...message, - }; - } - ), + // These id's are not 0, 1, 2, because we create a playground instance (including messages) at the top of the transformSpanAttributesToPlaygroundInstance function + // Doing so increments the message id counter + messages: [ + { id: 2, content: "You are a chatbot", role: ChatMessageRole.system }, + { id: 3, content: "hello?", role: ChatMessageRole.user }, + ], }, - output: spanAttributesWithInputMessages.llm.output_messages, - tools: undefined, + output: [ + { id: 4, content: "This is an AI Answer", role: ChatMessageRole.ai }, + ], +}; + +const defaultTemplate = { + __type: "chat", + messages: [ + { + id: 0, + role: ChatMessageRole.system, + content: "You are a chatbot", + }, + { + id: 1, + role: ChatMessageRole.user, + content: "{{question}}", + }, + ], }; describe("transformSpanAttributesToPlaygroundInstance", () => { @@ -37,46 +62,65 @@ describe("transformSpanAttributesToPlaygroundInstance", () => { _resetInstanceId(); _resetMessageId(); }); - it("should throw if the attributes are not parsable", () => { + it("should return the default instance with parsing errors if the span attributes are unparsable", () => { const span = { ...basePlaygroundSpan, attributes: "invalid json", }; - expect(() => transformSpanAttributesToPlaygroundInstance(span)).toThrow( - "Invalid span attributes, attributes must be valid JSON" - ); + expect(transformSpanAttributesToPlaygroundInstance(span)).toStrictEqual({ + playgroundInstance: { + ...expectedPlaygroundInstanceWithIO, + template: defaultTemplate, + output: undefined, + }, + parsingErrors: [SPAN_ATTRIBUTES_PARSING_ERROR], + }); }); - it("should return null if the attributes do not match the schema", () => { + it("should return the default instance with parsing errors if the attributes don't contain any information", () => { const span = { ...basePlaygroundSpan, attributes: JSON.stringify({}), }; - expect(transformSpanAttributesToPlaygroundInstance(span)).toBeNull(); + expect(transformSpanAttributesToPlaygroundInstance(span)).toStrictEqual({ + playgroundInstance: { + ...expectedPlaygroundInstanceWithIO, + template: defaultTemplate, + + output: undefined, + }, + parsingErrors: [ + INPUT_MESSAGES_PARSING_ERROR, + OUTPUT_MESSAGES_PARSING_ERROR, + OUTPUT_VALUE_PARSING_ERROR, + ], + }); }); - it("should return a PlaygroundInstance if the attributes contain llm.input_messages", () => { + it("should return a PlaygroundInstance with template messages and output parsing errors if the attributes contain llm.input_messages", () => { const span = { ...basePlaygroundSpan, - attributes: JSON.stringify(spanAttributesWithInputMessages), + attributes: JSON.stringify({ + ...spanAttributesWithInputMessages, + llm: { + ...spanAttributesWithInputMessages.llm, + output_messages: undefined, + }, + }), }; - - const instance = transformSpanAttributesToPlaygroundInstance(span); - expect(instance?.template.__type).toEqual("chat"); - if (instance?.template.__type !== "chat") { - throw new Error("Invalid template type constructed"); - } - expect(instance?.template.messages).toHaveLength(2); - instance?.template.messages.forEach((message, index) => { - expect(message.role).toEqual( - expectedPlaygroundInstance.template.messages[index].role - ); - expect(message.content).toEqual( - expectedPlaygroundInstance.template.messages[index].content - ); + expect(transformSpanAttributesToPlaygroundInstance(span)).toEqual({ + playgroundInstance: { + ...expectedPlaygroundInstanceWithIO, + output: undefined, + }, + parsingErrors: [ + OUTPUT_MESSAGES_PARSING_ERROR, + OUTPUT_VALUE_PARSING_ERROR, + ], }); }); - it("should return a PlaygroundInstance if the attributes contain llm.input_messages, even if output_messages are not present", () => { + + it("should fallback to output.value if output_messages is not present", () => { const span = { ...basePlaygroundSpan, attributes: JSON.stringify({ @@ -85,14 +129,73 @@ describe("transformSpanAttributesToPlaygroundInstance", () => { ...spanAttributesWithInputMessages.llm, output_messages: undefined, }, + output: { + value: "This is an AI Answer", + }, }), }; - const instance = transformSpanAttributesToPlaygroundInstance(span); - expect(instance?.template.__type).toEqual("chat"); - if (instance?.template.__type !== "chat") { - throw new Error("Invalid template type constructed"); - } - expect(Array.isArray(instance?.template.messages)).toBeTruthy(); + + expect(transformSpanAttributesToPlaygroundInstance(span)).toEqual({ + playgroundInstance: { + ...expectedPlaygroundInstanceWithIO, + output: "This is an AI Answer", + }, + parsingErrors: [OUTPUT_MESSAGES_PARSING_ERROR], + }); + }); + + it("should return a PlaygroundInstance if the attributes contain llm.input_messages and output_messages", () => { + const span = { + ...basePlaygroundSpan, + attributes: JSON.stringify(spanAttributesWithInputMessages), + }; + expect(transformSpanAttributesToPlaygroundInstance(span)).toEqual({ + playgroundInstance: expectedPlaygroundInstanceWithIO, + parsingErrors: [], + }); + }); + + it("should normalize message roles in input and output messages", () => { + const span = { + ...basePlaygroundSpan, + attributes: JSON.stringify({ + llm: { + input_messages: [ + { + message: { + role: "human", + content: "You are a chatbot", + }, + }, + ], + output_messages: [ + { + message: { + role: "assistant", + content: "This is an AI Answer", + }, + }, + ], + }, + }), + }; + expect(transformSpanAttributesToPlaygroundInstance(span)).toEqual({ + playgroundInstance: { + ...expectedPlaygroundInstanceWithIO, + template: { + __type: "chat", + messages: [ + { + id: 2, + role: "user", + content: "You are a chatbot", + }, + ], + }, + output: [{ id: 3, content: "This is an AI Answer", role: "ai" }], + }, + parsingErrors: [], + }); }); }); @@ -103,9 +206,9 @@ describe("getChatRole", () => { it("should return the ChatMessageRole if the role is included in ChatRoleMap", () => { expect(getChatRole("assistant")).toEqual("ai"); - // expect(getChatRole("bot")).toEqual("ai"); - // expect(getChatRole("system")).toEqual("system"); - // expect(getChatRole("human:")).toEqual("user"); + expect(getChatRole("bot")).toEqual("ai"); + expect(getChatRole("system")).toEqual("system"); + expect(getChatRole("human:")).toEqual("user"); }); it("should return DEFAULT_CHAT_ROLE if the role is not found", () => { diff --git a/app/src/pages/playground/constants.tsx b/app/src/pages/playground/constants.tsx index a4b1bb2b59..8f34607c82 100644 --- a/app/src/pages/playground/constants.tsx +++ b/app/src/pages/playground/constants.tsx @@ -2,7 +2,7 @@ import { ChatMessageRole } from "@phoenix/store"; export const NUM_MAX_PLAYGROUND_INSTANCES = 4; -export const DEFAULT_CHAT_ROLE = "user"; +export const DEFAULT_CHAT_ROLE = ChatMessageRole.user; /** * Map of {@link ChatMessageRole} to potential role values. diff --git a/app/src/pages/playground/playgroundUtils.ts b/app/src/pages/playground/playgroundUtils.ts index bd6df28085..ffe15aa51c 100644 --- a/app/src/pages/playground/playgroundUtils.ts +++ b/app/src/pages/playground/playgroundUtils.ts @@ -1,20 +1,28 @@ -import { generateInstanceId, PlaygroundInstance } from "@phoenix/store"; +import { PlaygroundInstance } from "@phoenix/store"; import { + ChatMessage, ChatMessageRole, - chatMessageRoles, + createPlaygroundInstance, generateMessageId, -} from "@phoenix/store/playgroundStore"; +} from "@phoenix/store"; import { safelyParseJSON } from "@phoenix/utils/jsonUtils"; import { ChatRoleMap, DEFAULT_CHAT_ROLE } from "./constants"; -import { llmAttributesSchema } from "./schemas"; +import { + chatMessageRolesSchema, + chatMessagesSchema, + llmInputMessageSchema, + llmOutputMessageSchema, + MessageSchema, + outputSchema, +} from "./schemas"; import { PlaygroundSpan } from "./spanPlaygroundPageLoader"; /** * Checks if a string is a valid chat message role */ export function isChatMessageRole(role: unknown): role is ChatMessageRole { - return chatMessageRoles.includes(role as ChatMessageRole); + return chatMessageRolesSchema.safeParse(role).success; } /** @@ -38,39 +46,140 @@ export function getChatRole(role: string): ChatMessageRole { return DEFAULT_CHAT_ROLE; } +/** + * Takes a list of messages from span attributes and transforms them into a list of {@link ChatMessage|ChatMessages} + * @param messages messages from attributes either input or output @see {@link https://github.com/Arize-ai/openinference/blob/main/spec/semantic_conventions.md|Semantic Conventions}} + * returns a list of {@link ChatMessage|ChatMessages} + */ +function processAttributeMessagesToChatMessage( + messages: MessageSchema[] +): ChatMessage[] { + return messages.map(({ message }) => { + return { + id: generateMessageId(), + role: getChatRole(message.role), + content: message.content, + }; + }); +} + +export const INPUT_MESSAGES_PARSING_ERROR = + "Unable to parse span input messages, expected messages which include a role and content."; +export const OUTPUT_MESSAGES_PARSING_ERROR = + "Unable to parse span output messages, expected messages which include a role and content."; +export const OUTPUT_VALUE_PARSING_ERROR = + "Unable to parse span output expected output.value to be present."; +export const SPAN_ATTRIBUTES_PARSING_ERROR = + "Unable to parse span attributes, attributes must be valid JSON."; + +/** + * Attempts to parse the input messages from the span attributes. + * @param parsedAttributes the JSON parsed span attributes + * @returns an object containing the parsed {@link ChatMessage|ChatMessages} and any parsing errors + */ +function getTemplateMessagesFromAttributes(parsedAttributes: unknown) { + const inputMessages = llmInputMessageSchema.safeParse(parsedAttributes); + if (!inputMessages.success) { + return { + messageParsingErrors: [INPUT_MESSAGES_PARSING_ERROR], + messages: null, + }; + } + + return { + messageParsingErrors: [], + messages: processAttributeMessagesToChatMessage( + inputMessages.data.llm.input_messages + ), + }; +} + +function getOutputFromAttributes(parsedAttributes: unknown) { + const outputParsingErrors: string[] = []; + const outputMessages = llmOutputMessageSchema.safeParse(parsedAttributes); + if (outputMessages.success) { + return { + output: processAttributeMessagesToChatMessage( + outputMessages.data.llm.output_messages + ), + outputParsingErrors, + }; + } + + outputParsingErrors.push(OUTPUT_MESSAGES_PARSING_ERROR); + + const parsedOutput = outputSchema.safeParse(parsedAttributes); + if (parsedOutput.success) { + return { + output: parsedOutput.data.output.value, + outputParsingErrors, + }; + } + + outputParsingErrors.push(OUTPUT_VALUE_PARSING_ERROR); + + return { + output: undefined, + outputParsingErrors, + }; +} + +/** + * Takes a {@link PlaygroundSpan|Span} and attempts to transform it's attributes into various fields on a {@link PlaygroundInstance}. + * @param span the {@link PlaygroundSpan|Span} to transform into a playground instance + * @returns a {@link PlaygroundInstance} with certain fields pre-populated from the span attributes + */ export function transformSpanAttributesToPlaygroundInstance( span: PlaygroundSpan -): PlaygroundInstance | null { +): { + playgroundInstance: PlaygroundInstance; + /** + * Errors that occurred during parsing of initial playground data. + * For example, when coming from a span to the playground, the span may + * not have the correct attributes, or the attributes may be of the wrong shape. + * This field is used to store any issues encountered when parsing to display in the playground. + */ + parsingErrors: string[]; +} { + const basePlaygroundInstance = createPlaygroundInstance(); const { json: parsedAttributes, parseError } = safelyParseJSON( span.attributes ); if (parseError) { - throw new Error("Invalid span attributes, attributes must be valid JSON"); - } - const { data, success } = llmAttributesSchema.safeParse(parsedAttributes); - if (!success) { - return null; + return { + playgroundInstance: basePlaygroundInstance, + parsingErrors: [SPAN_ATTRIBUTES_PARSING_ERROR], + }; } + + const { messages, messageParsingErrors } = + getTemplateMessagesFromAttributes(parsedAttributes); + const { output, outputParsingErrors } = + getOutputFromAttributes(parsedAttributes); + // TODO(parker): add support for tools, variables, and input / output variants // https://github.com/Arize-ai/phoenix/issues/4886 return { - id: generateInstanceId(), - activeRunId: null, - isRunning: false, - input: { - variables: {}, + playgroundInstance: { + ...basePlaygroundInstance, + template: + messages != null + ? { + __type: "chat", + messages, + } + : basePlaygroundInstance.template, + output, }, - template: { - __type: "chat", - messages: data.llm.input_messages.map(({ message }) => { - return { - id: generateMessageId(), - role: getChatRole(message.role), - content: message.content, - }; - }), - }, - output: data.llm.output_messages, - tools: undefined, + parsingErrors: [...messageParsingErrors, ...outputParsingErrors], }; } + +/** + * Checks if something is a valid {@link ChatMessage} + */ +export const isChatMessages = ( + messages: unknown +): messages is ChatMessage[] => { + return chatMessagesSchema.safeParse(messages).success; +}; diff --git a/app/src/pages/playground/schemas.ts b/app/src/pages/playground/schemas.ts index f9f964f4fa..f572e20bf0 100644 --- a/app/src/pages/playground/schemas.ts +++ b/app/src/pages/playground/schemas.ts @@ -1,13 +1,14 @@ import { z } from "zod"; import { - ImageAttributesPostfixes, LLMAttributePostfixes, MessageAttributePostfixes, - MessageContentsAttributePostfixes, SemanticAttributePrefixes, } from "@arizeai/openinference-semantic-conventions"; +import { ChatMessage, ChatMessageRole } from "@phoenix/store"; +import { schemaForType } from "@phoenix/typeUtils"; + /** * The zod schema for llm tool calls in an input message * @see {@link https://github.com/Arize-ai/openinference/blob/main/spec/semantic_conventions.md|Semantic Conventions} @@ -23,28 +24,6 @@ const toolCallSchema = z }) .partial(); -/** - * The zod schema for llm message contents - * @see {@link https://github.com/Arize-ai/openinference/blob/main/spec/semantic_conventions.md|Semantic Conventions} - */ -const messageContentSchema = z.object({ - [SemanticAttributePrefixes.message_content]: z - .object({ - [MessageContentsAttributePostfixes.type]: z.string(), - [MessageContentsAttributePostfixes.text]: z.string(), - [MessageContentsAttributePostfixes.image]: z - .object({ - [MessageContentsAttributePostfixes.image]: z - .object({ - [ImageAttributesPostfixes.url]: z.string(), - }) - .partial(), - }) - .partial(), - }) - .partial(), -}); - /** * The zod schema for llm messages * @see {@link https://github.com/Arize-ai/openinference/blob/main/spec/semantic_conventions.md|Semantic Conventions} @@ -53,20 +32,61 @@ const messageSchema = z.object({ [SemanticAttributePrefixes.message]: z.object({ [MessageAttributePostfixes.role]: z.string(), [MessageAttributePostfixes.content]: z.string(), - [MessageAttributePostfixes.name]: z.string().optional(), [MessageAttributePostfixes.tool_calls]: z.array(toolCallSchema).optional(), - [MessageAttributePostfixes.contents]: z - .array(messageContentSchema) - .optional(), }), }); + +/** + * The type of each message in either the input or output messages + * on a spans attributes + */ +export type MessageSchema = z.infer; + /** - * The zod schema for llm attributes + * The zod schema for llm.input_messages attributes * @see {@link https://github.com/Arize-ai/openinference/blob/main/spec/semantic_conventions.md|Semantic Conventions} */ -export const llmAttributesSchema = z.object({ +export const llmInputMessageSchema = z.object({ [SemanticAttributePrefixes.llm]: z.object({ [LLMAttributePostfixes.input_messages]: z.array(messageSchema), - [LLMAttributePostfixes.output_messages]: z.optional(z.array(messageSchema)), }), }); + +/** + * The zod schema for llm.output_messages attributes + * @see {@link https://github.com/Arize-ai/openinference/blob/main/spec/semantic_conventions.md|Semantic Conventions} + */ +export const llmOutputMessageSchema = z.object({ + [SemanticAttributePrefixes.llm]: z.object({ + [LLMAttributePostfixes.output_messages]: z.array(messageSchema), + }), +}); + +/** + * The zod schema for output attributes + * @see {@link https://github.com/Arize-ai/openinference/blob/main/spec/semantic_conventions.md|Semantic Conventions} + + */ +export const outputSchema = z.object({ + [SemanticAttributePrefixes.output]: z.object({ + value: z.string(), + }), +}); + +/** + * The zod schema for {@link chatMessageRoles} + */ +export const chatMessageRolesSchema = z.nativeEnum(ChatMessageRole); + +const chatMessageSchema = schemaForType()( + z.object({ + id: z.number(), + role: chatMessageRolesSchema, + content: z.string(), + }) +); + +/** + * The zod schema for ChatMessages + */ +export const chatMessagesSchema = z.array(chatMessageSchema); diff --git a/app/src/store/playgroundStore.tsx b/app/src/store/playgroundStore.tsx index 70040e18ef..248d237de9 100644 --- a/app/src/store/playgroundStore.tsx +++ b/app/src/store/playgroundStore.tsx @@ -33,7 +33,7 @@ export const _resetInstanceId = () => { * NB: This is only used for testing purposes */ export const _resetMessageId = () => { - playgroundInstanceIdIndex = 0; + playgroundMessageIdIndex = 0; }; /** @@ -52,12 +52,12 @@ export type PlaygroundTemplate = /** * Array of roles for a chat message with a LLM */ -export const chatMessageRoles = ["user", "ai", "system", "tool"] as const; - -/** - * The role of a chat message with a LLM - */ -export type ChatMessageRole = (typeof chatMessageRoles)[number]; +export enum ChatMessageRole { + system = "system", + user = "user", + tool = "tool", + ai = "ai", +} /** * A chat message with a role and content @@ -131,7 +131,7 @@ export interface PlaygroundInstance { template: PlaygroundTemplate; tools: unknown; input: PlaygroundInput; - output: unknown; + output: ChatMessage[] | undefined | string; activeRunId: number | null; /** * Whether or not the playground instance is actively running or not @@ -196,12 +196,12 @@ const generateChatCompletionTemplate = (): PlaygroundChatTemplate => ({ messages: [ { id: generateMessageId(), - role: "system", + role: ChatMessageRole.system, content: "You are a chatbot", }, { id: generateMessageId(), - role: "user", + role: ChatMessageRole.user, content: "{{question}}", }, ], @@ -218,7 +218,7 @@ export function createPlaygroundInstance(): PlaygroundInstance { template: generateChatCompletionTemplate(), tools: {}, input: { variables: {} }, - output: {}, + output: undefined, activeRunId: null, isRunning: false, }; @@ -242,7 +242,7 @@ export const createPlaygroundStore = ( template: generateChatCompletionTemplate(), tools: {}, input: { variables: {} }, - output: {}, + output: undefined, activeRunId: null, isRunning: false, }, @@ -256,7 +256,7 @@ export const createPlaygroundStore = ( template: DEFAULT_TEXT_COMPLETION_TEMPLATE, tools: {}, input: { variables: {} }, - output: {}, + output: undefined, activeRunId: null, isRunning: false, }, diff --git a/app/src/typeUtils.ts b/app/src/typeUtils.ts index 936a698c57..813a83e902 100644 --- a/app/src/typeUtils.ts +++ b/app/src/typeUtils.ts @@ -1,3 +1,5 @@ +import { z } from "zod"; + /** * Utility function that uses the type system to check if a switch statement is exhaustive. * If the switch statement is not exhaustive, there will be a type error caught in typescript @@ -45,3 +47,25 @@ export function isObject(value: unknown): value is object { export type Mutable = { -readonly [P in keyof T]: T[P]; }; + +/** + * A zod type utility that ensures that the schema is written to correctly match (at least) what is included in the type. + * Note it does not guard against extra fields in the schema not present in the type. + * @example + * ```typescript + * const chatMessageSchema = schemaForType()( + * z.object({ + * id: z.number(), + * role: chatMessageRolesSchema, + * content: z.string(), + * }) + * ); + * ``` + * Taken from the zod maintainer here: + * @see https://github.com/colinhacks/zod/issues/372#issuecomment-826380330 + */ +export const schemaForType = + () => + >(arg: S) => { + return arg; + };