Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(playground): plumb through message tool_calls from span to playground #5197

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 281 additions & 0 deletions app/src/pages/playground/__tests__/playgroundUtils.test.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { TemplateLanguage } from "@phoenix/components/templateEditor/types";
import { DEFAULT_MODEL_PROVIDER } from "@phoenix/constants/generativeConstants";
import {
_resetInstanceId,
Expand All @@ -14,8 +15,13 @@ import {
SPAN_ATTRIBUTES_PARSING_ERROR,
} from "../constants";
import {
extractVariablesFromInstances,
getChatRole,
getModelConfigFromAttributes,
getModelProviderFromModelName,
getOutputFromAttributes,
getTemplateMessagesFromAttributes,
processAttributeToolCalls,
transformSpanAttributesToPlaygroundInstance,
} from "../playgroundUtils";

Expand All @@ -24,6 +30,25 @@ import {
spanAttributesWithInputMessages,
} from "./fixtures";

const baseTestPlaygroundInstance: PlaygroundInstance = {
id: 0,
activeRunId: null,
isRunning: false,
model: {
provider: "OPENAI",
modelName: "gpt-3.5-turbo",
invocationParameters: {},
},
input: { variablesValueCache: {} },
tools: [],
toolChoice: "auto",
spanId: null,
template: {
__type: "chat",
messages: [],
},
};

const expectedPlaygroundInstanceWithIO: PlaygroundInstance = {
id: 0,
activeRunId: null,
Expand Down Expand Up @@ -410,3 +435,259 @@ describe("getModelProviderFromModelName", () => {
);
});
});

const testSpanToolCall = {
tool_call: {
id: "1",
function: {
name: "functionName",
arguments: JSON.stringify({ arg1: "value1" }),
},
},
};

const expectedTestToolCall = {
id: "1",
function: {
name: "functionName",
arguments: JSON.stringify({ arg1: "value1" }),
},
};
describe("processAttributeToolCalls", () => {
it("should transform tool calls correctly", () => {
const toolCalls = [testSpanToolCall];
expect(processAttributeToolCalls(toolCalls)).toEqual([
expectedTestToolCall,
]);
});

it("should filter out nullish tool calls", () => {
const toolCalls = [{}, testSpanToolCall];
expect(processAttributeToolCalls(toolCalls)).toEqual([
expectedTestToolCall,
]);
});
});

describe("getTemplateMessagesFromAttributes", () => {
it("should return parsing errors if input messages are invalid", () => {
const parsedAttributes = { llm: { input_messages: "invalid" } };
expect(getTemplateMessagesFromAttributes(parsedAttributes)).toEqual({
messageParsingErrors: [INPUT_MESSAGES_PARSING_ERROR],
messages: null,
});
});

it("should return parsed messages as ChatMessages if input messages are valid", () => {
const parsedAttributes = {
llm: {
input_messages: [
{
message: {
role: "human",
content: "Hello",
tool_calls: [testSpanToolCall],
},
},
],
},
};
expect(getTemplateMessagesFromAttributes(parsedAttributes)).toEqual({
messageParsingErrors: [],
messages: [
{
id: expect.any(Number),
role: "user",
content: "Hello",
toolCalls: [expectedTestToolCall],
},
],
});
});
});

describe("getOutputFromAttributes", () => {
it("should return parsing errors if output messages are invalid", () => {
const parsedAttributes = { llm: { output_messages: "invalid" } };
expect(getOutputFromAttributes(parsedAttributes)).toEqual({
output: undefined,
outputParsingErrors: [
OUTPUT_MESSAGES_PARSING_ERROR,
OUTPUT_VALUE_PARSING_ERROR,
],
});
});

it("should return parsed output if output messages are valid", () => {
const parsedAttributes = {
llm: {
output_messages: [
{
message: {
role: "ai",
content: "This is an AI Answer",
},
},
],
},
};
expect(getOutputFromAttributes(parsedAttributes)).toEqual({
output: [
{
id: expect.any(Number),
role: "ai",
content: "This is an AI Answer",
},
],
outputParsingErrors: [],
});
});

it("should fallback to output.value if output_messages is not present", () => {
const parsedAttributes = {
output: {
value: "This is an AI Answer",
},
};
expect(getOutputFromAttributes(parsedAttributes)).toEqual({
output: "This is an AI Answer",
outputParsingErrors: [OUTPUT_MESSAGES_PARSING_ERROR],
});
});
});

describe("getModelConfigFromAttributes", () => {
it("should return parsing errors if model config is invalid", () => {
const parsedAttributes = { llm: { model_name: 123 } };
expect(getModelConfigFromAttributes(parsedAttributes)).toEqual({
modelConfig: null,
parsingErrors: [MODEL_CONFIG_PARSING_ERROR],
});
});

it("should return parsed model config if valid with the provider inferred", () => {
const parsedAttributes = {
llm: {
model_name: "gpt-3.5-turbo",
invocation_parameters: '{"top_p": 0.5, "max_tokens": 100}',
},
};
expect(getModelConfigFromAttributes(parsedAttributes)).toEqual({
modelConfig: {
modelName: "gpt-3.5-turbo",
provider: "OPENAI",
invocationParameters: {
topP: 0.5,
maxTokens: 100,
},
},
parsingErrors: [],
});
});

it("should return invocation parameters parsing errors if they are malformed", () => {
const parsedAttributes = {
llm: {
model_name: "gpt-3.5-turbo",
invocation_parameters: 100,
},
};
expect(getModelConfigFromAttributes(parsedAttributes)).toEqual({
modelConfig: {
modelName: "gpt-3.5-turbo",
provider: "OPENAI",
invocationParameters: {},
},
parsingErrors: [MODEL_CONFIG_WITH_INVOCATION_PARAMETERS_PARSING_ERROR],
});
});
});

describe("extractVariablesFromInstances", () => {
it("should extract variables from chat messages", () => {
const instances: PlaygroundInstance[] = [
{
...baseTestPlaygroundInstance,
template: {
__type: "chat",
messages: [
{ id: 0, content: "Hello {{name}}", role: "user" },
{ id: 1, content: "How are you, {{name}}?", role: "ai" },
],
},
},
];
const templateLanguage = "MUSTACHE";
expect(
extractVariablesFromInstances({ instances, templateLanguage })
).toEqual(["name"]);
});

it("should extract variables from text completion prompts", () => {
const instances: PlaygroundInstance[] = [
{
...baseTestPlaygroundInstance,
template: {
__type: "text_completion",
prompt: "Hello {{name}}",
},
},
];
const templateLanguage = "MUSTACHE";
expect(
extractVariablesFromInstances({ instances, templateLanguage })
).toEqual(["name"]);
});

it("should handle multiple instances and variable extraction", () => {
const instances: PlaygroundInstance[] = [
{
...baseTestPlaygroundInstance,
template: {
__type: "chat",
messages: [
{ id: 0, content: "Hello {{name}}", role: "user" },
{ id: 1, content: "How are you, {{name}}?", role: "ai" },
],
},
},
{
...baseTestPlaygroundInstance,
template: {
__type: "text_completion",
prompt: "Your age is {{age}}",
},
},
];
const templateLanguage = "MUSTACHE";
expect(
extractVariablesFromInstances({ instances, templateLanguage })
).toEqual(["name", "age"]);
});

it("should handle multiple instances and variable extraction with fstring", () => {
const instances: PlaygroundInstance[] = [
{
...baseTestPlaygroundInstance,
template: {
__type: "chat",
messages: [
{ id: 0, content: "Hello {name}", role: "user" },
{ id: 1, content: "How are you, {{escaped}}?", role: "ai" },
],
},
},
{
...baseTestPlaygroundInstance,
template: {
__type: "text_completion",
prompt: "Your age is {age}",
},
},
];
const templateLanguage: TemplateLanguage = "F_STRING";
expect(
extractVariablesFromInstances({ instances, templateLanguage })
).toEqual(["name", "age"]);
});
});
44 changes: 41 additions & 3 deletions app/src/pages/playground/playgroundUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,37 @@ export function getChatRole(role: string): ChatMessageRole {
return DEFAULT_CHAT_ROLE;
}

/**
* Takes tool calls on a message from span attributes and transforms them into tool calls for a message in the playground
* @param toolCalls Tool calls from a spans message to tool calls from a chat message in the playground
* @returns Tool calls for a message in the playground
*
* NB: Only exported for testing
*/
export function processAttributeToolCalls(
toolCalls?: MessageSchema["message"]["tool_calls"]
): ChatMessage["toolCalls"] {
if (toolCalls == null) {
return;
}
return toolCalls
.map(({ tool_call }) => {
if (tool_call == null) {
return null;
}
return {
id: tool_call.id ?? "",
function: {
name: tool_call.function?.name ?? "",
arguments: tool_call.function?.arguments ?? {},
},
};
})
.filter((toolCall): toolCall is NonNullable<typeof toolCall> => {
return toolCall != null;
});
}

/**
* Takes a list of messages from span attributes and transforms them into a list of {@link ChatMessage|ChatMessages}
* @param messages messages from attributes either input or output @see {@link https://github.com/Arize-ai/openinference/blob/main/spec/semantic_conventions.md|Semantic Conventions}}
Expand All @@ -77,6 +108,7 @@ function processAttributeMessagesToChatMessage(
id: generateMessageId(),
role: getChatRole(message.role),
content: message.content,
toolCalls: processAttributeToolCalls(message.tool_calls),
};
});
}
Expand All @@ -85,8 +117,10 @@ function processAttributeMessagesToChatMessage(
* Attempts to parse the input messages from the span attributes.
* @param parsedAttributes the JSON parsed span attributes
* @returns an object containing the parsed {@link ChatMessage|ChatMessages} and any parsing errors
*
* NB: Only exported for testing
*/
function getTemplateMessagesFromAttributes(parsedAttributes: unknown) {
export function getTemplateMessagesFromAttributes(parsedAttributes: unknown) {
const inputMessages = llmInputMessageSchema.safeParse(parsedAttributes);
if (!inputMessages.success) {
return {
Expand All @@ -107,8 +141,10 @@ function getTemplateMessagesFromAttributes(parsedAttributes: unknown) {
* Attempts to get llm.output_messages then output.value from the span attributes.
* @param parsedAttributes the JSON parsed span attributes
* @returns an object containing the parsed output and any parsing errors
*
* NB: Only exported for testing
*/
function getOutputFromAttributes(parsedAttributes: unknown) {
export function getOutputFromAttributes(parsedAttributes: unknown) {
const outputParsingErrors: string[] = [];
const outputMessages = llmOutputMessageSchema.safeParse(parsedAttributes);
if (outputMessages.success) {
Expand Down Expand Up @@ -161,8 +197,10 @@ export function getModelProviderFromModelName(
* Attempts to get the llm.model_name, inferred provider, and invocation parameters from the span attributes.
* @param parsedAttributes the JSON parsed span attributes
* @returns the model config if it exists or parsing errors if it does not
*
* NB: Only exported for testing
*/
function getModelConfigFromAttributes(parsedAttributes: unknown): {
export function getModelConfigFromAttributes(parsedAttributes: unknown): {
modelConfig: ModelConfig | null;
parsingErrors: string[];
} {
Expand Down
Loading
Loading