Skip to content

Commit

Permalink
feat: implement amazon s3 to dynamodb transfer operator
Browse files Browse the repository at this point in the history
  • Loading branch information
dondaum committed May 15, 2024
1 parent 9284dc5 commit 513a641
Show file tree
Hide file tree
Showing 9 changed files with 994 additions and 9 deletions.
35 changes: 34 additions & 1 deletion airflow/providers/amazon/aws/hooks/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@

from __future__ import annotations

from typing import Iterable
from functools import cached_property
from typing import TYPE_CHECKING, Iterable

from botocore.exceptions import ClientError

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook

if TYPE_CHECKING:
from botocore.client import BaseClient


class DynamoDBHook(AwsBaseHook):
"""
Expand All @@ -50,6 +56,11 @@ def __init__(
kwargs["resource_type"] = "dynamodb"
super().__init__(*args, **kwargs)

@cached_property
def client(self) -> BaseClient:
"""Return boto3 client."""
return self.get_conn().meta.client

def write_batch_data(self, items: Iterable) -> bool:
"""
Write batch items to DynamoDB table with provisioned throughout capacity.
Expand All @@ -70,3 +81,25 @@ def write_batch_data(self, items: Iterable) -> bool:
return True
except Exception as general_error:
raise AirflowException(f"Failed to insert items in dynamodb, error: {general_error}")

def get_import_status(self, import_arn: str) -> tuple[str, str | None, str | None]:
"""
Get import status from Dynamodb.
:param import_arn: The Amazon Resource Name (ARN) for the import.
:return: Import status, Error code and Error message
"""
self.log.info("Poking for Dynamodb import %s", import_arn)

try:
describe_import = self.client.describe_import(ImportArn=import_arn)
status = describe_import["ImportTableDescription"]["ImportStatus"]
error_code = describe_import["ImportTableDescription"].get("FailureCode")
error_msg = describe_import["ImportTableDescription"].get("FailureMessage")
return status, error_code, error_msg
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code")
if error_code == "ImportNotFoundException":
raise AirflowException("S3 import into Dynamodb job not found.")
else:
raise e
313 changes: 313 additions & 0 deletions airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal, Sequence, TypedDict

from botocore.exceptions import ClientError, WaiterError

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook

if TYPE_CHECKING:
from airflow.utils.context import Context


class AttributeDefinition(TypedDict):
"""Attribute Definition Type."""

AttributeName: str
AttributeType: Literal["S", "N", "B"]


class KeySchema(TypedDict):
"""Key Schema Type."""

AttributeName: str
KeyType: Literal["HASH", "RANGE"]


class ProvisionedThroughput(TypedDict):
"""Provisioned Throughput Type."""

ReadCapacityUnits: int
WriteCapacityUnits: int


class OnDemandThroughput(TypedDict):
"""OnDemand Throughput Type."""

MaxReadRequestUnits: int
MaxWriteRequestUnits: int


class SSESpecification(TypedDict):
"""SSE Specification Type."""

Enabled: bool
SSEType: Literal["AES256", "KMS"]
KMSMasterKeyId: str


class Projection(TypedDict):
"""Projection Type."""

ProjectionType: Literal["ALL", "KEYS_ONLY", "INCLUDE"]
SSEType: Literal["AES256", "KMS"]
NonKeyAttributes: list[str]


class GlobalSecondaryIndexes(TypedDict):
"""Global Secondary Indexes Type."""

IndexName: str
KeySchema: KeySchema
Projection: Projection
ProvisionedThroughput: ProvisionedThroughput
OnDemandThroughput: OnDemandThroughput


class S3ToDynamoDBOperator(BaseOperator):
"""Load Data from S3 into a DynamoDB.
Data stored in S3 can be uploaded to a new or existing DynamoDB. Supported file formats CSV, DynamoDB JSON and
Amazon ION.
:param s3_bucket: The S3 bucket that is imported
:param s3_key: key prefix that imports single or multiple objects from S3
:param dynamodb_table_name: Name of the table that shall be created
:param dynamodb_key_schema: Primary key and sort key. Each element represents one primary key
attribute. AttributeName is the name of the attribute. KeyType is the role for the attribute. Valid values
HASH or RANGE
:param dynamodb_attributes: Name of the attributes of a table. AttributeName is the name for the attribute
AttributeType is the data type for the attribute. Valid values for AttributeType are
S - attribute is of type String
N - attribute is of type Number
B - attribute is of type Binary
:param dynamodb_tmp_table_prefix: Prefix for the temporary DynamoDB table
:param delete_dynamodb_tmp_table: If set, deletes the temporary DynamoDB table that is used for staging
:param use_existing_table: Whether to import to an existing non new DynamoDB table. If set to
true data is loaded first into a temporary DynamoDB table, then retrieved as chunks into memory and loaded
into the target table
:param client_token: ClientToken can be used for idempotent calls that have the same effect as one call.
A token is valid for 8 hours after the first request. Afterwards nay new request with the same token is
considered as a new request
:param input_format: The format for the imported data. Valid values for InputFormat are CSV, DYNAMODB_JSON
or ION
:param input_format_options: Extra options on how the input is formatted such as CSV delimiter options
:param input_compression: Compression options. Valid values for InputCompressionType are NONE, GZIP,
or ZSTD
:param billing_mode: Billing mode for the table. Valid values are PROVISIONED or PAY_PER_REQUEST
:param provisioned_throughput: Extra options for provisioned throughput settings
:param on_demand_throughput: Extra options for maximum number of read and write units
:param sse_specification: Extra options for server-side encryption
:param global_secondary_indexes: Extra options for server-side encryption
:param wait_for_completion: Whether to wait for cluster to stop
:param check_interval: Time in seconds to wait between status checks
:param max_attempts: Maximum number of attempts to check for job completion
:param aws_conn_id: The reference to the AWS connection details
"""

template_fields: Sequence[str] = (
"s3_bucket",
"s3_key",
"dynamodb_table_name",
"dynamodb_key_schema",
"dynamodb_attributes",
"global_secondary_indexes",
)
ui_color = "#e2e8f0"

def __init__(
self,
*,
s3_bucket: str,
s3_key: str,
dynamodb_table_name: str,
dynamodb_key_schema: list[KeySchema],
dynamodb_attributes: list[AttributeDefinition] | None = None,
dynamodb_tmp_table_prefix: str = "tmp",
delete_dynamodb_tmp_table: bool = True,
use_existing_table: bool = False,
client_token: str | None = None,
input_format: Literal["CSV", "DYNAMODB_JSON", "ION"] = "DYNAMODB_JSON",
input_format_options: dict | None = None,
input_compression: Literal["GZIP", "ZSTD", "NONE"] = "NONE",
billing_mode: Literal["PROVISIONED", "PAY_PER_REQUEST"] = "PAY_PER_REQUEST",
provisioned_throughput: ProvisionedThroughput | dict | None = None,
sse_specification: SSESpecification | dict | None = None,
global_secondary_indexes: list[GlobalSecondaryIndexes] | None = None,
wait_for_completion: bool = True,
check_interval: int = 30,
max_attempts: int = 240,
aws_conn_id: str | None = "aws_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
self.s3_bucket = s3_bucket
self.s3_key = s3_key
self.dynamodb_table_name = dynamodb_table_name
self.dynamodb_attributes = dynamodb_attributes
self.dynamodb_tmp_table_prefix = dynamodb_tmp_table_prefix
self.delete_dynamodb_tmp_table = delete_dynamodb_tmp_table
self.use_existing_table = use_existing_table
self.dynamodb_key_schema = dynamodb_key_schema
self.client_token = client_token
self.input_format = input_format
self.input_format_options = input_format_options
self.input_compression = input_compression
self.billing_moode = billing_mode
self.provisioned_throughput = provisioned_throughput
self.sse_specification = sse_specification
self.global_secondary_indexes = global_secondary_indexes
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_attempts = max_attempts
self.aws_conn_id = aws_conn_id

@property
def tmp_table_name(self):
"""Temporary table name."""
return f"{self.dynamodb_tmp_table_prefix}_{self.dynamodb_table_name}"

def _load_into_new_table(self, context: Context, tmp_table: bool = False) -> str:
"""
Import S3 key or keys into a new DynamoDB table.
:param context: the current context of the task instance
:param tmp_table: creates DynamoDB Table with tmp prefix
:return: the Amazon resource number (ARN)
"""
dynamodb_hook = DynamoDBHook(aws_conn_id=self.aws_conn_id)
client = dynamodb_hook.client
table_name = self.tmp_table_name if tmp_table else self.dynamodb_table_name

import_table_config: dict[str, Any] = {}
if self.client_token:
import_table_config["ClientToken"] = self.client_token
if self.input_compression:
import_table_config["InputCompressionType"] = self.input_compression
if self.input_format_options:
import_table_config["InputFormatOptions"] = self.input_format_options

import_table_creation_config: dict[str, Any] = {}
if self.provisioned_throughput:
import_table_creation_config["ProvisionedThroughput"] = self.provisioned_throughput
if self.sse_specification:
import_table_creation_config["SSESpecification"] = self.sse_specification
if self.global_secondary_indexes:
import_table_creation_config["GlobalSecondaryIndexes"] = self.global_secondary_indexes
try:
response = client.import_table(
S3BucketSource={
"S3Bucket": self.s3_bucket,
"S3KeyPrefix": self.s3_key,
},
InputFormat=self.input_format,
TableCreationParameters={
"TableName": table_name,
"AttributeDefinitions": self.dynamodb_attributes,
"KeySchema": self.dynamodb_key_schema,
"BillingMode": self.billing_moode,
**import_table_creation_config,
},
**import_table_config,
)
except ClientError as e:
self.log.error("Error: failed to load from S3 into DynamoDB table. Error: %s", str(e))
raise AirflowException(f"S3 load into DynamoDB table failed with error: {e}")

if response["ImportTableDescription"]["ImportStatus"] == "FAILED":
raise AirflowException(
"S3 into Dynamodb job creation failed. Code: "
f"{response['ImportTableDescription']['FailureCode']}. "
f"Failure: {response['ImportTableDescription']['FailureMessage']}"
)

if self.wait_for_completion:
self.log.info("Waiting for S3 into Dynamodb job to complete")
waiter = dynamodb_hook.get_waiter("import_table")
try:
waiter.wait(
ImportArn=response["ImportTableDescription"]["ImportArn"],
WaiterConfig={"Delay": self.check_interval, "MaxAttempts": self.max_attempts},
)
except WaiterError:
status, error_code, error_msg = dynamodb_hook.get_import_status(
response["ImportTableDescription"]["ImportArn"]
)
client.delete_table(TableName=table_name)
raise AirflowException(
f"S3 import into Dynamodb job failed: Status: {status}. Error: {error_code}. Error message: {error_msg}"
)
return response["ImportTableDescription"]["ImportArn"]

def _load_into_existing_table(self, context: Context) -> str:
"""
Import S3 key or keys in an existing DynamoDB table.
:param context: the current context of the task instance
:return: the Amazon resource number (ARN)
"""
if not self.wait_for_completion:
raise ValueError("wait_for_completion must be set to True when loading into an existing table")
table_keys = [key["AttributeName"] for key in self.dynamodb_key_schema]

dynamodb_hook = DynamoDBHook(
aws_conn_id=self.aws_conn_id, table_name=self.dynamodb_table_name, table_keys=table_keys
)
client = dynamodb_hook.client

self.log.info("Loading from S3 into a tmp DynamoDB table %s", self.tmp_table_name)
self._load_into_new_table(context=context, tmp_table=True)
total_items = 0
try:
paginator = client.get_paginator("scan")
paginate = paginator.paginate(
TableName=self.tmp_table_name,
Select="ALL_ATTRIBUTES",
ReturnConsumedCapacity="NONE",
ConsistentRead=True,
)
self.log.info(
"Loading data from %s to %s DynamoDB table", self.tmp_table_name, self.dynamodb_table_name
)
for page in paginate:
total_items += page.get("Count", 0)
dynamodb_hook.write_batch_data(items=page["Items"])
self.log.info("Number of items loaded: %s", total_items)
finally:
if self.delete_dynamodb_tmp_table:
self.log.info("Delete tmp DynamoDB table %s", self.tmp_table_name)
client.delete_table(TableName=self.tmp_table_name)
return dynamodb_hook.get_conn().Table(self.dynamodb_table_name).table_arn

def execute(self, context: Context) -> str:
"""
Execute S3 to DynamoDB Job from Airflow.
:param context: the current context of the task instance
:return: the Amazon resource number (ARN)
"""
if self.use_existing_table:
self.log.info("Loading from S3 into new DynamoDB table %s", self.dynamodb_table_name)
return self._load_into_existing_table(context=context)
self.log.info("Loading from S3 into existing DynamoDB table %s", self.dynamodb_table_name)
return self._load_into_new_table(context=context)
Loading

0 comments on commit 513a641

Please sign in to comment.