Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): Generalize streaming usage for language models based on passed callback handlers #7378

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions langchain-core/src/callbacks/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,19 @@ abstract class BaseCallbackHandlerMethodsClass {
*/
export type CallbackHandlerMethods = BaseCallbackHandlerMethodsClass;

/**
* Interface for handlers that can indicate a preference for streaming responses.
* When implemented, this allows the handler to signal whether it prefers to receive
* streaming responses from language models rather than complete responses.
*/
export interface CallbackHandlerPrefersStreaming {
readonly lc_prefer_streaming: boolean;
}

export function callbackHandlerPrefersStreaming(x: BaseCallbackHandler) {
return "lc_prefer_streaming" in x && x.lc_prefer_streaming;
}

/**
* Abstract base class for creating callback handlers in the LangChain
* framework. It provides a set of optional methods that can be overridden
Expand Down
9 changes: 4 additions & 5 deletions langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,10 @@ import {
RunnableSequence,
RunnableToolLike,
} from "../runnables/base.js";
import { isStreamEventsHandler } from "../tracers/event_stream.js";
import { isLogStreamHandler } from "../tracers/log_stream.js";
import { concat } from "../utils/stream.js";
import { RunnablePassthrough } from "../runnables/passthrough.js";
import { isZodSchema } from "../utils/types/is_zod_schema.js";
import { callbackHandlerPrefersStreaming } from "../callbacks/base.js";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type ToolChoice = string | Record<string, any> | "auto" | "any";
Expand Down Expand Up @@ -370,9 +369,9 @@ export abstract class BaseChatModel<
// Even if stream is not explicitly called, check if model is implicitly
// called from streamEvents() or streamLog() to get all streamed events.
// Bail out if _streamResponseChunks not overridden
const hasStreamingHandler = !!runManagers?.[0].handlers.find((handler) => {
return isStreamEventsHandler(handler) || isLogStreamHandler(handler);
});
const hasStreamingHandler = !!runManagers?.[0].handlers.find(
callbackHandlerPrefersStreaming
);
if (
hasStreamingHandler &&
baseMessages.length === 1 &&
Expand Down
9 changes: 4 additions & 5 deletions langchain-core/src/language_models/llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ import {
} from "./base.js";
import type { RunnableConfig } from "../runnables/config.js";
import type { BaseCache } from "../caches/base.js";
import { isStreamEventsHandler } from "../tracers/event_stream.js";
import { isLogStreamHandler } from "../tracers/log_stream.js";
import { concat } from "../utils/stream.js";
import { callbackHandlerPrefersStreaming } from "../callbacks/base.js";

export type SerializedLLM = {
_model: string;
Expand Down Expand Up @@ -270,9 +269,9 @@ export abstract class BaseLLM<
// Even if stream is not explicitly called, check if model is implicitly
// called from streamEvents() or streamLog() to get all streamed events.
// Bail out if _streamResponseChunks not overridden
const hasStreamingHandler = !!runManagers?.[0].handlers.find((handler) => {
return isStreamEventsHandler(handler) || isLogStreamHandler(handler);
});
const hasStreamingHandler = !!runManagers?.[0].handlers.find(
callbackHandlerPrefersStreaming
);
let output: LLMResult;
if (
hasStreamingHandler &&
Expand Down
8 changes: 7 additions & 1 deletion langchain-core/src/tracers/event_stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { BaseTracer, type Run } from "./base.js";
import {
BaseCallbackHandler,
BaseCallbackHandlerInput,
CallbackHandlerPrefersStreaming,
} from "../callbacks/base.js";
import { IterableReadableStream } from "../utils/stream.js";
import { AIMessageChunk } from "../messages/ai.js";
Expand Down Expand Up @@ -145,7 +146,10 @@ export const isStreamEventsHandler = (
* handler that logs the execution of runs and emits `RunLog` instances to a
* `RunLogStream`.
*/
export class EventStreamCallbackHandler extends BaseTracer {
export class EventStreamCallbackHandler
extends BaseTracer
implements CallbackHandlerPrefersStreaming
{
protected autoClose = true;

protected includeNames?: string[];
Expand All @@ -172,6 +176,8 @@ export class EventStreamCallbackHandler extends BaseTracer {

name = "event_stream_tracer";

lc_prefer_streaming = true;

constructor(fields?: EventStreamCallbackHandlerInput) {
super({ _awaitHandler: true, ...fields });
this.autoClose = fields?.autoClose ?? true;
Expand Down
8 changes: 7 additions & 1 deletion langchain-core/src/tracers/log_stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { BaseTracer, type Run } from "./base.js";
import {
BaseCallbackHandler,
BaseCallbackHandlerInput,
CallbackHandlerPrefersStreaming,
HandleLLMNewTokenCallbackFields,
} from "../callbacks/base.js";
import { IterableReadableStream } from "../utils/stream.js";
Expand Down Expand Up @@ -210,7 +211,10 @@ function isChatGenerationChunk(
* handler that logs the execution of runs and emits `RunLog` instances to a
* `RunLogStream`.
*/
export class LogStreamCallbackHandler extends BaseTracer {
export class LogStreamCallbackHandler
extends BaseTracer
implements CallbackHandlerPrefersStreaming
{
protected autoClose = true;

protected includeNames?: string[];
Expand Down Expand Up @@ -241,6 +245,8 @@ export class LogStreamCallbackHandler extends BaseTracer {

name = "log_stream_tracer";

lc_prefer_streaming = true;

constructor(fields?: LogStreamCallbackHandlerInput) {
super({ _awaitHandler: true, ...fields });
this.autoClose = fields?.autoClose ?? true;
Expand Down
Loading