Skip to content

Commit

Permalink
feat: use query bundle (run-llama#702)
Browse files Browse the repository at this point in the history
  • Loading branch information
himself65 authored Jul 18, 2024
1 parent b7cfe5b commit 92f0782
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 40 deletions.
5 changes: 5 additions & 0 deletions .changeset/quiet-cows-rule.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"llamaindex": patch
---

feat: use query bundle
10 changes: 6 additions & 4 deletions packages/llamaindex/src/engines/query/RouterQueryEngine.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import type { NodeWithScore } from "@llamaindex/core/schema";
import { extractText } from "@llamaindex/core/utils";
import { EngineResponse } from "../../EngineResponse.js";
import type { ServiceContext } from "../../ServiceContext.js";
import { llmFromSettingsOrContext } from "../../Settings.js";
import { toQueryBundle } from "../../internal/utils.js";
import { PromptMixin } from "../../prompts/index.js";
import type { BaseSelector } from "../../selectors/index.js";
import { LLMSingleSelector } from "../../selectors/index.js";
Expand Down Expand Up @@ -44,7 +46,7 @@ async function combineResponses(
}

const summary = await summarizer.getResponse({
query: queryBundle.queryStr,
query: extractText(queryBundle.query),
textChunks: responseStrs,
});

Expand Down Expand Up @@ -117,7 +119,7 @@ export class RouterQueryEngine extends PromptMixin implements QueryEngine {
): Promise<EngineResponse | AsyncIterable<EngineResponse>> {
const { query, stream } = params;

const response = await this.queryRoute({ queryStr: query });
const response = await this.queryRoute(toQueryBundle(query));

if (stream) {
throw new Error("Streaming is not supported yet.");
Expand All @@ -142,7 +144,7 @@ export class RouterQueryEngine extends PromptMixin implements QueryEngine {
const selectedQueryEngine = this.queryEngines[engineInd.index];
responses.push(
await selectedQueryEngine.query({
query: queryBundle.queryStr,
query: extractText(queryBundle.query),
}),
);
}
Expand Down Expand Up @@ -179,7 +181,7 @@ export class RouterQueryEngine extends PromptMixin implements QueryEngine {
}

const finalResponse = await selectedQueryEngine.query({
query: queryBundle.queryStr,
query: extractText(queryBundle.query),
});

// add selected result
Expand Down
8 changes: 8 additions & 0 deletions packages/llamaindex/src/internal/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type { JSONValue } from "@llamaindex/core/global";
import type { ImageType } from "@llamaindex/core/schema";
import { fs } from "@llamaindex/env";
import { filetypemime } from "magic-bytes.js";
import type { QueryBundle } from "../types.js";

export const isAsyncIterable = (
obj: unknown,
Expand Down Expand Up @@ -202,3 +203,10 @@ export async function imageToDataUrl(input: ImageType): Promise<string> {
}
return await blobToDataUrl(input);
}

export function toQueryBundle(query: QueryBundle | string): QueryBundle {
if (typeof query === "string") {
return { query };
}
return query;
}
1 change: 1 addition & 0 deletions packages/llamaindex/src/prompts/Mixin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ export class PromptMixin {
}

// Must be implemented by subclasses
// fixme: says must but never implemented
protected _getPrompts(): PromptsDict {
return {};
}
Expand Down
19 changes: 5 additions & 14 deletions packages/llamaindex/src/selectors/base.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { toQueryBundle } from "../internal/utils.js";
import { PromptMixin } from "../prompts/Mixin.js";
import type { QueryBundle, ToolMetadataOnlyDescription } from "../types.js";

Expand All @@ -10,8 +11,6 @@ export type SelectorResult = {
selections: SingleSelection[];
};

type QueryType = string | QueryBundle;

function wrapChoice(
choice: string | ToolMetadataOnlyDescription,
): ToolMetadataOnlyDescription {
Expand All @@ -22,21 +21,13 @@ function wrapChoice(
}
}

function wrapQuery(query: QueryType): QueryBundle {
if (typeof query === "string") {
return { queryStr: query };
}

return query;
}

type MetadataType = string | ToolMetadataOnlyDescription;

export abstract class BaseSelector extends PromptMixin {
async select(choices: MetadataType[], query: QueryType) {
const metadatas = choices.map((choice) => wrapChoice(choice));
const queryBundle = wrapQuery(query);
return await this._select(metadatas, queryBundle);
async select(choices: MetadataType[], query: string | QueryBundle) {
const metadata = choices.map((choice) => wrapChoice(choice));
const queryBundle = toQueryBundle(query);
return await this._select(metadata, queryBundle);
}

abstract _select(
Expand Down
9 changes: 7 additions & 2 deletions packages/llamaindex/src/selectors/llmSelectors.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { LLM } from "@llamaindex/core/llms";
import { extractText } from "@llamaindex/core/utils";
import type { Answer } from "../outputParsers/selectors.js";
import { SelectionOutputParser } from "../outputParsers/selectors.js";
import type {
Expand Down Expand Up @@ -88,7 +89,7 @@ export class LLMMultiSelector extends BaseSelector {
const prompt = this.prompt(
choicesText.length,
choicesText,
query.queryStr,
extractText(query.query),
this.maxOutputs,
);

Expand Down Expand Up @@ -152,7 +153,11 @@ export class LLMSingleSelector extends BaseSelector {
): Promise<SelectorResult> {
const choicesText = buildChoicesText(choices);

const prompt = this.prompt(choicesText.length, choicesText, query.queryStr);
const prompt = this.prompt(
choicesText.length,
choicesText,
extractText(query.query),
);

const formattedPrompt = this.outputParser.format(prompt);

Expand Down
23 changes: 15 additions & 8 deletions packages/llamaindex/src/synthesizers/builders.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type { LLM } from "@llamaindex/core/llms";
import { streamConverter } from "@llamaindex/core/utils";
import { extractText, streamConverter } from "@llamaindex/core/utils";
import { toQueryBundle } from "../internal/utils.js";
import type {
RefinePrompt,
SimplePrompt,
Expand Down Expand Up @@ -61,7 +62,7 @@ export class SimpleResponseBuilder implements ResponseBuilder {
AsyncIterable<string> | string
> {
const input = {
query,
query: extractText(toQueryBundle(query).query),
context: textChunks.join("\n\n"),
};

Expand Down Expand Up @@ -142,14 +143,14 @@ export class Refine extends PromptMixin implements ResponseBuilder {
const lastChunk = i === textChunks.length - 1;
if (!response) {
response = await this.giveResponseSingle(
query,
extractText(toQueryBundle(query).query),
chunk,
!!stream && lastChunk,
);
} else {
response = await this.refineResponseSingle(
response as string,
query,
extractText(toQueryBundle(query).query),
chunk,
!!stream && lastChunk,
);
Expand Down Expand Up @@ -254,9 +255,15 @@ export class CompactAndRefine extends Refine {
AsyncIterable<string> | string
> {
const textQATemplate: SimplePrompt = (input) =>
this.textQATemplate({ ...input, query: query });
this.textQATemplate({
...input,
query: extractText(toQueryBundle(query).query),
});
const refineTemplate: SimplePrompt = (input) =>
this.refineTemplate({ ...input, query: query });
this.refineTemplate({
...input,
query: extractText(toQueryBundle(query).query),
});

const maxPrompt = getBiggestPrompt([textQATemplate, refineTemplate]);
const newTexts = this.promptHelper.repack(maxPrompt, textChunks);
Expand Down Expand Up @@ -335,7 +342,7 @@ export class TreeSummarize extends PromptMixin implements ResponseBuilder {
const params = {
prompt: this.summaryTemplate({
context: packedTextChunks[0],
query,
query: extractText(toQueryBundle(query).query),
}),
};
if (stream) {
Expand All @@ -349,7 +356,7 @@ export class TreeSummarize extends PromptMixin implements ResponseBuilder {
this.llm.complete({
prompt: this.summaryTemplate({
context: chunk,
query,
query: extractText(toQueryBundle(query).query),
}),
}),
),
Expand Down
23 changes: 11 additions & 12 deletions packages/llamaindex/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/**
* Top level types to avoid circular dependencies
*/
import type { ToolMetadata } from "@llamaindex/core/llms";
import type { MessageContent, ToolMetadata } from "@llamaindex/core/llms";
import type { EngineResponse } from "./EngineResponse.js";

/**
Expand Down Expand Up @@ -52,16 +52,15 @@ export interface StructuredOutput<T> {

export type ToolMetadataOnlyDescription = Pick<ToolMetadata, "description">;

export class QueryBundle {
queryStr: string;

constructor(queryStr: string) {
this.queryStr = queryStr;
}

toString(): string {
return this.queryStr;
}
}
/**
* @link https://docs.llamaindex.ai/en/stable/api_reference/schema/?h=querybundle#llama_index.core.schema.QueryBundle
*
* We don't have `image_path` here, because it is included in the `query` field.
*/
export type QueryBundle = {
query: string | MessageContent;
customEmbedding?: string[];
embeddings?: number[];
};

export type UUID = `${string}-${string}-${string}-${string}-${string}`;

0 comments on commit 92f0782

Please sign in to comment.