Skip to content

Commit

Permalink
feat: add Amazon Bedrock Retriever (#1219)
Browse files Browse the repository at this point in the history
Co-authored-by: Arnaud JEAN <arnajean@amazon.com>
Co-authored-by: ajohn-wick <ajohnwick@mrwick.org>
Co-authored-by: Alex Yang <himself65@outlook.com>
  • Loading branch information
4 people authored Sep 23, 2024
1 parent 8b7fdba commit 50e6b57
Show file tree
Hide file tree
Showing 6 changed files with 1,314 additions and 52 deletions.
5 changes: 5 additions & 0 deletions .changeset/sour-dots-try.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@llamaindex/community": patch
---

feat: add Amazon Bedrock Retriever
1 change: 1 addition & 0 deletions packages/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Bedrock support for the Anthropic Claude Models [usage](https://ts.llamaindex.ai/modules/llms/available_llms/bedrock)
- Bedrock support for the Meta LLama 2, 3 and 3.1 Models [usage](https://ts.llamaindex.ai/modules/llms/available_llms/bedrock)
- Meta LLama3.1 405b tool call support
- Bedrock support for querying Knowledge Base

## LICENSE

Expand Down
1 change: 1 addition & 0 deletions packages/community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
},
"dependencies": {
"@aws-sdk/client-bedrock-runtime": "^3.642.0",
"@aws-sdk/client-bedrock-agent-runtime": "^3.642.0",
"@llamaindex/core": "workspace:*",
"@llamaindex/env": "workspace:*"
}
Expand Down
1 change: 1 addition & 0 deletions packages/community/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ export {
BEDROCK_MODEL_MAX_TOKENS,
Bedrock,
} from "./llm/bedrock/index.js";
export { AmazonKnowledgeBaseRetriever } from "./retrievers/bedrock.js";
165 changes: 165 additions & 0 deletions packages/community/src/retrievers/bedrock.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import type { KnowledgeBaseVectorSearchConfiguration } from "@aws-sdk/client-bedrock-agent-runtime";
import {
BedrockAgentRuntimeClient,
type BedrockAgentRuntimeClientConfig,
type RetrievalFilter,
RetrieveCommand,
type SearchType,
} from "@aws-sdk/client-bedrock-agent-runtime";
import type { QueryBundle } from "@llamaindex/core/query-engine";
import { BaseRetriever } from "@llamaindex/core/retriever";
import { Document, type NodeWithScore } from "@llamaindex/core/schema";
import { extractText } from "@llamaindex/core/utils";

/**
* Interface for the arguments required to initialize an
* AmazonKnowledgeBaseRetriever instance.
*/
export interface AmazonKnowledgeBaseRetrieverArgs {
knowledgeBaseId: string;
topK: number;
region: string;
clientOptions?: BedrockAgentRuntimeClientConfig;
filter?: RetrievalFilter;
overrideSearchType?: SearchType;
}

/**
* Class for interacting with Amazon Bedrock Knowledge Bases, a RAG workflow oriented service
* Extends the BaseRetriever class.
* @example
* ```typescript
* const retriever = new AmazonKnowledgeBaseRetriever({
* topK: 10,
* knowledgeBaseId: "YOUR_KNOWLEDGE_BASE_ID",
* region: "us-east-2",
* clientOptions: {
* credentials: {
* accessKeyId: "YOUR_ACCESS_KEY_ID",
* secretAccessKey: "YOUR_SECRET_ACCESS_KEY",
* },
* },
* });
*
* const docs = await retriever.retrieve({query: "How are clouds formed?"});
* ```
*/
export class AmazonKnowledgeBaseRetriever extends BaseRetriever {
static lc_name() {
return "AmazonKnowledgeBaseRetriever";
}

lc_namespace = ["llamaindex", "retrievers", "amazon_bedrock_knowledge_base"];

knowledgeBaseId: string;

topK: number;

bedrockAgentRuntimeClient: BedrockAgentRuntimeClient;

filter: RetrievalFilter | undefined;

overrideSearchType: SearchType | undefined;

constructor({
knowledgeBaseId,
topK = 10,
clientOptions,
region,
filter,
overrideSearchType,
}: AmazonKnowledgeBaseRetrieverArgs) {
super();

this.topK = topK;
this.filter = filter;
this.overrideSearchType = overrideSearchType;
this.bedrockAgentRuntimeClient = new BedrockAgentRuntimeClient({
region,
...clientOptions,
});
this.knowledgeBaseId = knowledgeBaseId;
}

/**
* Cleans the result text by replacing sequences of whitespace with a
* single space and removing ellipses.
* @param resText The result text to clean.
* @returns The cleaned result text.
*/
cleanResult(resText: string) {
const res = resText.replace(/\s+/g, " ").replace(/\.\.\./g, "");
return res;
}

async queryKnowledgeBase(
query: QueryBundle,
topK: number,
filter?: RetrievalFilter,
overrideSearchType?: SearchType,
): Promise<NodeWithScore[]> {
const retrieveCommand = new RetrieveCommand({
knowledgeBaseId: this.knowledgeBaseId,
retrievalQuery: {
text: extractText(query),
},
retrievalConfiguration: {
vectorSearchConfiguration: {
numberOfResults: topK,
overrideSearchType,
filter,
} as KnowledgeBaseVectorSearchConfiguration,
},
});

const retrieveResponse =
await this.bedrockAgentRuntimeClient.send(retrieveCommand);

return (
retrieveResponse.retrievalResults?.map((result) => {
let source;
switch (result.location?.type) {
case "CONFLUENCE":
source = result.location?.confluenceLocation?.url;
break;
case "S3":
source = result.location?.s3Location?.uri;
break;
case "SALESFORCE":
source = result.location?.salesforceLocation?.url;
break;
case "SHAREPOINT":
source = result.location?.sharePointLocation?.url;
break;
case "WEB":
source = result.location?.webLocation?.url;
break;
default:
source = result.location?.s3Location?.uri;
break;
}

return {
node: new Document({
text: this.cleanResult(result.content?.text || ""),
metadata: {
source,
score: result.score,
...result.metadata,
},
}),
score: result.score ?? 1.0,
};
}) ?? []
);
}

async _retrieve(query: QueryBundle): Promise<NodeWithScore[]> {
return await this.queryKnowledgeBase(
query,
this.topK,
this.filter,
this.overrideSearchType,
);
}
}
Loading

0 comments on commit 50e6b57

Please sign in to comment.