Skip to content

Commit

Permalink
feat: add inference for RunnableMap RunOutput type (#3517)
Browse files Browse the repository at this point in the history
* add RunnableMapLike to infer RunnableMap output

* remove unneeded changes

* fix linting

* format

* fix runnable_stream_log.test

* upgrade typescript version

* clean types

* fix structured_output_runnables.int.test

* ts version ~5.1.6

* remove unused eslint-disable-next-line

* remove another disable no-explicit-any

* remove another no-explicit-any

* move eslint

* Format

* Default runnable maps to any type in case inference is not possible

* Add tests

---------

Co-authored-by: David Illing <dilling123@gmail.com>
Co-authored-by: jacoblee93 <jacoblee93@gmail.com>
  • Loading branch information
3 people authored Dec 13, 2023
1 parent b455950 commit 7b574ff
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 47 deletions.
36 changes: 23 additions & 13 deletions langchain-core/src/runnables/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@ export type RunnableFunc<RunInput, RunOutput> = (
| (Record<string, any> & { config: RunnableConfig })
) => RunOutput | Promise<RunOutput>;

export type RunnableMapLike<RunInput, RunOutput> = {
[K in keyof RunOutput]: RunnableLike<RunInput, RunOutput[K]>;
};

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type RunnableLike<RunInput = any, RunOutput = any> =
| Runnable<RunInput, RunOutput>
| RunnableFunc<RunInput, RunOutput>
| { [key: string]: RunnableLike<RunInput, RunOutput> };
| RunnableMapLike<RunInput, RunOutput>;

export type RunnableBatchOptions = {
maxConcurrency?: number;
Expand Down Expand Up @@ -1368,11 +1372,12 @@ export class RunnableSequence<
* const result = await mapChain.invoke({ topic: "bear" });
* ```
*/
export class RunnableMap<RunInput> extends Runnable<
RunInput,
export class RunnableMap<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Record<string, any>
> {
RunInput = any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
> extends Runnable<RunInput, RunOutput> {
static lc_name() {
return "RunnableMap";
}
Expand All @@ -1387,23 +1392,28 @@ export class RunnableMap<RunInput> extends Runnable<
return Object.keys(this.steps);
}

constructor(fields: { steps: Record<string, RunnableLike<RunInput>> }) {
constructor(fields: { steps: RunnableMapLike<RunInput, RunOutput> }) {
super(fields);
this.steps = {};
for (const [key, value] of Object.entries(fields.steps)) {
this.steps[key] = _coerceToRunnable(value);
}
}

static from<RunInput>(steps: Record<string, RunnableLike<RunInput>>) {
return new RunnableMap<RunInput>({ steps });
static from<
RunInput,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
>(
steps: RunnableMapLike<RunInput, RunOutput>
): RunnableMap<RunInput, RunOutput> {
return new RunnableMap<RunInput, RunOutput>({ steps });
}

async invoke(
input: RunInput,
options?: Partial<BaseCallbackConfig>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
): Promise<Record<string, any>> {
): Promise<RunOutput> {
const callbackManager_ = await getCallbackMangerForConfig(options);
const runManager = await callbackManager_?.handleChainStart(
this.toJSON(),
Expand Down Expand Up @@ -1432,7 +1442,7 @@ export class RunnableMap<RunInput> extends Runnable<
throw e;
}
await runManager?.handleChainEnd(output);
return output;
return output as RunOutput;
}
}

Expand Down Expand Up @@ -1665,9 +1675,9 @@ export function _coerceToRunnable<RunInput, RunOutput>(
} else if (!Array.isArray(coerceable) && typeof coerceable === "object") {
const runnables: Record<string, Runnable<RunInput>> = {};
for (const [key, value] of Object.entries(coerceable)) {
runnables[key] = _coerceToRunnable(value);
runnables[key] = _coerceToRunnable(value as RunnableLike);
}
return new RunnableMap<RunInput>({
return new RunnableMap({
steps: runnables,
}) as unknown as Runnable<RunInput, Exclude<RunOutput, Error>>;
} else {
Expand Down
34 changes: 1 addition & 33 deletions langchain-core/src/runnables/tests/runnable.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,18 @@ import { StringOutputParser } from "../../output_parsers/string.js";
import {
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
} from "../../prompts/chat.js";
import { PromptTemplate } from "../../prompts/prompt.js";
import {
FakeLLM,
FakeChatModel,
FakeRetriever,
FakeStreamingLLM,
FakeSplitIntoListParser,
FakeRunnable,
FakeListChatModel,
} from "../../utils/testing/index.js";
import { RunnableSequence, RunnableMap, RunnableLambda } from "../base.js";
import { RunnableSequence, RunnableLambda } from "../base.js";
import { RouterRunnable } from "../router.js";
import { Document } from "../../documents/document.js";

test("Test batch", async () => {
const llm = new FakeLLM({});
Expand Down Expand Up @@ -70,35 +67,6 @@ test("Pipe from one runnable to the next", async () => {
expect(result).toBe("Hello world!");
});

test("Create a runnable sequence with a runnable map", async () => {
const promptTemplate = ChatPromptTemplate.fromMessages<{
documents: string;
question: string;
}>([
SystemMessagePromptTemplate.fromTemplate(`You are a nice assistant.`),
HumanMessagePromptTemplate.fromTemplate(
`Context:\n{documents}\n\nQuestion:\n{question}`
),
]);
const llm = new FakeChatModel({});
const inputs = {
question: (input: string) => input,
documents: RunnableSequence.from([
new FakeRetriever(),
(docs: Document[]) => JSON.stringify(docs),
]),
extraField: new FakeLLM({}),
};
const runnable = new RunnableMap({ steps: inputs })
.pipe(promptTemplate)
.pipe(llm);
const result = await runnable.invoke("Do you know the Muffin Man?");
console.log(result);
expect(result.content).toEqual(
`You are a nice assistant.\nContext:\n[{"pageContent":"foo","metadata":{}},{"pageContent":"bar","metadata":{}}]\n\nQuestion:\nDo you know the Muffin Man?`
);
});

test("Stream the entire way through", async () => {
const llm = new FakeStreamingLLM({});
const stream = await llm.pipe(new StringOutputParser()).stream("Hi there!");
Expand Down
105 changes: 105 additions & 0 deletions langchain-core/src/runnables/tests/runnable_map.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/* eslint-disable no-promise-executor-return */
/* eslint-disable @typescript-eslint/no-explicit-any */

import { StringOutputParser } from "../../output_parsers/string.js";
import {
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
} from "../../prompts/chat.js";
import {
FakeLLM,
FakeChatModel,
FakeRetriever,
} from "../../utils/testing/index.js";
import { RunnableSequence, RunnableMap } from "../base.js";
import { RunnablePassthrough } from "../passthrough.js";

test("Create a runnable sequence with a runnable map", async () => {
const promptTemplate = ChatPromptTemplate.fromMessages<{
documents: string;
question: string;
}>([
SystemMessagePromptTemplate.fromTemplate(`You are a nice assistant.`),
HumanMessagePromptTemplate.fromTemplate(
`Context:\n{documents}\n\nQuestion:\n{question}`
),
]);
const llm = new FakeChatModel({});
const inputs = {
question: (input: string) => input,
documents: RunnableSequence.from([
new FakeRetriever(),
(docs: Document[]) => JSON.stringify(docs),
]),
extraField: new FakeLLM({}),
};
const runnable = new RunnableMap({ steps: inputs })
.pipe(promptTemplate)
.pipe(llm);
const result = await runnable.invoke("Do you know the Muffin Man?");
console.log(result);
expect(result.content).toEqual(
`You are a nice assistant.\nContext:\n[{"pageContent":"foo","metadata":{}},{"pageContent":"bar","metadata":{}}]\n\nQuestion:\nDo you know the Muffin Man?`
);
});

test("Test map inference in a sequence", async () => {
const prompt = ChatPromptTemplate.fromTemplate(
"context: {context}, question: {question}"
);
const chain = RunnableSequence.from([
{
question: new RunnablePassthrough(),
context: async () => "SOME STUFF",
},
prompt,
new FakeLLM({}),
new StringOutputParser(),
]);
const response = await chain.invoke("Just passing through.");
console.log(response);
expect(response).toBe(
`Human: context: SOME STUFF, question: Just passing through.`
);
});

test("Should not allow mismatched inputs", async () => {
const prompt = ChatPromptTemplate.fromTemplate(
"context: {context}, question: {question}"
);
const badChain = RunnableSequence.from([
{
// @ts-expect-error TS compiler should flag mismatched input types
question: new FakeLLM({}),
context: async (input: number) => input,
},
prompt,
new FakeLLM({}),
new StringOutputParser(),
]);
console.log(badChain);
});

test("Should not allow improper inputs into a map in a sequence", async () => {
const prompt = ChatPromptTemplate.fromTemplate(
"context: {context}, question: {question}"
);
const map = RunnableMap.from({
question: new FakeLLM({}),
context: async (_input: string) => 9,
});
// @ts-expect-error TS compiler should flag mismatched output types
const runnable = prompt.pipe(map);
console.log(runnable);
});

test("Should not allow improper outputs from a map into the next item in a sequence", async () => {
const map = RunnableMap.from({
question: new FakeLLM({}),
context: async (_input: string) => 9,
});
// @ts-expect-error TS compiler should flag mismatched output types
const runnable = map.pipe(new FakeLLM({}));
console.log(runnable);
});
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ test("Runnable streamLog method with a more complicated sequence", async () => {
response: "testing",
}).withConfig({ tags: ["only_one"] }),
};
const runnable = new RunnableMap({ steps: inputs })

const runnable = new RunnableMap({
steps: inputs,
})
.pipe(promptTemplate)
.pipe(llm);
const stream = await runnable.streamLog(
Expand Down

1 comment on commit 7b574ff

@vercel
Copy link

@vercel vercel bot commented on 7b574ff Dec 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.