Skip to content

Commit

Permalink
feat(playground): plumb through message tool_calls from span to playg…
Browse files Browse the repository at this point in the history
…round (#5197)

* feat(playground): plumb through message tool_calls from span to playground

* cleanup
  • Loading branch information
Parker-Stafford authored Oct 25, 2024
1 parent b0436d7 commit a1886a0
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 12 deletions.
281 changes: 281 additions & 0 deletions app/src/pages/playground/__tests__/playgroundUtils.test.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { TemplateLanguage } from "@phoenix/components/templateEditor/types";
import { DEFAULT_MODEL_PROVIDER } from "@phoenix/constants/generativeConstants";
import {
_resetInstanceId,
Expand All @@ -14,8 +15,13 @@ import {
SPAN_ATTRIBUTES_PARSING_ERROR,
} from "../constants";
import {
extractVariablesFromInstances,
getChatRole,
getModelConfigFromAttributes,
getModelProviderFromModelName,
getOutputFromAttributes,
getTemplateMessagesFromAttributes,
processAttributeToolCalls,
transformSpanAttributesToPlaygroundInstance,
} from "../playgroundUtils";

Expand All @@ -24,6 +30,25 @@ import {
spanAttributesWithInputMessages,
} from "./fixtures";

const baseTestPlaygroundInstance: PlaygroundInstance = {
id: 0,
activeRunId: null,
isRunning: false,
model: {
provider: "OPENAI",
modelName: "gpt-3.5-turbo",
invocationParameters: {},
},
input: { variablesValueCache: {} },
tools: [],
toolChoice: "auto",
spanId: null,
template: {
__type: "chat",
messages: [],
},
};

const expectedPlaygroundInstanceWithIO: PlaygroundInstance = {
id: 0,
activeRunId: null,
Expand Down Expand Up @@ -410,3 +435,259 @@ describe("getModelProviderFromModelName", () => {
);
});
});

const testSpanToolCall = {
tool_call: {
id: "1",
function: {
name: "functionName",
arguments: JSON.stringify({ arg1: "value1" }),
},
},
};

const expectedTestToolCall = {
id: "1",
function: {
name: "functionName",
arguments: JSON.stringify({ arg1: "value1" }),
},
};
describe("processAttributeToolCalls", () => {
it("should transform tool calls correctly", () => {
const toolCalls = [testSpanToolCall];
expect(processAttributeToolCalls(toolCalls)).toEqual([
expectedTestToolCall,
]);
});

it("should filter out nullish tool calls", () => {
const toolCalls = [{}, testSpanToolCall];
expect(processAttributeToolCalls(toolCalls)).toEqual([
expectedTestToolCall,
]);
});
});

describe("getTemplateMessagesFromAttributes", () => {
it("should return parsing errors if input messages are invalid", () => {
const parsedAttributes = { llm: { input_messages: "invalid" } };
expect(getTemplateMessagesFromAttributes(parsedAttributes)).toEqual({
messageParsingErrors: [INPUT_MESSAGES_PARSING_ERROR],
messages: null,
});
});

it("should return parsed messages as ChatMessages if input messages are valid", () => {
const parsedAttributes = {
llm: {
input_messages: [
{
message: {
role: "human",
content: "Hello",
tool_calls: [testSpanToolCall],
},
},
],
},
};
expect(getTemplateMessagesFromAttributes(parsedAttributes)).toEqual({
messageParsingErrors: [],
messages: [
{
id: expect.any(Number),
role: "user",
content: "Hello",
toolCalls: [expectedTestToolCall],
},
],
});
});
});

describe("getOutputFromAttributes", () => {
it("should return parsing errors if output messages are invalid", () => {
const parsedAttributes = { llm: { output_messages: "invalid" } };
expect(getOutputFromAttributes(parsedAttributes)).toEqual({
output: undefined,
outputParsingErrors: [
OUTPUT_MESSAGES_PARSING_ERROR,
OUTPUT_VALUE_PARSING_ERROR,
],
});
});

it("should return parsed output if output messages are valid", () => {
const parsedAttributes = {
llm: {
output_messages: [
{
message: {
role: "ai",
content: "This is an AI Answer",
},
},
],
},
};
expect(getOutputFromAttributes(parsedAttributes)).toEqual({
output: [
{
id: expect.any(Number),
role: "ai",
content: "This is an AI Answer",
},
],
outputParsingErrors: [],
});
});

it("should fallback to output.value if output_messages is not present", () => {
const parsedAttributes = {
output: {
value: "This is an AI Answer",
},
};
expect(getOutputFromAttributes(parsedAttributes)).toEqual({
output: "This is an AI Answer",
outputParsingErrors: [OUTPUT_MESSAGES_PARSING_ERROR],
});
});
});

describe("getModelConfigFromAttributes", () => {
it("should return parsing errors if model config is invalid", () => {
const parsedAttributes = { llm: { model_name: 123 } };
expect(getModelConfigFromAttributes(parsedAttributes)).toEqual({
modelConfig: null,
parsingErrors: [MODEL_CONFIG_PARSING_ERROR],
});
});

it("should return parsed model config if valid with the provider inferred", () => {
const parsedAttributes = {
llm: {
model_name: "gpt-3.5-turbo",
invocation_parameters: '{"top_p": 0.5, "max_tokens": 100}',
},
};
expect(getModelConfigFromAttributes(parsedAttributes)).toEqual({
modelConfig: {
modelName: "gpt-3.5-turbo",
provider: "OPENAI",
invocationParameters: {
topP: 0.5,
maxTokens: 100,
},
},
parsingErrors: [],
});
});

it("should return invocation parameters parsing errors if they are malformed", () => {
const parsedAttributes = {
llm: {
model_name: "gpt-3.5-turbo",
invocation_parameters: 100,
},
};
expect(getModelConfigFromAttributes(parsedAttributes)).toEqual({
modelConfig: {
modelName: "gpt-3.5-turbo",
provider: "OPENAI",
invocationParameters: {},
},
parsingErrors: [MODEL_CONFIG_WITH_INVOCATION_PARAMETERS_PARSING_ERROR],
});
});
});

describe("extractVariablesFromInstances", () => {
it("should extract variables from chat messages", () => {
const instances: PlaygroundInstance[] = [
{
...baseTestPlaygroundInstance,
template: {
__type: "chat",
messages: [
{ id: 0, content: "Hello {{name}}", role: "user" },
{ id: 1, content: "How are you, {{name}}?", role: "ai" },
],
},
},
];
const templateLanguage = "MUSTACHE";
expect(
extractVariablesFromInstances({ instances, templateLanguage })
).toEqual(["name"]);
});

it("should extract variables from text completion prompts", () => {
const instances: PlaygroundInstance[] = [
{
...baseTestPlaygroundInstance,
template: {
__type: "text_completion",
prompt: "Hello {{name}}",
},
},
];
const templateLanguage = "MUSTACHE";
expect(
extractVariablesFromInstances({ instances, templateLanguage })
).toEqual(["name"]);
});

it("should handle multiple instances and variable extraction", () => {
const instances: PlaygroundInstance[] = [
{
...baseTestPlaygroundInstance,
template: {
__type: "chat",
messages: [
{ id: 0, content: "Hello {{name}}", role: "user" },
{ id: 1, content: "How are you, {{name}}?", role: "ai" },
],
},
},
{
...baseTestPlaygroundInstance,
template: {
__type: "text_completion",
prompt: "Your age is {{age}}",
},
},
];
const templateLanguage = "MUSTACHE";
expect(
extractVariablesFromInstances({ instances, templateLanguage })
).toEqual(["name", "age"]);
});

it("should handle multiple instances and variable extraction with fstring", () => {
const instances: PlaygroundInstance[] = [
{
...baseTestPlaygroundInstance,
template: {
__type: "chat",
messages: [
{ id: 0, content: "Hello {name}", role: "user" },
{ id: 1, content: "How are you, {{escaped}}?", role: "ai" },
],
},
},
{
...baseTestPlaygroundInstance,
template: {
__type: "text_completion",
prompt: "Your age is {age}",
},
},
];
const templateLanguage: TemplateLanguage = "F_STRING";
expect(
extractVariablesFromInstances({ instances, templateLanguage })
).toEqual(["name", "age"]);
});
});
44 changes: 41 additions & 3 deletions app/src/pages/playground/playgroundUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,37 @@ export function getChatRole(role: string): ChatMessageRole {
return DEFAULT_CHAT_ROLE;
}

/**
* Takes tool calls on a message from span attributes and transforms them into tool calls for a message in the playground
* @param toolCalls Tool calls from a spans message to tool calls from a chat message in the playground
* @returns Tool calls for a message in the playground
*
* NB: Only exported for testing
*/
export function processAttributeToolCalls(
toolCalls?: MessageSchema["message"]["tool_calls"]
): ChatMessage["toolCalls"] {
if (toolCalls == null) {
return;
}
return toolCalls
.map(({ tool_call }) => {
if (tool_call == null) {
return null;
}
return {
id: tool_call.id ?? "",
function: {
name: tool_call.function?.name ?? "",
arguments: tool_call.function?.arguments ?? {},
},
};
})
.filter((toolCall): toolCall is NonNullable<typeof toolCall> => {
return toolCall != null;
});
}

/**
* Takes a list of messages from span attributes and transforms them into a list of {@link ChatMessage|ChatMessages}
* @param messages messages from attributes either input or output @see {@link https://github.com/Arize-ai/openinference/blob/main/spec/semantic_conventions.md|Semantic Conventions}}
Expand All @@ -77,6 +108,7 @@ function processAttributeMessagesToChatMessage(
id: generateMessageId(),
role: getChatRole(message.role),
content: message.content,
toolCalls: processAttributeToolCalls(message.tool_calls),
};
});
}
Expand All @@ -85,8 +117,10 @@ function processAttributeMessagesToChatMessage(
* Attempts to parse the input messages from the span attributes.
* @param parsedAttributes the JSON parsed span attributes
* @returns an object containing the parsed {@link ChatMessage|ChatMessages} and any parsing errors
*
* NB: Only exported for testing
*/
function getTemplateMessagesFromAttributes(parsedAttributes: unknown) {
export function getTemplateMessagesFromAttributes(parsedAttributes: unknown) {
const inputMessages = llmInputMessageSchema.safeParse(parsedAttributes);
if (!inputMessages.success) {
return {
Expand All @@ -107,8 +141,10 @@ function getTemplateMessagesFromAttributes(parsedAttributes: unknown) {
* Attempts to get llm.output_messages then output.value from the span attributes.
* @param parsedAttributes the JSON parsed span attributes
* @returns an object containing the parsed output and any parsing errors
*
* NB: Only exported for testing
*/
function getOutputFromAttributes(parsedAttributes: unknown) {
export function getOutputFromAttributes(parsedAttributes: unknown) {
const outputParsingErrors: string[] = [];
const outputMessages = llmOutputMessageSchema.safeParse(parsedAttributes);
if (outputMessages.success) {
Expand Down Expand Up @@ -161,8 +197,10 @@ export function getModelProviderFromModelName(
* Attempts to get the llm.model_name, inferred provider, and invocation parameters from the span attributes.
* @param parsedAttributes the JSON parsed span attributes
* @returns the model config if it exists or parsing errors if it does not
*
* NB: Only exported for testing
*/
function getModelConfigFromAttributes(parsedAttributes: unknown): {
export function getModelConfigFromAttributes(parsedAttributes: unknown): {
modelConfig: ModelConfig | null;
parsingErrors: string[];
} {
Expand Down
Loading

0 comments on commit a1886a0

Please sign in to comment.