Skip to content

Commit

Permalink
Amazon Bedrock - Retrieve and RetrieveAndGenerate (#39500)
Browse files Browse the repository at this point in the history
Both of these calls are super fast and neither has any kind of waiter or means of checking the status, so here can not be any sensor or trigger for them. They are simple client calls, but I think making these Operators allowed us to simplify the complicated formatting on the client API call itself, for a better UX.
  • Loading branch information
ferruzzi committed May 15, 2024
1 parent 287c188 commit 9284dc5
Show file tree
Hide file tree
Showing 8 changed files with 543 additions and 9 deletions.
20 changes: 20 additions & 0 deletions airflow/providers/amazon/aws/hooks/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,23 @@ class BedrockAgentHook(AwsBaseHook):
def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = self.client_type
super().__init__(*args, **kwargs)


class BedrockAgentRuntimeHook(AwsBaseHook):
"""
Interact with the Amazon Agents for Bedrock API.
Provide thin wrapper around :external+boto3:py:class:`boto3.client("bedrock-agent-runtime") <AgentsforBedrockRuntime.Client>`.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
- :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""

client_type = "bedrock-agent-runtime"

def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = self.client_type
super().__init__(*args, **kwargs)
220 changes: 218 additions & 2 deletions airflow/providers/amazon/aws/operators/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,25 @@
from time import sleep
from typing import TYPE_CHECKING, Any, Sequence

import botocore
from botocore.exceptions import ClientError

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook, BedrockRuntimeHook
from airflow.providers.amazon.aws.hooks.bedrock import (
BedrockAgentHook,
BedrockAgentRuntimeHook,
BedrockHook,
BedrockRuntimeHook,
)
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.bedrock import (
BedrockCustomizeModelCompletedTrigger,
BedrockIngestionJobTrigger,
BedrockKnowledgeBaseActiveTrigger,
BedrockProvisionModelThroughputCompletedTrigger,
)
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils import get_botocore_version, validate_execute_complete_event
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.utils.helpers import prune_dict
from airflow.utils.timezone import utcnow
Expand Down Expand Up @@ -664,3 +670,213 @@ def execute(self, context: Context) -> str:
)

return ingestion_job_id


class BedrockRaGOperator(AwsBaseOperator[BedrockAgentRuntimeHook]):
"""
Query a knowledge base and generate responses based on the retrieved results with sources citations.
NOTE: Support for EXTERNAL SOURCES was added in botocore 1.34.90
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:BedrockRaGOperator`
:param input: The query to be made to the knowledge base. (templated)
:param source_type: The type of resource that is queried by the request. (templated)
Must be one of 'KNOWLEDGE_BASE' or 'EXTERNAL_SOURCES', and the appropriate config values must also be provided.
If set to 'KNOWLEDGE_BASE' then `knowledge_base_id` must be provided, and `vector_search_config` may be.
If set to `EXTERNAL_SOURCES` then `sources` must also be provided.
NOTE: Support for EXTERNAL SOURCES was added in botocore 1.34.90
:param model_arn: The ARN of the foundation model used to generate a response. (templated)
:param prompt_template: The template for the prompt that's sent to the model for response generation.
You can include prompt placeholders, which are replaced before the prompt is sent to the model
to provide instructions and context to the model. In addition, you can include XML tags to delineate
meaningful sections of the prompt template. (templated)
:param knowledge_base_id: The unique identifier of the knowledge base that is queried. (templated)
Can only be specified if source_type='KNOWLEDGE_BASE'.
:param vector_search_config: How the results from the vector search should be returned. (templated)
Can only be specified if source_type='KNOWLEDGE_BASE'.
For more information, see https://docs.aws.amazon.com/bedrock/latest/userguide/kb-test-config.html.
:param sources: The documents used as reference for the response. (templated)
Can only be specified if source_type='EXTERNAL_SOURCES'
NOTE: Support for EXTERNAL SOURCES was added in botocore 1.34.90
:param rag_kwargs: Additional keyword arguments to pass to the API call. (templated)
"""

aws_hook_class = BedrockAgentRuntimeHook
template_fields: Sequence[str] = aws_template_fields(
"input",
"source_type",
"model_arn",
"prompt_template",
"knowledge_base_id",
"vector_search_config",
"sources",
"rag_kwargs",
)

def __init__(
self,
input: str,
source_type: str,
model_arn: str,
prompt_template: str | None = None,
knowledge_base_id: str | None = None,
vector_search_config: dict[str, Any] | None = None,
sources: list[dict[str, Any]] | None = None,
rag_kwargs: dict[str, Any] | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.input = input
self.prompt_template = prompt_template
self.source_type = source_type.upper()
self.knowledge_base_id = knowledge_base_id
self.model_arn = model_arn
self.vector_search_config = vector_search_config
self.sources = sources
self.rag_kwargs = rag_kwargs or {}

def validate_inputs(self):
if self.source_type == "KNOWLEDGE_BASE":
if self.knowledge_base_id is None:
raise AttributeError(
"If `source_type` is set to 'KNOWLEDGE_BASE' then `knowledge_base_id` must be provided."
)
if self.sources is not None:
raise AttributeError(
"`sources` can not be used when `source_type` is set to 'KNOWLEDGE_BASE'."
)
elif self.source_type == "EXTERNAL_SOURCES":
if not self.sources is not None:
raise AttributeError(
"If `source_type` is set to `EXTERNAL_SOURCES` then `sources` must also be provided."
)
if self.vector_search_config or self.knowledge_base_id:
raise AttributeError(
"`vector_search_config` and `knowledge_base_id` can not be used "
"when `source_type` is set to `EXTERNAL_SOURCES`"
)
else:
raise AttributeError(
"`source_type` must be one of 'KNOWLEDGE_BASE' or 'EXTERNAL_SOURCES', "
"and the appropriate config values must also be provided."
)

def build_rag_config(self) -> dict[str, Any]:
result: dict[str, Any] = {}
base_config: dict[str, Any] = {
"modelArn": self.model_arn,
}

if self.prompt_template:
base_config["generationConfiguration"] = {
"promptTemplate": {"textPromptTemplate": self.prompt_template}
}

if self.source_type == "KNOWLEDGE_BASE":
if self.vector_search_config:
base_config["retrievalConfiguration"] = {
"vectorSearchConfiguration": self.vector_search_config
}

result = {
"type": self.source_type,
"knowledgeBaseConfiguration": {
**base_config,
"knowledgeBaseId": self.knowledge_base_id,
},
}

if self.source_type == "EXTERNAL_SOURCES":
result = {
"type": self.source_type,
"externalSourcesConfiguration": {**base_config, "sources": self.sources},
}
return result

def execute(self, context: Context) -> Any:
self.validate_inputs()

try:
result = self.hook.conn.retrieve_and_generate(
input={"text": self.input},
retrieveAndGenerateConfiguration=self.build_rag_config(),
**self.rag_kwargs,
)
except botocore.exceptions.ParamValidationError as error:
if (
'Unknown parameter in retrieveAndGenerateConfiguration: "externalSourcesConfiguration"'
in str(error)
) and (self.source_type == "EXTERNAL_SOURCES"):
self.log.error(
"You are attempting to use External Sources and the BOTO API returned an "
"error message which may indicate the need to update botocore to do this. \n"
"Support for External Sources was added in botocore 1.34.90 and you are using botocore %s",
".".join(map(str, get_botocore_version())),
)
raise

self.log.info(
"\nPrompt: %s\nResponse: %s\nCitations: %s",
self.input,
result["output"]["text"],
result["citations"],
)
return result


class BedrockRetrieveOperator(AwsBaseOperator[BedrockAgentRuntimeHook]):
"""
Query a knowledge base and retrieve results with source citations.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:BedrockRetrieveOperator`
:param retrieval_query: The query to be made to the knowledge base. (templated)
:param knowledge_base_id: The unique identifier of the knowledge base that is queried. (templated)
:param vector_search_config: How the results from the vector search should be returned. (templated)
For more information, see https://docs.aws.amazon.com/bedrock/latest/userguide/kb-test-config.html.
:param retrieve_kwargs: Additional keyword arguments to pass to the API call. (templated)
"""

aws_hook_class = BedrockAgentRuntimeHook
template_fields: Sequence[str] = aws_template_fields(
"retrieval_query",
"knowledge_base_id",
"vector_search_config",
"retrieve_kwargs",
)

def __init__(
self,
retrieval_query: str,
knowledge_base_id: str,
vector_search_config: dict[str, Any] | None = None,
retrieve_kwargs: dict[str, Any] | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.retrieval_query = retrieval_query
self.knowledge_base_id = knowledge_base_id
self.vector_search_config = vector_search_config
self.retrieve_kwargs = retrieve_kwargs or {}

def execute(self, context: Context) -> Any:
retrieval_configuration = (
{"retrievalConfiguration": {"vectorSearchConfiguration": self.vector_search_config}}
if self.vector_search_config
else {}
)

result = self.hook.conn.retrieve(
retrievalQuery={"text": self.retrieval_query},
knowledgeBaseId=self.knowledge_base_id,
**retrieval_configuration,
**self.retrieve_kwargs,
)

self.log.info("\nQuery: %s\nRetrieved: %s", self.retrieval_query, result["retrievalResults"])
return result
7 changes: 7 additions & 0 deletions airflow/providers/amazon/aws/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from enum import Enum
from typing import Any

import importlib_metadata

from airflow.exceptions import AirflowException
from airflow.utils.helpers import prune_dict
from airflow.version import version
Expand Down Expand Up @@ -74,6 +76,11 @@ def get_airflow_version() -> tuple[int, ...]:
return tuple(int(x) for x in match.groups())


def get_botocore_version() -> tuple[int, ...]:
"""Return the version number of the installed botocore package in the form of a tuple[int,...]."""
return tuple(map(int, importlib_metadata.version("botocore").split(".")[:3]))


def validate_execute_complete_event(event: dict[str, Any] | None = None) -> dict[str, Any]:
if event is None:
err_msg = "Trigger error: event is None"
Expand Down
52 changes: 52 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/bedrock.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ Create an Amazon Bedrock Knowledge Base
To create an Amazon Bedrock Knowledge Base, you can use
:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockCreateKnowledgeBaseOperator`.

For more information on which models support embedding data into a vector store, see
https://docs.aws.amazon.com/bedrock/latest/userguide/knowledge-base-supported.html

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py
:language: python
:dedent: 4
Expand Down Expand Up @@ -174,6 +177,55 @@ To add data from an Amazon S3 bucket into an Amazon Bedrock Data Source, you can
:start-after: [START howto_operator_bedrock_ingest_data]
:end-before: [END howto_operator_bedrock_ingest_data]

.. _howto/operator:BedrockRetrieveOperator:

Amazon Bedrock Retrieve
=======================

To query a knowledge base, you can use :class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockRetrieveOperator`.

The response will only contain citations to sources that are relevant to the query. If you
would like to pass the results through an LLM in order to generate a text response, see
:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockRaGOperator`

For more information on which models support retrieving information from a knowledge base, see
https://docs.aws.amazon.com/bedrock/latest/userguide/knowledge-base-supported.html

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py
:language: python
:dedent: 4
:start-after: [START howto_operator_bedrock_retrieve]
:end-before: [END howto_operator_bedrock_retrieve]

.. _howto/operator:BedrockRaGOperator:

Amazon Bedrock Retrieve and Generate (RaG)
==========================================

To query a knowledge base or external sources and generate a text response based on the retrieved
results, you can use :class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockRaGOperator`.

The response will contain citations to sources that are relevant to the query as well as a generated text reply.
For more information on which models support retrieving information from a knowledge base, see
https://docs.aws.amazon.com/bedrock/latest/userguide/knowledge-base-supported.html

NOTE: Support for "external sources" was added in boto 1.34.90

Example using an Amazon Bedrock Knowledge Base:

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py
:language: python
:dedent: 4
:start-after: [START howto_operator_bedrock_knowledge_base_rag]
:end-before: [END howto_operator_bedrock_knowledge_base_rag]

Example using a PDF file in an Amazon S3 Bucket:

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py
:language: python
:dedent: 4
:start-after: [START howto_operator_bedrock_external_sources_rag]
:end-before: [END howto_operator_bedrock_external_sources_rag]


Sensors
Expand Down
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,7 @@ queueing
quickstart
quotechar
rabbitmq
RaG
RBAC
rbac
rc
Expand Down
8 changes: 7 additions & 1 deletion tests/providers/amazon/aws/hooks/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@

import pytest

from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook, BedrockRuntimeHook
from airflow.providers.amazon.aws.hooks.bedrock import (
BedrockAgentHook,
BedrockAgentRuntimeHook,
BedrockHook,
BedrockRuntimeHook,
)


class TestBedrockHooks:
Expand All @@ -28,6 +33,7 @@ class TestBedrockHooks:
pytest.param(BedrockHook(), "bedrock", id="bedrock"),
pytest.param(BedrockRuntimeHook(), "bedrock-runtime", id="bedrock-runtime"),
pytest.param(BedrockAgentHook(), "bedrock-agent", id="bedrock-agent"),
pytest.param(BedrockAgentRuntimeHook(), "bedrock-agent-runtime", id="bedrock-agent-runtime"),
],
)
def test_bedrock_hooks(self, test_hook, service_name):
Expand Down
Loading

0 comments on commit 9284dc5

Please sign in to comment.