Skip to content

Commit

Permalink
core[patch]: Adds support for plain message objects as shorthand for …
Browse files Browse the repository at this point in the history
…messages (#5954)

* Adds support for OpenAI style objects as shorthand for messages

* Use more specific error message

* Fix tests

* Refine types
  • Loading branch information
jacoblee93 authored Jul 2, 2024
1 parent 2c9ae22 commit 8dcd6f8
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 39 deletions.
19 changes: 19 additions & 0 deletions langchain-core/src/language_models/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,25 @@ import { z } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";
import { FakeChatModel, FakeListChatModel } from "../../utils/testing/index.js";

test("Test ChatModel accepts array shorthand for messages", async () => {
const model = new FakeChatModel({});
const response = await model.invoke([["human", "Hello there!"]]);
expect(response.content).toEqual("Hello there!");
});

test("Test ChatModel accepts object shorthand for messages", async () => {
const model = new FakeChatModel({});
const response = await model.invoke([
{
type: "human",
content: "Hello there!",
additional_kwargs: {},
example: true,
},
]);
expect(response.content).toEqual("Hello there!");
});

test("Test ChatModel uses callbacks", async () => {
const model = new FakeChatModel({});
let acc = "";
Expand Down
4 changes: 4 additions & 0 deletions langchain-core/src/messages/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ export abstract class BaseMessageChunk extends BaseMessage {

export type BaseMessageLike =
| BaseMessage
| ({
type: MessageType | "user" | "assistant" | "placeholder";
} & BaseMessageFields &
Record<string, unknown>)
| [
StringWithAutocomplete<
MessageType | "user" | "assistant" | "placeholder"
Expand Down
32 changes: 22 additions & 10 deletions langchain-core/src/messages/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
isBaseMessage,
StoredMessage,
StoredMessageV1,
BaseMessageFields,
} from "./base.js";
import {
ChatMessage,
Expand All @@ -20,6 +21,23 @@ import { HumanMessage, HumanMessageChunk } from "./human.js";
import { SystemMessage, SystemMessageChunk } from "./system.js";
import { ToolMessage, ToolMessageFieldsWithToolCallId } from "./tool.js";

function _constructMessageFromParams(
params: BaseMessageFields & { type: string }
) {
const { type, ...rest } = params;
if (type === "human" || type === "user") {
return new HumanMessage(rest);
} else if (type === "ai" || type === "assistant") {
return new AIMessage(rest);
} else if (type === "system") {
return new SystemMessage(rest);
} else {
throw new Error(
`Unable to coerce message from array: only human, AI, or system message coercion is currently supported.`
);
}
}

export function coerceMessageLikeToMessage(
messageLike: BaseMessageLike
): BaseMessage {
Expand All @@ -28,17 +46,11 @@ export function coerceMessageLikeToMessage(
} else if (isBaseMessage(messageLike)) {
return messageLike;
}
const [type, content] = messageLike;
if (type === "human" || type === "user") {
return new HumanMessage({ content });
} else if (type === "ai" || type === "assistant") {
return new AIMessage({ content });
} else if (type === "system") {
return new SystemMessage({ content });
if (Array.isArray(messageLike)) {
const [type, content] = messageLike;
return _constructMessageFromParams({ type, content });
} else {
throw new Error(
`Unable to coerce message from array: only human, AI, or system message coercion is currently supported.`
);
return _constructMessageFromParams(messageLike);
}
}

Expand Down
49 changes: 21 additions & 28 deletions langchain-core/src/prompts/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,50 +134,43 @@ export class MessagesPlaceholder<
return [this.variableName];
}

validateInputOrThrow(
input: Array<unknown> | undefined,
variableName: Extract<keyof RunInput, string>
): input is BaseMessage[] {
async formatMessages(
values: TypedPromptInputValues<RunInput>
): Promise<BaseMessage[]> {
const input = values[this.variableName];
if (this.optional && !input) {
return false;
return [];
} else if (!input) {
const error = new Error(
`Error: Field "${variableName}" in prompt uses a MessagesPlaceholder, which expects an array of BaseMessages as an input value. Received: undefined`
`Field "${this.variableName}" in prompt uses a MessagesPlaceholder, which expects an array of BaseMessages as an input value. Received: undefined`
);
error.name = "InputFormatError";
throw error;
}

let isInputBaseMessage = false;

if (Array.isArray(input)) {
isInputBaseMessage = input.every((message) =>
isBaseMessage(message as BaseMessage)
);
} else {
isInputBaseMessage = isBaseMessage(input as BaseMessage);
}

if (!isInputBaseMessage) {
let formattedMessages;
try {
if (Array.isArray(input)) {
formattedMessages = input.map(coerceMessageLikeToMessage);
} else {
formattedMessages = [coerceMessageLikeToMessage(input)];
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
const readableInput =
typeof input === "string" ? input : JSON.stringify(input, null, 2);

const error = new Error(
`Error: Field "${variableName}" in prompt uses a MessagesPlaceholder, which expects an array of BaseMessages as an input value. Received: ${readableInput}`
[
`Field "${this.variableName}" in prompt uses a MessagesPlaceholder, which expects an array of BaseMessages or coerceable values as input.`,
`Received value: ${readableInput}`,
`Additional message: ${e.message}`,
].join("\n\n")
);
error.name = "InputFormatError";
throw error;
}

return true;
}

async formatMessages(
values: TypedPromptInputValues<RunInput>
): Promise<BaseMessage[]> {
this.validateInputOrThrow(values[this.variableName], this.variableName);

return values[this.variableName] ?? [];
return formattedMessages;
}
}

Expand Down
56 changes: 55 additions & 1 deletion langchain-core/src/prompts/tests/chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ test("Test MessagesPlaceholder not optional", async () => {
});
// eslint-disable-next-line @typescript-eslint/no-explicit-any
await expect(prompt.formatMessages({} as any)).rejects.toThrow(
'Error: Field "foo" in prompt uses a MessagesPlaceholder, which expects an array of BaseMessages as an input value. Received: undefined'
'Field "foo" in prompt uses a MessagesPlaceholder, which expects an array of BaseMessages as an input value. Received: undefined'
);
});

Expand All @@ -323,6 +323,60 @@ test("Test MessagesPlaceholder shorthand in a chat prompt template", async () =>
]);
});

test("Test MessagesPlaceholder shorthand in a chat prompt template with object format", async () => {
const prompt = ChatPromptTemplate.fromMessages([["placeholder", "{foo}"]]);
const messages = await prompt.formatMessages({
foo: [
{
type: "system",
content: "some initial content",
},
{
type: "human",
content: [
{
text: "page: 1\ndescription: One Purchase Flow\ntimestamp: '2024-06-04T14:46:46.062Z'\ntype: navigate\nscreenshot_present: true\n",
type: "text",
},
{
text: "page: 3\ndescription: intent_str=buy,mode_str=redirect,screenName_str=order-completed,\ntimestamp: '2024-06-04T14:46:58.846Z'\ntype: Screen View\nscreenshot_present: false\n",
type: "text",
},
],
},
{
type: "assistant",
content: "some captivating response",
},
],
});
expect(messages).toEqual([
new SystemMessage("some initial content"),
new HumanMessage({
content: [
{
text: "page: 1\ndescription: One Purchase Flow\ntimestamp: '2024-06-04T14:46:46.062Z'\ntype: navigate\nscreenshot_present: true\n",
type: "text",
},
{
text: "page: 3\ndescription: intent_str=buy,mode_str=redirect,screenName_str=order-completed,\ntimestamp: '2024-06-04T14:46:58.846Z'\ntype: Screen View\nscreenshot_present: false\n",
type: "text",
},
],
}),
new AIMessage("some captivating response"),
]);
});

test("Test MessagesPlaceholder with invalid shorthand should throw", async () => {
const prompt = ChatPromptTemplate.fromMessages([["placeholder", "{foo}"]]);
await expect(() =>
prompt.formatMessages({
foo: [{ badFormatting: true }],
})
).rejects.toThrow();
});

test("Test using partial", async () => {
const userPrompt = new PromptTemplate({
template: "{foo}{bar}",
Expand Down

0 comments on commit 8dcd6f8

Please sign in to comment.