forked from run-llama/LlamaIndexTS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
RetrieverQueryEngine.ts
91 lines (80 loc) · 2.61 KB
/
RetrieverQueryEngine.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import { BaseQueryEngine } from "@llamaindex/core/query-engine";
import type { BaseSynthesizer } from "@llamaindex/core/response-synthesizers";
import { getResponseSynthesizer } from "@llamaindex/core/response-synthesizers";
import { type NodeWithScore } from "@llamaindex/core/schema";
import { extractText } from "@llamaindex/core/utils";
import type { BaseNodePostprocessor } from "../../postprocessors/index.js";
import type { BaseRetriever } from "../../Retriever.js";
/**
* A query engine that uses a retriever to query an index and then synthesizes the response.
*/
export class RetrieverQueryEngine extends BaseQueryEngine {
retriever: BaseRetriever;
responseSynthesizer: BaseSynthesizer;
nodePostprocessors: BaseNodePostprocessor[];
preFilters?: unknown;
constructor(
retriever: BaseRetriever,
responseSynthesizer?: BaseSynthesizer,
preFilters?: unknown,
nodePostprocessors?: BaseNodePostprocessor[],
) {
super(async (strOrQueryBundle, stream) => {
const nodesWithScore = await this.retrieve(
typeof strOrQueryBundle === "string"
? strOrQueryBundle
: extractText(strOrQueryBundle),
);
if (stream) {
return this.responseSynthesizer.synthesize(
{
query:
typeof strOrQueryBundle === "string"
? { query: strOrQueryBundle }
: strOrQueryBundle,
nodes: nodesWithScore,
},
true,
);
}
return this.responseSynthesizer.synthesize({
query:
typeof strOrQueryBundle === "string"
? { query: strOrQueryBundle }
: strOrQueryBundle,
nodes: nodesWithScore,
});
});
this.retriever = retriever;
this.responseSynthesizer =
responseSynthesizer || getResponseSynthesizer("compact");
this.preFilters = preFilters;
this.nodePostprocessors = nodePostprocessors || [];
}
protected _getPrompts() {
return {};
}
protected _updatePrompts() {}
_getPromptModules() {
return {
responseSynthesizer: this.responseSynthesizer,
};
}
private async applyNodePostprocessors(nodes: NodeWithScore[], query: string) {
let nodesWithScore = nodes;
for (const postprocessor of this.nodePostprocessors) {
nodesWithScore = await postprocessor.postprocessNodes(
nodesWithScore,
query,
);
}
return nodesWithScore;
}
private async retrieve(query: string) {
const nodes = await this.retriever.retrieve({
query,
preFilters: this.preFilters,
});
return await this.applyNodePostprocessors(nodes, query);
}
}