diff --git a/airflow/providers/amazon/aws/hooks/dynamodb.py b/airflow/providers/amazon/aws/hooks/dynamodb.py index 07c9b99a51a46b..29707b3f495d9a 100644 --- a/airflow/providers/amazon/aws/hooks/dynamodb.py +++ b/airflow/providers/amazon/aws/hooks/dynamodb.py @@ -19,11 +19,17 @@ from __future__ import annotations -from typing import Iterable +from functools import cached_property +from typing import TYPE_CHECKING, Iterable + +from botocore.exceptions import ClientError from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +if TYPE_CHECKING: + from botocore.client import BaseClient + class DynamoDBHook(AwsBaseHook): """ @@ -50,6 +56,11 @@ def __init__( kwargs["resource_type"] = "dynamodb" super().__init__(*args, **kwargs) + @cached_property + def client(self) -> BaseClient: + """Return boto3 client.""" + return self.get_conn().meta.client + def write_batch_data(self, items: Iterable) -> bool: """ Write batch items to DynamoDB table with provisioned throughout capacity. @@ -70,3 +81,25 @@ def write_batch_data(self, items: Iterable) -> bool: return True except Exception as general_error: raise AirflowException(f"Failed to insert items in dynamodb, error: {general_error}") + + def get_import_status(self, import_arn: str) -> tuple[str, str | None, str | None]: + """ + Get import status from Dynamodb. + + :param import_arn: The Amazon Resource Name (ARN) for the import. + :return: Import status, Error code and Error message + """ + self.log.info("Poking for Dynamodb import %s", import_arn) + + try: + describe_import = self.client.describe_import(ImportArn=import_arn) + status = describe_import["ImportTableDescription"]["ImportStatus"] + error_code = describe_import["ImportTableDescription"].get("FailureCode") + error_msg = describe_import["ImportTableDescription"].get("FailureMessage") + return status, error_code, error_msg + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code") + if error_code == "ImportNotFoundException": + raise AirflowException("S3 import into Dynamodb job not found.") + else: + raise e diff --git a/airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py b/airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py new file mode 100644 index 00000000000000..b8c50e8032ffa4 --- /dev/null +++ b/airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py @@ -0,0 +1,261 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, Sequence, TypedDict + +from botocore.exceptions import ClientError, WaiterError + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class AttributeDefinition(TypedDict): + """Attribute Definition Type.""" + + AttributeName: str + AttributeType: Literal["S", "N", "B"] + + +class KeySchema(TypedDict): + """Key Schema Type.""" + + AttributeName: str + KeyType: Literal["HASH", "RANGE"] + + +class S3ToDynamoDBOperator(BaseOperator): + """Load Data from S3 into a DynamoDB. + + Data stored in S3 can be uploaded to a new or existing DynamoDB. Supported file formats CSV, DynamoDB JSON and + Amazon ION. + + + :param s3_bucket: The S3 bucket that is imported + :param s3_key: key prefix that imports single or multiple objects from S3 + :param dynamodb_table_name: Name of the table that shall be created + :param dynamodb_key_schema: Primary key and sort key. Each element represents one primary key + attribute. AttributeName is the name of the attribute. KeyType is the role for the attribute. Valid values + HASH or RANGE + :param dynamodb_attributes: Name of the attributes of a table. AttributeName is the name for the attribute + AttributeType is the data type for the attribute. Valid values for AttributeType are + S - attribute is of type String + N - attribute is of type Number + B - attribute is of type Binary + :param dynamodb_tmp_table_prefix: Prefix for the temporary DynamoDB table + :param delete_dynamodb_tmp_table: If set, deletes the temporary DynamoDB table that is used for staging + :param delete_on_error: If set, the new DynamoDB table will be deleted in case of import errors + :param use_existing_table: Whether to import to an existing non new DynamoDB table. If set to + true data is loaded first into a temporary DynamoDB table, then retrieved as chunks into memory and loaded + into the target table + :param input_format: The format for the imported data. Valid values for InputFormat are CSV, DYNAMODB_JSON + or ION + :param billing_mode: Billing mode for the table. Valid values are PROVISIONED or PAY_PER_REQUEST + :param on_demand_throughput: Extra options for maximum number of read and write units + :param import_table_kwargs: Any additional optional import table parameters to pass, such as ClientToken, + InputCompressionType, or InputFormatOptions. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/client/import_table.html + :param import_table_creation_kwargs: Any additional optional import table creation parameters to pass, such as + ProvisionedThroughput, SSESpecification, or GlobalSecondaryIndexes. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/client/import_table.html + :param wait_for_completion: Whether to wait for cluster to stop + :param check_interval: Time in seconds to wait between status checks + :param max_attempts: Maximum number of attempts to check for job completion + :param aws_conn_id: The reference to the AWS connection details + """ + + template_fields: Sequence[str] = ( + "s3_bucket", + "s3_key", + "dynamodb_table_name", + "dynamodb_key_schema", + "dynamodb_attributes", + "dynamodb_tmp_table_prefix", + "delete_dynamodb_tmp_table", + "delete_on_error", + "use_existing_table", + "input_format", + "billing_mode", + "import_table_kwargs", + "import_table_creation_kwargs", + ) + ui_color = "#e2e8f0" + + def __init__( + self, + *, + s3_bucket: str, + s3_key: str, + dynamodb_table_name: str, + dynamodb_key_schema: list[KeySchema], + dynamodb_attributes: list[AttributeDefinition] | None = None, + dynamodb_tmp_table_prefix: str = "tmp", + delete_dynamodb_tmp_table: bool = True, + delete_on_error: bool = False, + use_existing_table: bool = False, + input_format: Literal["CSV", "DYNAMODB_JSON", "ION"] = "DYNAMODB_JSON", + billing_mode: Literal["PROVISIONED", "PAY_PER_REQUEST"] = "PAY_PER_REQUEST", + import_table_kwargs: dict[str, Any] | None = None, + import_table_creation_kwargs: dict[str, Any] | None = None, + wait_for_completion: bool = True, + check_interval: int = 30, + max_attempts: int = 240, + aws_conn_id: str | None = "aws_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.s3_bucket = s3_bucket + self.s3_key = s3_key + self.dynamodb_table_name = dynamodb_table_name + self.dynamodb_attributes = dynamodb_attributes + self.dynamodb_tmp_table_prefix = dynamodb_tmp_table_prefix + self.delete_dynamodb_tmp_table = delete_dynamodb_tmp_table + self.delete_on_error = delete_on_error + self.use_existing_table = use_existing_table + self.dynamodb_key_schema = dynamodb_key_schema + self.input_format = input_format + self.billing_mode = billing_mode + self.import_table_kwargs = import_table_kwargs + self.import_table_creation_kwargs = import_table_creation_kwargs + self.wait_for_completion = wait_for_completion + self.check_interval = check_interval + self.max_attempts = max_attempts + self.aws_conn_id = aws_conn_id + + @property + def tmp_table_name(self): + """Temporary table name.""" + return f"{self.dynamodb_tmp_table_prefix}_{self.dynamodb_table_name}" + + def _load_into_new_table(self, context: Context, tmp_table: bool = False) -> str: + """ + Import S3 key or keys into a new DynamoDB table. + + :param context: the current context of the task instance + :param tmp_table: creates DynamoDB Table with tmp prefix + :return: the Amazon resource number (ARN) + """ + dynamodb_hook = DynamoDBHook(aws_conn_id=self.aws_conn_id) + client = dynamodb_hook.client + table_name = self.tmp_table_name if tmp_table else self.dynamodb_table_name + + import_table_config = self.import_table_kwargs or {} + import_table_creation_config = self.import_table_creation_kwargs or {} + + try: + response = client.import_table( + S3BucketSource={ + "S3Bucket": self.s3_bucket, + "S3KeyPrefix": self.s3_key, + }, + InputFormat=self.input_format, + TableCreationParameters={ + "TableName": table_name, + "AttributeDefinitions": self.dynamodb_attributes, + "KeySchema": self.dynamodb_key_schema, + "BillingMode": self.billing_mode, + **import_table_creation_config, + }, + **import_table_config, + ) + except ClientError as e: + self.log.error("Error: failed to load from S3 into DynamoDB table. Error: %s", str(e)) + raise AirflowException(f"S3 load into DynamoDB table failed with error: {e}") + + if response["ImportTableDescription"]["ImportStatus"] == "FAILED": + raise AirflowException( + "S3 into Dynamodb job creation failed. Code: " + f"{response['ImportTableDescription']['FailureCode']}. " + f"Failure: {response['ImportTableDescription']['FailureMessage']}" + ) + + if self.wait_for_completion: + self.log.info("Waiting for S3 into Dynamodb job to complete") + waiter = dynamodb_hook.get_waiter("import_table") + try: + waiter.wait( + ImportArn=response["ImportTableDescription"]["ImportArn"], + WaiterConfig={"Delay": self.check_interval, "MaxAttempts": self.max_attempts}, + ) + except WaiterError: + status, error_code, error_msg = dynamodb_hook.get_import_status( + response["ImportTableDescription"]["ImportArn"] + ) + if self.delete_on_error: + client.delete_table(TableName=table_name) + raise AirflowException( + f"S3 import into Dynamodb job failed: Status: {status}. Error: {error_code}. Error message: {error_msg}" + ) + return response["ImportTableDescription"]["ImportArn"] + + def _load_into_existing_table(self, context: Context) -> str: + """ + Import S3 key or keys in an existing DynamoDB table. + + :param context: the current context of the task instance + :return: the Amazon resource number (ARN) + """ + if not self.wait_for_completion: + raise ValueError("wait_for_completion must be set to True when loading into an existing table") + table_keys = [key["AttributeName"] for key in self.dynamodb_key_schema] + + dynamodb_hook = DynamoDBHook( + aws_conn_id=self.aws_conn_id, table_name=self.dynamodb_table_name, table_keys=table_keys + ) + client = dynamodb_hook.client + + self.log.info("Loading from S3 into a tmp DynamoDB table %s", self.tmp_table_name) + self._load_into_new_table(context=context, tmp_table=True) + total_items = 0 + try: + paginator = client.get_paginator("scan") + paginate = paginator.paginate( + TableName=self.tmp_table_name, + Select="ALL_ATTRIBUTES", + ReturnConsumedCapacity="NONE", + ConsistentRead=True, + ) + self.log.info( + "Loading data from %s to %s DynamoDB table", self.tmp_table_name, self.dynamodb_table_name + ) + for page in paginate: + total_items += page.get("Count", 0) + dynamodb_hook.write_batch_data(items=page["Items"]) + self.log.info("Number of items loaded: %s", total_items) + finally: + if self.delete_dynamodb_tmp_table: + self.log.info("Delete tmp DynamoDB table %s", self.tmp_table_name) + client.delete_table(TableName=self.tmp_table_name) + return dynamodb_hook.get_conn().Table(self.dynamodb_table_name).table_arn + + def execute(self, context: Context) -> str: + """ + Execute S3 to DynamoDB Job from Airflow. + + :param context: the current context of the task instance + :return: the Amazon resource number (ARN) + """ + if self.use_existing_table: + self.log.info("Loading from S3 into new DynamoDB table %s", self.dynamodb_table_name) + return self._load_into_existing_table(context=context) + self.log.info("Loading from S3 into existing DynamoDB table %s", self.dynamodb_table_name) + return self._load_into_new_table(context=context) diff --git a/airflow/providers/amazon/aws/waiters/dynamodb.json b/airflow/providers/amazon/aws/waiters/dynamodb.json index acd23268ab7b65..1051530630920e 100644 --- a/airflow/providers/amazon/aws/waiters/dynamodb.json +++ b/airflow/providers/amazon/aws/waiters/dynamodb.json @@ -25,6 +25,43 @@ "state": "retry" } ] + }, + "import_table": { + "operation": "DescribeImport", + "delay": 30, + "maxAttempts": 240, + "acceptors": [ + { + "matcher": "path", + "expected": "COMPLETED", + "argument": "ImportTableDescription.ImportStatus", + "state": "success" + }, + { + "matcher": "path", + "expected": "CANCELLING", + "argument": "ImportTableDescription.ImportStatus", + "state": "failure" + }, + { + "matcher": "path", + "expected": "CANCELLED", + "argument": "ImportTableDescription.ImportStatus", + "state": "failure" + }, + { + "matcher": "path", + "expected": "FAILED", + "argument": "ImportTableDescription.ImportStatus", + "state": "failure" + }, + { + "matcher": "path", + "expected": "IN_PROGRESS", + "argument": "ImportTableDescription.ImportStatus", + "state": "retry" + } + ] } } } diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 7c06879143ef18..d3a26dbbab0807 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -821,6 +821,10 @@ transfers: target-integration-name: Amazon Simple Storage Service (S3) how-to-guide: /docs/apache-airflow-providers-amazon/transfer/azure_blob_to_s3.rst python-module: airflow.providers.amazon.aws.transfers.azure_blob_to_s3 + - source-integration-name: Amazon Simple Storage Service (S3) + target-integration-name: Amazon DynamoDB + how-to-guide: /docs/apache-airflow-providers-amazon/transfer/s3_to_dynamodb.rst + python-module: airflow.providers.amazon.aws.transfers.s3_to_dynamodb extra-links: - airflow.providers.amazon.aws.links.athena.AthenaQueryResultsLink diff --git a/docs/apache-airflow-providers-amazon/transfer/s3_to_dynamodb.rst b/docs/apache-airflow-providers-amazon/transfer/s3_to_dynamodb.rst new file mode 100644 index 00000000000000..630863547d7a48 --- /dev/null +++ b/docs/apache-airflow-providers-amazon/transfer/s3_to_dynamodb.rst @@ -0,0 +1,71 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +============================ +Amazon S3 to DynamoDB +============================ + +Use the ``S3ToDynamoDBOperator`` transfer to load data stored in Amazon Simple Storage Service (S3) bucket +to an existing or new Amazon DynamoDB table. + +Prerequisite Tasks +------------------ + +.. include:: ../_partials/prerequisite_tasks.rst + +Operators +--------- + +.. _howto/transfer:S3ToDynamoDBOperator: + +Amazon S3 To DynamoDB transfer operator +============================================== + +This operator loads data from Amazon S3 to an Amazon DynamoDB table. It uses the Amazon DynamoDB +ImportTable Services that interacts with different AWS Services such Amazon S3 and CloudWatch. The +default behavior is to load S3 data into a new Amazon DynamoDB table. The import into an existing +table is currently not supported by the Service. Thus, the operator uses a custom approach. It creates +a temporary DynamoDB table and loads S3 data into the table. Then it scans the temporary Amazon +DynamoDB table and writes the received records to the target table. + + +To get more information visit: +:class:`~airflow.providers.amazon.aws.transfers.s3_to_dynamodb.S3ToDynamoDBOperator` + +Example usage: + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_s3_to_dynamodb.py + :language: python + :dedent: 4 + :start-after: [START howto_transfer_s3_to_dynamodb] + :end-before: [END howto_transfer_s3_to_dynamodb] + + +To load S3 data into an existing DynamoDB table use: + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_s3_to_dynamodb.py + :language: python + :dedent: 4 + :start-after: [START howto_transfer_s3_to_dynamodb_existing_table] + :end-before: [END howto_transfer_s3_to_dynamodb_existing_table] + + +Reference +--------- + +* `AWS boto3 library documentation for Amazon DynamoDB `__ +* `AWS boto3 library documentation for Amazon S3 `__ diff --git a/tests/providers/amazon/aws/hooks/test_dynamodb.py b/tests/providers/amazon/aws/hooks/test_dynamodb.py index 4e3e96c0dd47a3..51131d29aea162 100644 --- a/tests/providers/amazon/aws/hooks/test_dynamodb.py +++ b/tests/providers/amazon/aws/hooks/test_dynamodb.py @@ -20,16 +20,29 @@ import uuid from unittest import mock +import pytest +from botocore.exceptions import ClientError from moto import mock_aws +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook +TEST_IMPORT_ARN = "arn:aws:dynamodb:us-east-1:255683865591:table/test-table/import/01662190284205-aa94decf" + class TestDynamoDBHook: @mock_aws def test_get_conn_returns_a_boto3_connection(self): hook = DynamoDBHook(aws_conn_id="aws_default") - assert hook.get_conn() is not None + conn = hook.get_conn() + assert conn is not None + assert conn.__class__.__name__ == "dynamodb.ServiceResource" + + @mock_aws + def test_get_client_from_dynamodb_ressource(self): + hook = DynamoDBHook(aws_conn_id="aws_default") + client = hook.client + assert client.__class__.__name__ == "DynamoDB" @mock_aws def test_insert_batch_items_dynamodb_table(self): @@ -61,3 +74,76 @@ def test_waiter_path_generated_from_resource_type(self, _): hook = DynamoDBHook(aws_conn_id="aws_default") path = hook.waiter_path assert path.as_uri().endswith("/airflow/providers/amazon/aws/waiters/dynamodb.json") + + @pytest.mark.parametrize( + "response, status, error", + [ + pytest.param( + {"ImportTableDescription": {"ImportStatus": "COMPLETED"}}, "COMPLETED", False, id="complete" + ), + pytest.param( + { + "ImportTableDescription": { + "ImportStatus": "CANCELLING", + "FailureCode": "Failure1", + "FailureMessage": "Message", + } + }, + "CANCELLING", + True, + id="cancel", + ), + pytest.param( + {"ImportTableDescription": {"ImportStatus": "IN_PROGRESS"}}, + "IN_PROGRESS", + False, + id="progress", + ), + ], + ) + @mock.patch("botocore.client.BaseClient._make_api_call") + def test_get_s3_import_status(self, mock_make_api_call, response, status, error): + mock_make_api_call.return_value = response + hook = DynamoDBHook(aws_conn_id="aws_default") + sta, code, msg = hook.get_import_status(import_arn=TEST_IMPORT_ARN) + mock_make_api_call.assert_called_once_with("DescribeImport", {"ImportArn": TEST_IMPORT_ARN}) + assert sta == status + if error: + assert code == "Failure1" + assert msg == "Message" + else: + assert code is None + assert msg is None + + @pytest.mark.parametrize( + "effect, error", + [ + pytest.param( + ClientError( + error_response={"Error": {"Message": "Error message", "Code": "GeneralException"}}, + operation_name="UnitTest", + ), + ClientError, + id="general-exception", + ), + pytest.param( + ClientError( + error_response={"Error": {"Message": "Error message", "Code": "ImportNotFoundException"}}, + operation_name="UnitTest", + ), + AirflowException, + id="not-found-exception", + ), + ], + ) + @mock.patch("botocore.client.BaseClient._make_api_call") + def test_get_s3_import_status_with_error(self, mock_make_api_call, effect, error): + mock_make_api_call.side_effect = effect + hook = DynamoDBHook(aws_conn_id="aws_default") + with pytest.raises(error): + hook.get_import_status(import_arn=TEST_IMPORT_ARN) + + def test_hook_has_import_waiters(self): + hook = DynamoDBHook(aws_conn_id="aws_default") + waiter = hook.get_waiter("import_table") + assert waiter is not None diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_dynamodb.py b/tests/providers/amazon/aws/transfers/test_s3_to_dynamodb.py new file mode 100644 index 00000000000000..f530c3cdc4f5e6 --- /dev/null +++ b/tests/providers/amazon/aws/transfers/test_s3_to_dynamodb.py @@ -0,0 +1,232 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest +from botocore.exceptions import ClientError, WaiterError + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook +from airflow.providers.amazon.aws.transfers.s3_to_dynamodb import S3ToDynamoDBOperator + +TASK_ID = "transfer_1" +BUCKET = "test-bucket" +S3_KEY_PREFIX = "test/test_data" +S3_KEY = "test/test_data_file_1.csv" +S3_CONN_ID = "aws_default" +DYNAMODB_TABLE_NAME = "test-table" +DYNAMODB_ATTRIBUTES = [ + {"AttributeName": "attribute_a", "AttributeType": "S"}, + {"AttributeName": "attribute_b", "AttributeType": "I"}, +] +DYNAMODB_KEY_SCHEMA = [ + {"AttributeName": "attribute_a", "KeyType": "HASH"}, +] + +DYNAMODB_PROV_THROUGHPUT = {"ReadCapacityUnits": 123, "WriteCapacityUnits": 123} +SUCCESS_S3_RESPONSE = { + "ImportTableDescription": { + "ImportArn": "arn:aws:dynamodb:import", + "ImportStatus": "IN_PROGRESS", + "TableArn": "arn:aws:dynamodb:table", + "TableId": "test-table", + "ClientToken": "client", + } +} +FAILURE_S3_RESPONSE = { + "ImportTableDescription": { + "ImportArn": "arn:aws:dynamodb:import", + "ImportStatus": "FAILED", + "TableArn": "arn:aws:dynamodb:table", + "TableId": "test-table", + "ClientToken": "client", + "FailureCode": "300", + "FailureMessage": "invalid csv format", + } +} +IMPORT_TABLE_RESPONSE = { + "S3BucketSource": {"S3Bucket": "test-bucket", "S3KeyPrefix": "test/test_data"}, + "InputFormat": "DYNAMODB_JSON", + "TableCreationParameters": { + "TableName": "test-table", + "AttributeDefinitions": [ + {"AttributeName": "attribute_a", "AttributeType": "S"}, + {"AttributeName": "attribute_b", "AttributeType": "I"}, + ], + "KeySchema": [{"AttributeName": "attribute_a", "KeyType": "HASH"}], + "BillingMode": "PAY_PER_REQUEST", + "ProvisionedThroughput": {"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, + }, +} + + +@pytest.fixture +def new_table_op(): + return S3ToDynamoDBOperator( + task_id=TASK_ID, + s3_key=S3_KEY_PREFIX, + s3_bucket=BUCKET, + dynamodb_table_name=DYNAMODB_TABLE_NAME, + dynamodb_attributes=DYNAMODB_ATTRIBUTES, + dynamodb_key_schema=DYNAMODB_KEY_SCHEMA, + aws_conn_id=S3_CONN_ID, + import_table_creation_kwargs={"ProvisionedThroughput": DYNAMODB_PROV_THROUGHPUT}, + ) + + +@pytest.fixture +def exist_table_op(): + return S3ToDynamoDBOperator( + task_id=TASK_ID, + s3_key=S3_KEY, + dynamodb_key_schema=DYNAMODB_KEY_SCHEMA, + s3_bucket=BUCKET, + dynamodb_table_name=DYNAMODB_TABLE_NAME, + use_existing_table=True, + aws_conn_id=S3_CONN_ID, + ) + + +class TestS3ToDynamoDBOperator: + @mock.patch.object(DynamoDBHook, "get_waiter") + @mock.patch("botocore.client.BaseClient._make_api_call") + def test_s3_to_dynamodb_new_table_wait_for_completion(self, mock_make_api_call, mock_wait, new_table_op): + mock_make_api_call.return_value = SUCCESS_S3_RESPONSE + + res = new_table_op.execute(None) + + mock_make_api_call.assert_called_once_with("ImportTable", IMPORT_TABLE_RESPONSE) + mock_wait.assert_called_once_with("import_table") + mock_wait.return_value.wait.assert_called_once_with( + ImportArn="arn:aws:dynamodb:import", WaiterConfig={"Delay": 30, "MaxAttempts": 240} + ) + assert res == "arn:aws:dynamodb:import" + + @pytest.mark.parametrize( + "delete_on_error", + [ + pytest.param( + True, + id="delete-on-error", + ), + pytest.param( + False, + id="no-delete-on-error", + ), + ], + ) + @mock.patch("airflow.providers.amazon.aws.transfers.s3_to_dynamodb.DynamoDBHook") + def test_s3_to_dynamodb_new_table_delete_on_error(self, mock_hook, new_table_op, delete_on_error): + mock_wait = mock.Mock() + mock_wait.side_effect = WaiterError(name="NetworkError", reason="unit test error", last_response={}) + mock_hook.return_value.get_waiter.return_value.wait = mock_wait + new_table_op.delete_on_error = delete_on_error + mock_hook.return_value.get_import_status.return_value = "FAILED", "400", "General error" + + with pytest.raises(AirflowException): + new_table_op.execute(None) + + if delete_on_error: + mock_hook.return_value.client.delete_table.assert_called_once_with(TableName="test-table") + else: + mock_hook.return_value.client.delete_table.assert_not_called() + + @mock.patch("botocore.client.BaseClient._make_api_call") + def test_s3_to_dynamodb_new_table_no_wait(self, mock_make_api_call): + mock_make_api_call.return_value = SUCCESS_S3_RESPONSE + op = S3ToDynamoDBOperator( + task_id=TASK_ID, + s3_key=S3_KEY_PREFIX, + s3_bucket=BUCKET, + dynamodb_table_name=DYNAMODB_TABLE_NAME, + dynamodb_attributes=DYNAMODB_ATTRIBUTES, + dynamodb_key_schema=DYNAMODB_KEY_SCHEMA, + aws_conn_id=S3_CONN_ID, + import_table_creation_kwargs={"ProvisionedThroughput": DYNAMODB_PROV_THROUGHPUT}, + wait_for_completion=False, + ) + res = op.execute(None) + + mock_make_api_call.assert_called_once_with("ImportTable", IMPORT_TABLE_RESPONSE) + assert res == "arn:aws:dynamodb:import" + + @mock.patch("botocore.client.BaseClient._make_api_call") + def test_s3_to_dynamodb_new_table_client_error(self, mock_make_api_call, new_table_op): + mock_make_api_call.side_effect = ClientError( + error_response={"Error": {"Message": "Error message", "Code": "GeneralException"}}, + operation_name="UnitTest", + ) + with pytest.raises(AirflowException) as excinfo: + new_table_op.execute(None) + assert "S3 load into DynamoDB table failed with error" in str( + excinfo.value + ), "Exception message not passed correctly" + + @mock.patch("botocore.client.BaseClient._make_api_call") + def test_s3_to_dynamodb_new_table_job_startup_error(self, mock_make_api_call, new_table_op): + mock_make_api_call.return_value = FAILURE_S3_RESPONSE + exp_err_msg = "S3 into Dynamodb job creation failed. Code: 300. Failure: invalid csv format" + with pytest.raises(AirflowException) as excinfo: + new_table_op.execute(None) + assert str(excinfo.value) == exp_err_msg, "Exception message not passed correctly" + + @mock.patch( + "airflow.providers.amazon.aws.transfers.s3_to_dynamodb.S3ToDynamoDBOperator._load_into_new_table" + ) + @mock.patch.object(DynamoDBHook, "get_conn") + def test_s3_to_dynamodb_existing_table(self, mock_get_conn, new_table_load_mock, exist_table_op): + response = [ + { + "Items": [ + {"Date": {"N": "54675846"}, "Message": {"S": "Message1"}, "_id": {"S": "1"}}, + {"Date": {"N": "54675847"}, "Message": {"S": "Message2"}, "_id": {"S": "2"}}, + {"Date": {"N": "54675857"}, "Message": {"S": "Message3"}, "_id": {"S": "4"}}, + ] + } + ] + batch_writer_calls = [mock.call(Item=item) for item in response[0]["Items"]] + mock_paginator = mock.Mock() + mock_paginator.paginate.return_value = response + + mock_conn = mock.MagicMock() + mock_client = mock.Mock() + mock_put_item = mock.Mock() + + mock_client.get_paginator.return_value = mock_paginator + mock_conn.meta.client = mock_client + mock_conn.Table.return_value.batch_writer.return_value.__enter__.return_value.put_item = mock_put_item + mock_conn.Table.return_value.table_arn = "arn:aws:dynamodb" + mock_get_conn.return_value = mock_conn + + res = exist_table_op.execute(None) + + new_table_load_mock.assert_called_once_with(context=None, tmp_table=True) + mock_client.get_paginator.assert_called_once_with("scan") + mock_client.get_paginator.return_value.paginate.assert_called_once_with( + TableName=exist_table_op.tmp_table_name, + Select="ALL_ATTRIBUTES", + ReturnConsumedCapacity="NONE", + ConsistentRead=True, + ) + mock_conn.Table.assert_called_with("test-table") + mock_conn.Table.return_value.batch_writer.assert_called_once_with(overwrite_by_pkeys=["attribute_a"]) + mock_put_item.assert_has_calls(batch_writer_calls) + mock_client.delete_table.assert_called_once_with(TableName=exist_table_op.tmp_table_name) + assert res == "arn:aws:dynamodb" diff --git a/tests/providers/amazon/aws/waiters/test_dynamo.py b/tests/providers/amazon/aws/waiters/test_dynamo.py index be94f68081af2a..f68663037868fb 100644 --- a/tests/providers/amazon/aws/waiters/test_dynamo.py +++ b/tests/providers/amazon/aws/waiters/test_dynamo.py @@ -24,11 +24,17 @@ from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook +TEST_IMPORT_ARN = "arn:aws:dynamodb:us-east-1:255683865591:table/test-table/import/01662190284205-aa94decf" + class TestCustomDynamoDBServiceWaiters: - STATUS_COMPLETED = "COMPLETED" - STATUS_FAILED = "FAILED" - STATUS_IN_PROGRESS = "IN_PROGRESS" + EXPORT_STATUS_COMPLETED = "COMPLETED" + EXPORT_STATUS_FAILED = "FAILED" + EXPORT_STATUS_IN_PROGRESS = "IN_PROGRESS" + + IMPORT_STATUS_FAILED = ("CANCELLING", "CANCELLED", "FAILED") + IMPORT_STATUS_COMPLETED = "COMPLETED" + IMPORT_STATUS_IN_PROGRESS = "IN_PROGRESS" @pytest.fixture(autouse=True) def setup_test_cases(self, monkeypatch): @@ -42,9 +48,16 @@ def mock_describe_export(self): with mock.patch.object(self.client, "describe_export") as m: yield m + @pytest.fixture + def mock_describe_import(self): + """Mock ``DynamoDBHook.Client.describe_import`` method.""" + with mock.patch.object(self.client, "describe_import") as m: + yield m + def test_service_waiters(self): hook_waiters = DynamoDBHook(aws_conn_id=None).list_waiters() assert "export_table" in hook_waiters + assert "import_table" in hook_waiters @staticmethod def describe_export(status: str): @@ -54,12 +67,20 @@ def describe_export(status: str): """ return {"ExportDescription": {"ExportStatus": status}} + @staticmethod + def describe_import(status: str): + """ + Helper function for generate minimal DescribeImport response for single job. + https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_DescribeImport.html + """ + return {"ImportTableDescription": {"ImportStatus": status}} + def test_export_table_to_point_in_time_completed(self, mock_describe_export): """Test state transition from `in progress` to `completed` during init.""" waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table") mock_describe_export.side_effect = [ - self.describe_export(self.STATUS_IN_PROGRESS), - self.describe_export(self.STATUS_COMPLETED), + self.describe_export(self.EXPORT_STATUS_IN_PROGRESS), + self.describe_export(self.EXPORT_STATUS_COMPLETED), ] waiter.wait( ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry", @@ -71,8 +92,8 @@ def test_export_table_to_point_in_time_failed(self, mock_describe_export): with mock.patch("boto3.client") as client: client.return_value = self.client mock_describe_export.side_effect = [ - self.describe_export(self.STATUS_IN_PROGRESS), - self.describe_export(self.STATUS_FAILED), + self.describe_export(self.EXPORT_STATUS_IN_PROGRESS), + self.describe_export(self.EXPORT_STATUS_FAILED), ] waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table", client=self.client) with pytest.raises(WaiterError, match='we matched expected path: "FAILED"'): @@ -80,3 +101,31 @@ def test_export_table_to_point_in_time_failed(self, mock_describe_export): ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry", WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}, ) + + def test_import_table_completed(self, mock_describe_import): + waiter = DynamoDBHook(aws_conn_id=None).get_waiter("import_table") + mock_describe_import.side_effect = [ + self.describe_import(self.IMPORT_STATUS_IN_PROGRESS), + self.describe_import(self.IMPORT_STATUS_COMPLETED), + ] + waiter.wait( + ImportArn=TEST_IMPORT_ARN, + WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}, + ) + + @pytest.mark.parametrize( + "status", + [ + pytest.param(IMPORT_STATUS_FAILED[0]), + pytest.param(IMPORT_STATUS_FAILED[1]), + pytest.param(IMPORT_STATUS_FAILED[2]), + ], + ) + def test_import_table_failed(self, status, mock_describe_import): + waiter = DynamoDBHook(aws_conn_id=None).get_waiter("import_table") + mock_describe_import.side_effect = [ + self.describe_import(self.EXPORT_STATUS_IN_PROGRESS), + self.describe_import(status), + ] + with pytest.raises(WaiterError, match=f'we matched expected path: "{status}"'): + waiter.wait(ImportArn=TEST_IMPORT_ARN, WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}) diff --git a/tests/system/providers/amazon/aws/example_s3_to_dynamodb.py b/tests/system/providers/amazon/aws/example_s3_to_dynamodb.py new file mode 100644 index 00000000000000..b415ffad7bdc10 --- /dev/null +++ b/tests/system/providers/amazon/aws/example_s3_to_dynamodb.py @@ -0,0 +1,192 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from datetime import datetime + +import boto3 + +from airflow.decorators import task +from airflow.models.baseoperator import chain +from airflow.models.dag import DAG +from airflow.providers.amazon.aws.operators.s3 import ( + S3CreateBucketOperator, + S3CreateObjectOperator, + S3DeleteBucketOperator, +) +from airflow.providers.amazon.aws.transfers.s3_to_dynamodb import S3ToDynamoDBOperator +from airflow.utils.trigger_rule import TriggerRule +from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder + +log = logging.getLogger(__name__) + +DAG_ID = "example_s3_to_dynamodb" + +sys_test_context_task = SystemTestContextBuilder().build() + +TABLE_ATTRIBUTES = [ + {"AttributeName": "cocktail_id", "AttributeType": "S"}, +] +TABLE_KEY_SCHEMA = [ + {"AttributeName": "cocktail_id", "KeyType": "HASH"}, +] +TABLE_THROUGHPUT = {"ReadCapacityUnits": 1, "WriteCapacityUnits": 1} + +SAMPLE_DATA = r"""cocktail_id,cocktail_name,base_spirit +1,Caipirinha,Cachaca +2,Bramble,Gin +3,Daiquiri,Rum +""" + + +@task +def set_up_table(table_name: str): + dynamo_resource = boto3.resource("dynamodb") + dynamo_resource.create_table( + AttributeDefinitions=TABLE_ATTRIBUTES, + TableName=table_name, + KeySchema=TABLE_KEY_SCHEMA, + ProvisionedThroughput=TABLE_THROUGHPUT, + ) + boto3.client("dynamodb").get_waiter("table_exists").wait( + TableName=table_name, WaiterConfig={"Delay": 10, "MaxAttempts": 10} + ) + + +@task +def wait_for_bucket(s3_bucket_name): + waiter = boto3.client("s3").get_waiter("bucket_exists") + waiter.wait(Bucket=s3_bucket_name) + + +@task(trigger_rule=TriggerRule.ALL_DONE) +def delete_dynamodb_table(table_name: str): + boto3.resource("dynamodb").Table(table_name).delete() + boto3.client("dynamodb").get_waiter("table_not_exists").wait( + TableName=table_name, WaiterConfig={"Delay": 10, "MaxAttempts": 10} + ) + + +with DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example"], +) as dag: + test_context = sys_test_context_task() + env_id = test_context[ENV_ID_KEY] + existing_table_name = f"{env_id}-ex-dynamodb-table" + new_table_name = f"{env_id}-new-dynamodb-table" + bucket_name = f"{env_id}-dynamodb-bucket" + s3_key = f"{env_id}/files/cocktail_list.csv" + + create_table = set_up_table(table_name=existing_table_name) + + create_bucket = S3CreateBucketOperator(task_id="create_bucket", bucket_name=bucket_name) + + create_object = S3CreateObjectOperator( + task_id="create_object", + s3_bucket=bucket_name, + s3_key=s3_key, + data=SAMPLE_DATA, + replace=True, + ) + + # [START howto_transfer_s3_to_dynamodb] + transfer_1 = S3ToDynamoDBOperator( + task_id="s3_to_dynamodb", + s3_bucket=bucket_name, + s3_key=s3_key, + dynamodb_table_name=new_table_name, + input_format="CSV", + import_table_kwargs={ + "InputFormatOptions": { + "Csv": { + "Delimiter": ",", + } + } + }, + dynamodb_attributes=[ + {"AttributeName": "cocktail_id", "AttributeType": "S"}, + ], + dynamodb_key_schema=[ + {"AttributeName": "cocktail_id", "KeyType": "HASH"}, + ], + ) + # [END howto_transfer_s3_to_dynamodb] + + # [START howto_transfer_s3_to_dynamodb_existing_table] + transfer_2 = S3ToDynamoDBOperator( + task_id="s3_to_dynamodb_new_table", + s3_bucket=bucket_name, + s3_key=s3_key, + dynamodb_table_name=existing_table_name, + use_existing_table=True, + input_format="CSV", + import_table_kwargs={ + "InputFormatOptions": { + "Csv": { + "Delimiter": ",", + } + } + }, + dynamodb_attributes=[ + {"AttributeName": "cocktail_id", "AttributeType": "S"}, + ], + dynamodb_key_schema=[ + {"AttributeName": "cocktail_id", "KeyType": "HASH"}, + ], + ) + # [END howto_transfer_s3_to_dynamodb_existing_table] + + delete_existing_table = delete_dynamodb_table(table_name=existing_table_name) + delete_new_table = delete_dynamodb_table(table_name=new_table_name) + + delete_bucket = S3DeleteBucketOperator( + task_id="delete_bucket", + bucket_name=bucket_name, + trigger_rule=TriggerRule.ALL_DONE, + force_delete=True, + ) + + chain( + # TEST SETUP + test_context, + create_table, + create_bucket, + wait_for_bucket(s3_bucket_name=bucket_name), + create_object, + # TEST BODY + transfer_1, + transfer_2, + # TEST TEARDOWN + delete_existing_table, + delete_new_table, + delete_bucket, + ) + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)