From 7b574ff195be6187c8b743a336c3291fd52a2eb1 Mon Sep 17 00:00:00 2001 From: David Illing Date: Tue, 12 Dec 2023 21:16:26 -0500 Subject: [PATCH] feat: add inference for RunnableMap RunOutput type (#3517) * add RunnableMapLike to infer RunnableMap output * remove unneeded changes * fix linting * format * fix runnable_stream_log.test * upgrade typescript version * clean types * fix structured_output_runnables.int.test * ts version ~5.1.6 * remove unused eslint-disable-next-line * remove another disable no-explicit-any * remove another no-explicit-any * move eslint * Format * Default runnable maps to any type in case inference is not possible * Add tests --------- Co-authored-by: David Illing Co-authored-by: jacoblee93 --- langchain-core/src/runnables/base.ts | 36 +++--- .../src/runnables/tests/runnable.test.ts | 34 +----- .../src/runnables/tests/runnable_map.test.ts | 105 ++++++++++++++++++ .../tests/runnable_stream_log.test.ts | 5 +- 4 files changed, 133 insertions(+), 47 deletions(-) create mode 100644 langchain-core/src/runnables/tests/runnable_map.test.ts diff --git a/langchain-core/src/runnables/base.ts b/langchain-core/src/runnables/base.ts index cd8e50eab4f5..8504b24fdedd 100644 --- a/langchain-core/src/runnables/base.ts +++ b/langchain-core/src/runnables/base.ts @@ -31,11 +31,15 @@ export type RunnableFunc = ( | (Record & { config: RunnableConfig }) ) => RunOutput | Promise; +export type RunnableMapLike = { + [K in keyof RunOutput]: RunnableLike; +}; + // eslint-disable-next-line @typescript-eslint/no-explicit-any export type RunnableLike = | Runnable | RunnableFunc - | { [key: string]: RunnableLike }; + | RunnableMapLike; export type RunnableBatchOptions = { maxConcurrency?: number; @@ -1368,11 +1372,12 @@ export class RunnableSequence< * const result = await mapChain.invoke({ topic: "bear" }); * ``` */ -export class RunnableMap extends Runnable< - RunInput, +export class RunnableMap< // eslint-disable-next-line @typescript-eslint/no-explicit-any - Record -> { + RunInput = any, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends Record = Record +> extends Runnable { static lc_name() { return "RunnableMap"; } @@ -1387,7 +1392,7 @@ export class RunnableMap extends Runnable< return Object.keys(this.steps); } - constructor(fields: { steps: Record> }) { + constructor(fields: { steps: RunnableMapLike }) { super(fields); this.steps = {}; for (const [key, value] of Object.entries(fields.steps)) { @@ -1395,15 +1400,20 @@ export class RunnableMap extends Runnable< } } - static from(steps: Record>) { - return new RunnableMap({ steps }); + static from< + RunInput, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends Record = Record + >( + steps: RunnableMapLike + ): RunnableMap { + return new RunnableMap({ steps }); } async invoke( input: RunInput, options?: Partial - // eslint-disable-next-line @typescript-eslint/no-explicit-any - ): Promise> { + ): Promise { const callbackManager_ = await getCallbackMangerForConfig(options); const runManager = await callbackManager_?.handleChainStart( this.toJSON(), @@ -1432,7 +1442,7 @@ export class RunnableMap extends Runnable< throw e; } await runManager?.handleChainEnd(output); - return output; + return output as RunOutput; } } @@ -1665,9 +1675,9 @@ export function _coerceToRunnable( } else if (!Array.isArray(coerceable) && typeof coerceable === "object") { const runnables: Record> = {}; for (const [key, value] of Object.entries(coerceable)) { - runnables[key] = _coerceToRunnable(value); + runnables[key] = _coerceToRunnable(value as RunnableLike); } - return new RunnableMap({ + return new RunnableMap({ steps: runnables, }) as unknown as Runnable>; } else { diff --git a/langchain-core/src/runnables/tests/runnable.test.ts b/langchain-core/src/runnables/tests/runnable.test.ts index 149a39112c4f..647df66282e2 100644 --- a/langchain-core/src/runnables/tests/runnable.test.ts +++ b/langchain-core/src/runnables/tests/runnable.test.ts @@ -10,21 +10,18 @@ import { StringOutputParser } from "../../output_parsers/string.js"; import { ChatPromptTemplate, SystemMessagePromptTemplate, - HumanMessagePromptTemplate, } from "../../prompts/chat.js"; import { PromptTemplate } from "../../prompts/prompt.js"; import { FakeLLM, FakeChatModel, - FakeRetriever, FakeStreamingLLM, FakeSplitIntoListParser, FakeRunnable, FakeListChatModel, } from "../../utils/testing/index.js"; -import { RunnableSequence, RunnableMap, RunnableLambda } from "../base.js"; +import { RunnableSequence, RunnableLambda } from "../base.js"; import { RouterRunnable } from "../router.js"; -import { Document } from "../../documents/document.js"; test("Test batch", async () => { const llm = new FakeLLM({}); @@ -70,35 +67,6 @@ test("Pipe from one runnable to the next", async () => { expect(result).toBe("Hello world!"); }); -test("Create a runnable sequence with a runnable map", async () => { - const promptTemplate = ChatPromptTemplate.fromMessages<{ - documents: string; - question: string; - }>([ - SystemMessagePromptTemplate.fromTemplate(`You are a nice assistant.`), - HumanMessagePromptTemplate.fromTemplate( - `Context:\n{documents}\n\nQuestion:\n{question}` - ), - ]); - const llm = new FakeChatModel({}); - const inputs = { - question: (input: string) => input, - documents: RunnableSequence.from([ - new FakeRetriever(), - (docs: Document[]) => JSON.stringify(docs), - ]), - extraField: new FakeLLM({}), - }; - const runnable = new RunnableMap({ steps: inputs }) - .pipe(promptTemplate) - .pipe(llm); - const result = await runnable.invoke("Do you know the Muffin Man?"); - console.log(result); - expect(result.content).toEqual( - `You are a nice assistant.\nContext:\n[{"pageContent":"foo","metadata":{}},{"pageContent":"bar","metadata":{}}]\n\nQuestion:\nDo you know the Muffin Man?` - ); -}); - test("Stream the entire way through", async () => { const llm = new FakeStreamingLLM({}); const stream = await llm.pipe(new StringOutputParser()).stream("Hi there!"); diff --git a/langchain-core/src/runnables/tests/runnable_map.test.ts b/langchain-core/src/runnables/tests/runnable_map.test.ts new file mode 100644 index 000000000000..d820e53ee0ed --- /dev/null +++ b/langchain-core/src/runnables/tests/runnable_map.test.ts @@ -0,0 +1,105 @@ +/* eslint-disable no-promise-executor-return */ +/* eslint-disable @typescript-eslint/no-explicit-any */ + +import { StringOutputParser } from "../../output_parsers/string.js"; +import { + ChatPromptTemplate, + SystemMessagePromptTemplate, + HumanMessagePromptTemplate, +} from "../../prompts/chat.js"; +import { + FakeLLM, + FakeChatModel, + FakeRetriever, +} from "../../utils/testing/index.js"; +import { RunnableSequence, RunnableMap } from "../base.js"; +import { RunnablePassthrough } from "../passthrough.js"; + +test("Create a runnable sequence with a runnable map", async () => { + const promptTemplate = ChatPromptTemplate.fromMessages<{ + documents: string; + question: string; + }>([ + SystemMessagePromptTemplate.fromTemplate(`You are a nice assistant.`), + HumanMessagePromptTemplate.fromTemplate( + `Context:\n{documents}\n\nQuestion:\n{question}` + ), + ]); + const llm = new FakeChatModel({}); + const inputs = { + question: (input: string) => input, + documents: RunnableSequence.from([ + new FakeRetriever(), + (docs: Document[]) => JSON.stringify(docs), + ]), + extraField: new FakeLLM({}), + }; + const runnable = new RunnableMap({ steps: inputs }) + .pipe(promptTemplate) + .pipe(llm); + const result = await runnable.invoke("Do you know the Muffin Man?"); + console.log(result); + expect(result.content).toEqual( + `You are a nice assistant.\nContext:\n[{"pageContent":"foo","metadata":{}},{"pageContent":"bar","metadata":{}}]\n\nQuestion:\nDo you know the Muffin Man?` + ); +}); + +test("Test map inference in a sequence", async () => { + const prompt = ChatPromptTemplate.fromTemplate( + "context: {context}, question: {question}" + ); + const chain = RunnableSequence.from([ + { + question: new RunnablePassthrough(), + context: async () => "SOME STUFF", + }, + prompt, + new FakeLLM({}), + new StringOutputParser(), + ]); + const response = await chain.invoke("Just passing through."); + console.log(response); + expect(response).toBe( + `Human: context: SOME STUFF, question: Just passing through.` + ); +}); + +test("Should not allow mismatched inputs", async () => { + const prompt = ChatPromptTemplate.fromTemplate( + "context: {context}, question: {question}" + ); + const badChain = RunnableSequence.from([ + { + // @ts-expect-error TS compiler should flag mismatched input types + question: new FakeLLM({}), + context: async (input: number) => input, + }, + prompt, + new FakeLLM({}), + new StringOutputParser(), + ]); + console.log(badChain); +}); + +test("Should not allow improper inputs into a map in a sequence", async () => { + const prompt = ChatPromptTemplate.fromTemplate( + "context: {context}, question: {question}" + ); + const map = RunnableMap.from({ + question: new FakeLLM({}), + context: async (_input: string) => 9, + }); + // @ts-expect-error TS compiler should flag mismatched output types + const runnable = prompt.pipe(map); + console.log(runnable); +}); + +test("Should not allow improper outputs from a map into the next item in a sequence", async () => { + const map = RunnableMap.from({ + question: new FakeLLM({}), + context: async (_input: string) => 9, + }); + // @ts-expect-error TS compiler should flag mismatched output types + const runnable = map.pipe(new FakeLLM({})); + console.log(runnable); +}); diff --git a/langchain-core/src/runnables/tests/runnable_stream_log.test.ts b/langchain-core/src/runnables/tests/runnable_stream_log.test.ts index 8f7c1d1e589a..f9ab046be648 100644 --- a/langchain-core/src/runnables/tests/runnable_stream_log.test.ts +++ b/langchain-core/src/runnables/tests/runnable_stream_log.test.ts @@ -62,7 +62,10 @@ test("Runnable streamLog method with a more complicated sequence", async () => { response: "testing", }).withConfig({ tags: ["only_one"] }), }; - const runnable = new RunnableMap({ steps: inputs }) + + const runnable = new RunnableMap({ + steps: inputs, + }) .pipe(promptTemplate) .pipe(llm); const stream = await runnable.streamLog(