Skip to content

Commit

Permalink
feat (rsc): add streamUI onFinish callback
Browse files Browse the repository at this point in the history
  • Loading branch information
gclark-eightfold committed Jun 11, 2024
1 parent 3cabf07 commit f0d9f71
Showing 1 changed file with 44 additions and 3 deletions.
47 changes: 44 additions & 3 deletions packages/core/rsc/stream-ui/stream-ui.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ import { getValidatedPrompt } from '../../core/prompt/get-validated-prompt';
import { prepareCallSettings } from '../../core/prompt/prepare-call-settings';
import { prepareToolsAndToolChoice } from '../../core/prompt/prepare-tools-and-tool-choice';
import { Prompt } from '../../core/prompt/prompt';
import { CoreToolChoice } from '../../core/types';
import { CallWarning, CoreToolChoice, FinishReason } from '../../core/types';
import { retryWithExponentialBackoff } from '../../core/util/retry-with-exponential-backoff';
import { createStreamableUI } from '../streamable';
import { createResolvablePromise } from '../utils';
import {
TokenUsage,
calculateTokenUsage,
} from '../../core/generate-text/token-usage';

type Streamable = ReactNode | Promise<ReactNode>;

Expand Down Expand Up @@ -84,6 +88,7 @@ export async function streamUI<
abortSignal,
initial,
text,
onFinish,
...settings
}: CallSettings &
Prompt & {
Expand All @@ -100,12 +105,42 @@ export async function streamUI<
};

/**
The tool choice strategy. Default: 'auto'.
* The tool choice strategy. Default: 'auto'.
*/
toolChoice?: CoreToolChoice<TOOLS>;

text?: RenderText;
initial?: ReactNode;
/**
* Callback that is called when the LLM response and the final object validation are finished.
*/
onFinish?: (event: {
/**
* The reason why the generation finished.
*/
finishReason: FinishReason;
/**
* The token usage of the generated response.
*/
usage: TokenUsage;
/**
* The final ui node that was generated.
*/
value: ReactNode;
/**
* Warnings from the model provider (e.g. unsupported settings)
*/
warnings?: CallWarning[];
/**
* Optional raw response data.
*/
rawResponse?: {
/**
* Response headers.
*/
headers?: Record<string, string>;
};
}) => Promise<void> | void;
}): Promise<RenderResult> {
// TODO: Remove these errors after the experimental phase.
if (typeof model === 'string') {
Expand Down Expand Up @@ -311,7 +346,13 @@ The tool choice strategy. Default: 'auto'.
}

case 'finish': {
// Nothing to do here.
onFinish?.({
finishReason: value.finishReason,
usage: calculateTokenUsage(value.usage),
value: ui.value,
warnings: result.warnings,
rawResponse: result.rawResponse,
});
}
}
}
Expand Down

0 comments on commit f0d9f71

Please sign in to comment.