diff --git a/airbyte-config/init/src/main/resources/icons/awsdatalake.svg b/airbyte-config/init/src/main/resources/icons/awsdatalake.svg
new file mode 100644
index 000000000000..5ea8d6d6aab5
--- /dev/null
+++ b/airbyte-config/init/src/main/resources/icons/awsdatalake.svg
@@ -0,0 +1,10 @@
+
diff --git a/airbyte-config/init/src/main/resources/seed/destination_definitions.yaml b/airbyte-config/init/src/main/resources/seed/destination_definitions.yaml
index 3238f29c174b..9a714a1ca387 100644
--- a/airbyte-config/init/src/main/resources/seed/destination_definitions.yaml
+++ b/airbyte-config/init/src/main/resources/seed/destination_definitions.yaml
@@ -34,8 +34,9 @@
- name: AWS Datalake
destinationDefinitionId: 99878c90-0fbd-46d3-9d98-ffde879d17fc
dockerRepository: airbyte/destination-aws-datalake
- dockerImageTag: 0.1.1
+ dockerImageTag: 0.1.2
documentationUrl: https://docs.airbyte.com/integrations/destinations/aws-datalake
+ icon: awsdatalake.svg
releaseStage: alpha
- name: BigQuery
destinationDefinitionId: 22f6c74f-5699-40ff-833c-4a879ea40133
diff --git a/airbyte-config/init/src/main/resources/seed/destination_specs.yaml b/airbyte-config/init/src/main/resources/seed/destination_specs.yaml
index 6a56f7ff7b6c..f213f702217f 100644
--- a/airbyte-config/init/src/main/resources/seed/destination_specs.yaml
+++ b/airbyte-config/init/src/main/resources/seed/destination_specs.yaml
@@ -533,7 +533,7 @@
supported_destination_sync_modes:
- "overwrite"
- "append"
-- dockerImage: "airbyte/destination-aws-datalake:0.1.1"
+- dockerImage: "airbyte/destination-aws-datalake:0.1.2"
spec:
documentationUrl: "https://docs.airbyte.com/integrations/destinations/aws-datalake"
connectionSpecification:
@@ -544,7 +544,7 @@
- "credentials"
- "region"
- "bucket_name"
- - "bucket_prefix"
+ - "lakeformation_database_name"
additionalProperties: false
properties:
aws_account_id:
@@ -553,11 +553,7 @@
description: "target aws account id"
examples:
- "111111111111"
- region:
- title: "AWS Region"
- type: "string"
- description: "Region name"
- airbyte_secret: false
+ order: 1
credentials:
title: "Authentication mode"
description: "Choose How to Authenticate to AWS."
@@ -609,21 +605,145 @@
type: "string"
description: "Secret Access Key"
airbyte_secret: true
+ order: 2
+ region:
+ title: "S3 Bucket Region"
+ type: "string"
+ default: ""
+ description: "The region of the S3 bucket. See here for all region codes."
+ enum:
+ - ""
+ - "us-east-1"
+ - "us-east-2"
+ - "us-west-1"
+ - "us-west-2"
+ - "af-south-1"
+ - "ap-east-1"
+ - "ap-south-1"
+ - "ap-northeast-1"
+ - "ap-northeast-2"
+ - "ap-northeast-3"
+ - "ap-southeast-1"
+ - "ap-southeast-2"
+ - "ca-central-1"
+ - "cn-north-1"
+ - "cn-northwest-1"
+ - "eu-central-1"
+ - "eu-north-1"
+ - "eu-south-1"
+ - "eu-west-1"
+ - "eu-west-2"
+ - "eu-west-3"
+ - "sa-east-1"
+ - "me-south-1"
+ - "us-gov-east-1"
+ - "us-gov-west-1"
+ order: 3
bucket_name:
title: "S3 Bucket Name"
type: "string"
- description: "Name of the bucket"
- airbyte_secret: false
+ description: "The name of the S3 bucket. Read more here."
+ order: 4
bucket_prefix:
title: "Target S3 Bucket Prefix"
type: "string"
description: "S3 prefix"
- airbyte_secret: false
+ order: 5
lakeformation_database_name:
- title: "Lakeformation Database Name"
+ title: "Lake Formation Database Name"
type: "string"
- description: "Which database to use"
- airbyte_secret: false
+ description: "The default database this destination will use to create tables\
+ \ in per stream. Can be changed per connection by customizing the namespace."
+ order: 6
+ lakeformation_database_default_tag_key:
+ title: "Lake Formation Database Tag Key"
+ description: "Add a default tag key to databases created by this destination"
+ examples:
+ - "pii_level"
+ type: "string"
+ order: 7
+ lakeformation_database_default_tag_values:
+ title: "Lake Formation Database Tag Values"
+ description: "Add default values for the `Tag Key` to databases created\
+ \ by this destination. Comma separate for multiple values."
+ examples:
+ - "private,public"
+ type: "string"
+ order: 8
+ lakeformation_governed_tables:
+ title: "Lake Formation Governed Tables"
+ description: "Whether to create tables as LF governed tables."
+ type: "boolean"
+ default: false
+ order: 9
+ format:
+ title: "Output Format *"
+ type: "object"
+ description: "Format of the data output."
+ oneOf:
+ - title: "JSON Lines: Newline-delimited JSON"
+ required:
+ - "format_type"
+ properties:
+ format_type:
+ title: "Format Type *"
+ type: "string"
+ enum:
+ - "JSONL"
+ default: "JSONL"
+ compression_codec:
+ title: "Compression Codec (Optional)"
+ description: "The compression algorithm used to compress data."
+ type: "string"
+ enum:
+ - "UNCOMPRESSED"
+ - "GZIP"
+ default: "UNCOMPRESSED"
+ - title: "Parquet: Columnar Storage"
+ required:
+ - "format_type"
+ properties:
+ format_type:
+ title: "Format Type *"
+ type: "string"
+ enum:
+ - "Parquet"
+ default: "Parquet"
+ compression_codec:
+ title: "Compression Codec (Optional)"
+ description: "The compression algorithm used to compress data."
+ type: "string"
+ enum:
+ - "UNCOMPRESSED"
+ - "SNAPPY"
+ - "GZIP"
+ - "ZSTD"
+ default: "SNAPPY"
+ order: 10
+ partitioning:
+ title: "Choose how to partition data"
+ description: "Partition data by cursor fields when a cursor field is a date"
+ type: "string"
+ enum:
+ - "NO PARTITIONING"
+ - "DATE"
+ - "YEAR"
+ - "MONTH"
+ - "DAY"
+ - "YEAR/MONTH"
+ - "YEAR/MONTH/DAY"
+ default: "NO PARTITIONING"
+ order: 11
+ glue_catalog_float_as_decimal:
+ title: "Glue Catalog: Float as Decimal"
+ description: "Cast float/double as decimal(38,18). This can help achieve\
+ \ higher accuracy and represent numbers correctly as received from the\
+ \ source."
+ type: "boolean"
+ default: false
+ order: 12
supportsIncremental: true
supportsNormalization: false
supportsDBT: false
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/BOOTSTRAP.md b/airbyte-integrations/connectors/destination-aws-datalake/BOOTSTRAP.md
new file mode 100644
index 000000000000..16998e8a3326
--- /dev/null
+++ b/airbyte-integrations/connectors/destination-aws-datalake/BOOTSTRAP.md
@@ -0,0 +1,5 @@
+# AWS Lake Formation Destination Connector Bootstrap
+
+This destination syncs your data to s3 and aws data lake and will automatically create a glue catalog databases and tables for you.
+
+See [this](https://docs.aws.amazon.com/lake-formation/latest/dg/how-it-works.html) to learn more about AWS Lake Formation.
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/Dockerfile b/airbyte-integrations/connectors/destination-aws-datalake/Dockerfile
index ebe2844b7244..a036661f7da7 100644
--- a/airbyte-integrations/connectors/destination-aws-datalake/Dockerfile
+++ b/airbyte-integrations/connectors/destination-aws-datalake/Dockerfile
@@ -13,5 +13,5 @@ RUN pip install .
ENV AIRBYTE_ENTRYPOINT "python /airbyte/integration_code/main.py"
ENTRYPOINT ["python", "/airbyte/integration_code/main.py"]
-LABEL io.airbyte.version=0.1.1
+LABEL io.airbyte.version=0.1.2
LABEL io.airbyte.name=airbyte/destination-aws-datalake
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/README.md b/airbyte-integrations/connectors/destination-aws-datalake/README.md
index 81a16e03a8ad..e4e7f1d858b2 100644
--- a/airbyte-integrations/connectors/destination-aws-datalake/README.md
+++ b/airbyte-integrations/connectors/destination-aws-datalake/README.md
@@ -140,23 +140,28 @@ To run acceptance and custom integration tests:
./gradlew :airbyte-integrations:connectors:destination-aws-datalake:integrationTest
```
-#### Running the Destination Acceptance Tests
+#### Running the Destination Integration Tests
To successfully run the Destination Acceptance Tests, you need a `secrets/config.json` file with appropriate information. For example:
```json
{
- "bucket_name": "your-bucket-name",
- "bucket_prefix": "your-prefix",
- "region": "your-region",
- "aws_account_id": "111111111111",
- "lakeformation_database_name": "an_lf_database",
- "credentials": {
- "credentials_title": "IAM User",
- "aws_access_key_id": ".....",
- "aws_secret_access_key": "....."
- }
+ "aws_account_id": "111111111111",
+ "credentials": {
+ "credentials_title": "IAM User",
+ "aws_access_key_id": "aws_key_id",
+ "aws_secret_access_key": "aws_secret_key"
+ },
+ "region": "us-east-1",
+ "bucket_name": "datalake-bucket",
+ "lakeformation_database_name": "test",
+ "format": {
+ "format_type": "Parquet",
+ "compression_codec": "SNAPPY"
+ },
+ "partitioning": "NO PARTITIONING"
}
+
```
In the AWS account, you need to have the following elements in place:
@@ -167,7 +172,6 @@ In the AWS account, you need to have the following elements in place:
* The user must have appropriate permissions to the Lake Formation database to perform the tests (For example see: [Granting Database Permissions Using the Lake Formation Console and the Named Resource Method](https://docs.aws.amazon.com/lake-formation/latest/dg/granting-database-permissions.html))
-
## Dependency Management
All of your dependencies should go in `setup.py`, NOT `requirements.txt`. The requirements file is only used to connect internal Airbyte dependencies in the monorepo for local development.
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/aws.py b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/aws.py
index c03b8305e5d7..1e1a679938d7 100644
--- a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/aws.py
+++ b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/aws.py
@@ -2,42 +2,95 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-import json
+import logging
+from decimal import Decimal
+from typing import Dict, Optional
+import awswrangler as wr
import boto3
+import pandas as pd
from airbyte_cdk.destinations import Destination
+from awswrangler import _data_types
from botocore.exceptions import ClientError
from retrying import retry
-from .config_reader import AuthMode, ConnectorConfig
+from .config_reader import CompressionCodec, ConnectorConfig, CredentialsType, OutputFormat
+
+logger = logging.getLogger("airbyte")
+
+null_values = ["", " ", "#N/A", "#N/A N/A", "#NA", "", "N/A", "NA", "NULL", "none", "None", "NaN", "n/a", "nan", "null"]
+
+
+def _cast_pandas_column(df: pd.DataFrame, col: str, current_type: str, desired_type: str) -> pd.DataFrame:
+ if desired_type == "datetime64":
+ df[col] = pd.to_datetime(df[col])
+ elif desired_type == "date":
+ df[col] = df[col].apply(lambda x: _data_types._cast2date(value=x)).replace(to_replace={pd.NaT: None})
+ elif desired_type == "bytes":
+ df[col] = df[col].astype("string").str.encode(encoding="utf-8").replace(to_replace={pd.NA: None})
+ elif desired_type == "decimal":
+ # First cast to string
+ df = _cast_pandas_column(df=df, col=col, current_type=current_type, desired_type="string")
+ # Then cast to decimal
+ df[col] = df[col].apply(lambda x: Decimal(str(x)) if str(x) not in null_values else None)
+ elif desired_type.lower() in ["float64", "int64"]:
+ df[col] = df[col].fillna("")
+ df[col] = pd.to_numeric(df[col])
+ elif desired_type in ["boolean", "bool"]:
+ if df[col].dtype in ["string", "O"]:
+ df[col] = df[col].fillna("false").apply(lambda x: str(x).lower() in ["true", "1", "1.0", "t", "y", "yes"])
+
+ df[col] = df[col].astype(bool)
+ else:
+ try:
+ df[col] = df[col].astype(desired_type)
+ except (TypeError, ValueError) as ex:
+ if "object cannot be converted to an IntegerDtype" not in str(ex):
+ raise ex
+ logger.warn(
+ "Object cannot be converted to an IntegerDtype. Integer columns in Python cannot contain "
+ "missing values. If your input data contains missing values, it will be encoded as floats"
+ "which may cause precision loss.",
+ UserWarning,
+ )
+ df[col] = df[col].apply(lambda x: int(x) if str(x) not in null_values else None).astype(desired_type)
+ return df
-class AwsHandler:
- COLUMNS_MAPPING = {"number": "float", "string": "string", "integer": "int"}
+# Overwrite to fix type conversion issues from athena to pandas
+# These happen when appending data to an existing table. awswrangler
+# tries to cast the data types to the existing table schema, examples include:
+# Fixes: ValueError: could not convert string to float: ''
+# Fixes: TypeError: Need to pass bool-like values
+_data_types._cast_pandas_column = _cast_pandas_column
+
- def __init__(self, connector_config, destination: Destination):
- self._connector_config: ConnectorConfig = connector_config
+class AwsHandler:
+ def __init__(self, connector_config: ConnectorConfig, destination: Destination):
+ self._config: ConnectorConfig = connector_config
self._destination: Destination = destination
- self._bucket_name = connector_config.bucket_name
- self.logger = self._destination.logger
+ self._session: boto3.Session = None
self.create_session()
- self.s3_client = self.session.client("s3", region_name=connector_config.region)
- self.glue_client = self.session.client("glue")
- self.lf_client = self.session.client("lakeformation")
+ self.glue_client = self._session.client("glue")
+ self.s3_client = self._session.client("s3")
+ self.lf_client = self._session.client("lakeformation")
+
+ self._table_type = "GOVERNED" if self._config.lakeformation_governed_tables else "EXTERNAL_TABLE"
@retry(stop_max_attempt_number=10, wait_random_min=1000, wait_random_max=2000)
def create_session(self):
- if self._connector_config.credentials_type == AuthMode.IAM_USER.value:
+ if self._config.credentials_type == CredentialsType.IAM_USER:
self._session = boto3.Session(
- aws_access_key_id=self._connector_config.aws_access_key,
- aws_secret_access_key=self._connector_config.aws_secret_key,
- region_name=self._connector_config.region,
+ aws_access_key_id=self._config.aws_access_key,
+ aws_secret_access_key=self._config.aws_secret_key,
+ region_name=self._config.region,
)
- elif self._connector_config.credentials_type == AuthMode.IAM_ROLE.value:
+
+ elif self._config.credentials_type == CredentialsType.IAM_ROLE:
client = boto3.client("sts")
role = client.assume_role(
- RoleArn=self._connector_config.role_arn,
+ RoleArn=self._config.role_arn,
RoleSessionName="airbyte-destination-aws-datalake",
)
creds = role.get("Credentials", {})
@@ -45,244 +98,164 @@ def create_session(self):
aws_access_key_id=creds.get("AccessKeyId"),
aws_secret_access_key=creds.get("SecretAccessKey"),
aws_session_token=creds.get("SessionToken"),
- region_name=self._connector_config.region,
+ region_name=self._config.region,
)
+
+ def _get_s3_path(self, database: str, table: str) -> str:
+ bucket = f"s3://{self._config.bucket_name}"
+ if self._config.bucket_prefix:
+ bucket += f"/{self._config.bucket_prefix}"
+
+ return f"{bucket}/{database}/{table}/"
+
+ def _get_compression_type(self, compression: CompressionCodec):
+ if compression == CompressionCodec.GZIP:
+ return "gzip"
+ elif compression == CompressionCodec.SNAPPY:
+ return "snappy"
+ elif compression == CompressionCodec.ZSTD:
+ return "zstd"
else:
- raise Exception("Session wasn't created")
+ return None
- @property
- def session(self) -> boto3.Session:
- return self._session
+ def _write_parquet(
+ self,
+ df: pd.DataFrame,
+ path: str,
+ database: str,
+ table: str,
+ mode: str,
+ dtype: Optional[Dict[str, str]],
+ partition_cols: list = None,
+ ):
+ return wr.s3.to_parquet(
+ df=df,
+ path=path,
+ dataset=True,
+ database=database,
+ table=table,
+ table_type=self._table_type,
+ mode=mode,
+ use_threads=False, # True causes s3 NoCredentialsError error
+ catalog_versioning=True,
+ boto3_session=self._session,
+ partition_cols=partition_cols,
+ compression=self._get_compression_type(self._config.compression_codec),
+ dtype=dtype,
+ )
- @retry(stop_max_attempt_number=10, wait_random_min=2000, wait_random_max=3000)
- def head_bucket(self):
- self.s3_client.head_bucket(Bucket=self._bucket_name)
+ def _write_json(
+ self,
+ df: pd.DataFrame,
+ path: str,
+ database: str,
+ table: str,
+ mode: str,
+ dtype: Optional[Dict[str, str]],
+ partition_cols: list = None,
+ ):
+ return wr.s3.to_json(
+ df=df,
+ path=path,
+ dataset=True,
+ database=database,
+ table=table,
+ table_type=self._table_type,
+ mode=mode,
+ use_threads=False, # True causes s3 NoCredentialsError error
+ orient="records",
+ lines=True,
+ catalog_versioning=True,
+ boto3_session=self._session,
+ partition_cols=partition_cols,
+ dtype=dtype,
+ compression=self._get_compression_type(self._config.compression_codec),
+ )
- @retry(stop_max_attempt_number=10, wait_random_min=2000, wait_random_max=3000)
- def head_object(self, object_key):
- return self.s3_client.head_object(Bucket=self._bucket_name, Key=object_key)
+ def _write(self, df: pd.DataFrame, path: str, database: str, table: str, mode: str, dtype: Dict[str, str], partition_cols: list = None):
+ self._create_database_if_not_exists(database)
- @retry(stop_max_attempt_number=10, wait_random_min=2000, wait_random_max=3000)
- def put_object(self, object_key, body):
- self.s3_client.put_object(Bucket=self._bucket_name, Key=object_key, Body="\n".join(body))
+ if self._config.format_type == OutputFormat.JSONL:
+ return self._write_json(df, path, database, table, mode, dtype, partition_cols)
- @staticmethod
- def batch_iterate(iterable, n=1):
- size = len(iterable)
- for ndx in range(0, size, n):
- yield iterable[ndx : min(ndx + n, size)]
+ elif self._config.format_type == OutputFormat.PARQUET:
+ return self._write_parquet(df, path, database, table, mode, dtype, partition_cols)
- def get_table(self, txid, database_name: str, table_name: str, location: str):
- table = None
- try:
- table = self.glue_client.get_table(DatabaseName=database_name, Name=table_name, TransactionId=txid)
- except ClientError as e:
- if e.response["Error"]["Code"] == "EntityNotFoundException":
- table_input = {
- "Name": table_name,
- "TableType": "GOVERNED",
- "StorageDescriptor": {
- "Location": location,
- "InputFormat": "org.apache.hadoop.mapred.TextInputFormat",
- "OutputFormat": "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat",
- "SerdeInfo": {"SerializationLibrary": "org.openx.data.jsonserde.JsonSerDe", "Parameters": {"paths": ","}},
- },
- "PartitionKeys": [],
- "Parameters": {"classification": "json", "lakeformation.aso.status": "true"},
- }
- self.glue_client.create_table(DatabaseName=database_name, TableInput=table_input, TransactionId=txid)
- table = self.glue_client.get_table(DatabaseName=database_name, Name=table_name, TransactionId=txid)
- else:
- err = e.response["Error"]["Code"]
- self.logger.error(f"An error occurred: {err}")
- raise
-
- if table:
- return table
else:
- return None
+ raise Exception(f"Unsupported output format: {self._config.format_type}")
- def update_table(self, database, table_info, transaction_id):
- self.glue_client.update_table(DatabaseName=database, TableInput=table_info, TransactionId=transaction_id)
+ def _create_database_if_not_exists(self, database: str):
+ tag_key = self._config.lakeformation_database_default_tag_key
+ tag_values = self._config.lakeformation_database_default_tag_values
- def preprocess_type(self, property_type):
- if type(property_type) is list:
- not_null_types = list(filter(lambda t: t != "null", property_type))
- if len(not_null_types) > 2:
- return "string"
- else:
- return not_null_types[0]
- else:
- return property_type
-
- def cast_to_athena(self, str_type):
- preprocessed_type = self.preprocess_type(str_type)
- return self.COLUMNS_MAPPING.get(preprocessed_type, preprocessed_type)
-
- def generate_athena_schema(self, schema):
- columns = []
- for (k, v) in schema.items():
- athena_type = self.cast_to_athena(v["type"])
- if athena_type == "object":
- properties = v["properties"]
- type_str = ",".join([f"{k1}:{self.cast_to_athena(v1['type'])}" for (k1, v1) in properties.items()])
- columns.append({"Name": k, "Type": f"struct<{type_str}>"})
- else:
- columns.append({"Name": k, "Type": athena_type})
- return columns
-
- def update_table_schema(self, txid, database, table, schema):
- table_info = table["Table"]
- table_info_keys = list(table_info.keys())
- for k in table_info_keys:
- if k not in [
- "Name",
- "Description",
- "Owner",
- "LastAccessTime",
- "LastAnalyzedTime",
- "Retention",
- "StorageDescriptor",
- "PartitionKeys",
- "ViewOriginalText",
- "ViewExpandedText",
- "TableType",
- "Parameters",
- "TargetTable",
- "IsRowFilteringEnabled",
- ]:
- table_info.pop(k)
-
- self.logger.debug("Schema = " + repr(schema))
-
- columns = self.generate_athena_schema(schema)
- if "StorageDescriptor" in table_info:
- table_info["StorageDescriptor"]["Columns"] = columns
- else:
- table_info["StorageDescriptor"] = {"Columns": columns}
- self.update_table(database, table_info, txid)
- self.glue_client.update_table(DatabaseName=database, TableInput=table_info, TransactionId=txid)
+ wr.catalog.create_database(name=database, boto3_session=self._session, exist_ok=True)
+
+ if tag_key and tag_values:
+ self.lf_client.add_lf_tags_to_resource(
+ Resource={
+ "Database": {"Name": database},
+ },
+ LFTags=[{"TagKey": tag_key, "TagValues": tag_values.split(",")}],
+ )
- def get_all_table_objects(self, txid, database, table):
- table_objects = []
+ @retry(stop_max_attempt_number=10, wait_random_min=2000, wait_random_max=3000)
+ def head_bucket(self):
+ return self.s3_client.head_bucket(Bucket=self._config.bucket_name)
+ def table_exists(self, database: str, table: str) -> bool:
try:
- res = self.lf_client.get_table_objects(DatabaseName=database, TableName=table, TransactionId=txid)
- except ClientError as e:
- if e.response["Error"]["Code"] == "EntityNotFoundException":
- return []
- else:
- err = e.response["Error"]["Code"]
- self.logger.error(f"Could not get table objects due to error: {err}")
- raise
-
- while True:
- next_token = res.get("NextToken", None)
- partition_objects = res.get("Objects")
- table_objects.extend([p["Objects"] for p in partition_objects])
- if next_token:
- res = self.lf_client.get_table_objects(
- DatabaseName=database,
- TableName=table,
- TransactionId=txid,
- NextToken=next_token,
- )
- else:
- break
- flat_list = [item for sublist in table_objects for item in sublist]
- return flat_list
-
- def purge_table(self, txid, database, table):
- self.logger.debug(f"Going to purge table {table}")
- write_ops = []
- all_objects = self.get_all_table_objects(txid, database, table)
- write_ops.extend([{"DeleteObject": {"Uri": o["Uri"]}} for o in all_objects])
- if len(write_ops) > 0:
- self.logger.debug(f"{len(write_ops)} objects to purge")
- for batch in self.batch_iterate(write_ops, 99):
- self.logger.debug("Purging batch")
- try:
- self.lf_client.update_table_objects(
- TransactionId=txid,
- DatabaseName=database,
- TableName=table,
- WriteOperations=batch,
- )
- except ClientError as e:
- self.logger.error(f"Could not delete object due to exception {repr(e)}")
- raise
- else:
- self.logger.debug("Table was empty, nothing to purge.")
-
- def update_governed_table(self, txid, database, table, bucket, object_key, etag, size):
- self.logger.debug(f"Updating governed table {database}:{table}")
- write_ops = [
- {
- "AddObject": {
- "Uri": f"s3://{bucket}/{object_key}",
- "ETag": etag,
- "Size": size,
- }
- }
- ]
-
- self.lf_client.update_table_objects(
- TransactionId=txid,
- DatabaseName=database,
- TableName=table,
- WriteOperations=write_ops,
+ self.glue_client.get_table(DatabaseName=database, Name=table)
+ return True
+ except ClientError:
+ return False
+
+ def delete_table(self, database: str, table: str) -> bool:
+ logger.info(f"Deleting table {database}.{table}")
+ return wr.catalog.delete_table_if_exists(database=database, table=table, boto3_session=self._session)
+
+ def delete_table_objects(self, database: str, table: str) -> None:
+ path = self._get_s3_path(database, table)
+ logger.info(f"Deleting objects in {path}")
+ return wr.s3.delete_objects(path=path, boto3_session=self._session)
+
+ def reset_table(self, database: str, table: str) -> None:
+ logger.info(f"Resetting table {database}.{table}")
+ if self.table_exists(database, table):
+ self.delete_table(database, table)
+ self.delete_table_objects(database, table)
+
+ def write(self, df: pd.DataFrame, database: str, table: str, dtype: Dict[str, str], partition_cols: list):
+ path = self._get_s3_path(database, table)
+ return self._write(
+ df,
+ path,
+ database,
+ table,
+ "overwrite",
+ dtype,
+ partition_cols,
)
+ def append(self, df: pd.DataFrame, database: str, table: str, dtype: Dict[str, str], partition_cols: list):
+ path = self._get_s3_path(database, table)
+ return self._write(
+ df,
+ path,
+ database,
+ table,
+ "append",
+ dtype,
+ partition_cols,
+ )
-class LakeformationTransaction:
- def __init__(self, aws_handler: AwsHandler):
- self._aws_handler = aws_handler
- self._transaction = None
- self._logger = aws_handler.logger
-
- @property
- def txid(self):
- return self._transaction["TransactionId"]
-
- def cancel_transaction(self):
- self._logger.debug("Canceling Lakeformation Transaction")
- self._aws_handler.lf_client.cancel_transaction(TransactionId=self.txid)
-
- def commit_transaction(self):
- self._logger.debug(f"Commiting Lakeformation Transaction {self.txid}")
- self._aws_handler.lf_client.commit_transaction(TransactionId=self.txid)
-
- def extend_transaction(self):
- self._logger.debug("Extending Lakeformation Transaction")
- self._aws_handler.lf_client.extend_transaction(TransactionId=self.txid)
-
- def describe_transaction(self):
- return self._aws_handler.lf_client.describe_transaction(TransactionId=self.txid)
-
- def __enter__(self, transaction_type="READ_AND_WRITE"):
- self._logger.debug("Starting Lakeformation Transaction")
- self._transaction = self._aws_handler.lf_client.start_transaction(TransactionType=transaction_type)
- self._logger.debug(f"Transaction id = {self.txid}")
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self._logger.debug("Exiting LakeformationTransaction context manager")
- tx_desc = self.describe_transaction()
- self._logger.debug(json.dumps(tx_desc, default=str))
-
- if exc_type:
- self._logger.error("Exiting LakeformationTransaction context manager due to an exception")
- self._logger.error(repr(exc_type))
- self._logger.error(repr(exc_val))
- self.cancel_transaction()
- self._transaction = None
- else:
- self._logger.debug("Exiting LakeformationTransaction context manager due to reaching end of with block")
- try:
- self.commit_transaction()
- self._transaction = None
- except Exception as e:
- self.cancel_transaction()
- self._logger.error(f"Could not commit the transaction id = {self.txid} because of :\n{repr(e)}")
- self._transaction = None
- raise (e)
+ def upsert(self, df: pd.DataFrame, database: str, table: str, dtype: Dict[str, str], partition_cols: list):
+ path = self._get_s3_path(database, table)
+ return self._write(
+ df,
+ path,
+ database,
+ table,
+ "overwrite_partitions",
+ dtype,
+ partition_cols,
+ )
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/config_reader.py b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/config_reader.py
index db3648e11c62..8559e9d17c4a 100644
--- a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/config_reader.py
+++ b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/config_reader.py
@@ -5,10 +5,76 @@
import enum
-class AuthMode(enum.Enum):
+class CredentialsType(enum.Enum):
IAM_ROLE = "IAM Role"
IAM_USER = "IAM User"
+ @staticmethod
+ def from_string(s: str):
+ if s == "IAM Role":
+ return CredentialsType.IAM_ROLE
+ elif s == "IAM User":
+ return CredentialsType.IAM_USER
+ else:
+ raise ValueError(f"Unknown auth mode: {s}")
+
+
+class OutputFormat(enum.Enum):
+ PARQUET = "Parquet"
+ JSONL = "JSONL"
+
+ @staticmethod
+ def from_string(s: str):
+ if s == "Parquet":
+ return OutputFormat.PARQUET
+
+ return OutputFormat.JSONL
+
+
+class CompressionCodec(enum.Enum):
+ SNAPPY = "SNAPPY"
+ GZIP = "GZIP"
+ ZSTD = "ZSTD"
+ UNCOMPRESSED = "UNCOMPRESSED"
+
+ @staticmethod
+ def from_config(str: str):
+ if str == "SNAPPY":
+ return CompressionCodec.SNAPPY
+ elif str == "GZIP":
+ return CompressionCodec.GZIP
+ elif str == "ZSTD":
+ return CompressionCodec.ZSTD
+
+ return CompressionCodec.UNCOMPRESSED
+
+
+class PartitionOptions(enum.Enum):
+ NONE = "NO PARTITIONING"
+ DATE = "DATE"
+ YEAR = "YEAR"
+ MONTH = "MONTH"
+ DAY = "DAY"
+ YEAR_MONTH = "YEAR/MONTH"
+ YEAR_MONTH_DAY = "YEAR/MONTH/DAY"
+
+ @staticmethod
+ def from_string(s: str):
+ if s == "DATE":
+ return PartitionOptions.DATE
+ elif s == "YEAR":
+ return PartitionOptions.YEAR
+ elif s == "MONTH":
+ return PartitionOptions.MONTH
+ elif s == "DAY":
+ return PartitionOptions.DAY
+ elif s == "YEAR/MONTH":
+ return PartitionOptions.YEAR_MONTH
+ elif s == "YEAR/MONTH/DAY":
+ return PartitionOptions.YEAR_MONTH_DAY
+
+ return PartitionOptions.NONE
+
class ConnectorConfig:
def __init__(
@@ -19,21 +85,36 @@ def __init__(
bucket_name: str = None,
bucket_prefix: str = None,
lakeformation_database_name: str = None,
+ lakeformation_database_default_tag_key: str = None,
+ lakeformation_database_default_tag_values: str = None,
+ lakeformation_governed_tables: bool = False,
+ glue_catalog_float_as_decimal: bool = False,
table_name: str = None,
+ format: dict = {},
+ partitioning: str = None,
):
self.aws_account_id = aws_account_id
self.credentials = credentials
- self.credentials_type = credentials.get("credentials_title")
+ self.credentials_type = CredentialsType.from_string(credentials.get("credentials_title"))
self.region = region
self.bucket_name = bucket_name
self.bucket_prefix = bucket_prefix
self.lakeformation_database_name = lakeformation_database_name
+ self.lakeformation_database_default_tag_key = lakeformation_database_default_tag_key
+ self.lakeformation_database_default_tag_values = lakeformation_database_default_tag_values
+ self.lakeformation_governed_tables = lakeformation_governed_tables
+ self.glue_catalog_float_as_decimal = glue_catalog_float_as_decimal
self.table_name = table_name
- if self.credentials_type == AuthMode.IAM_USER.value:
+ self.format_type = OutputFormat.from_string(format.get("format_type", OutputFormat.PARQUET.value))
+ self.compression_codec = CompressionCodec.from_config(format.get("compression_codec", CompressionCodec.UNCOMPRESSED.value))
+
+ self.partitioning = PartitionOptions.from_string(partitioning)
+
+ if self.credentials_type == CredentialsType.IAM_USER:
self.aws_access_key = self.credentials.get("aws_access_key_id")
self.aws_secret_key = self.credentials.get("aws_secret_access_key")
- elif self.credentials_type == AuthMode.IAM_ROLE.value:
+ elif self.credentials_type == CredentialsType.IAM_ROLE:
self.role_arn = self.credentials.get("role_arn")
else:
raise Exception("Auth Mode not recognized.")
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/destination.py b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/destination.py
index 577af35670fa..95bce879685b 100644
--- a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/destination.py
+++ b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/destination.py
@@ -2,21 +2,36 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+import logging
+import random
+import string
+from typing import Any, Dict, Iterable, Mapping
-import json
-from typing import Any, Iterable, Mapping
-
+import pandas as pd
from airbyte_cdk import AirbyteLogger
from airbyte_cdk.destinations import Destination
from airbyte_cdk.models import AirbyteConnectionStatus, AirbyteMessage, ConfiguredAirbyteCatalog, Status, Type
-from botocore.exceptions import ClientError
+from botocore.exceptions import ClientError, InvalidRegionError
-from .aws import AwsHandler, LakeformationTransaction
+from .aws import AwsHandler
from .config_reader import ConnectorConfig
from .stream_writer import StreamWriter
+logger = logging.getLogger("airbyte")
+
+# Flush records every 25000 records to limit memory consumption
+RECORD_FLUSH_INTERVAL = 25000
+
class DestinationAwsDatalake(Destination):
+ def _flush_streams(self, streams: Dict[str, StreamWriter]) -> None:
+ for stream in streams:
+ streams[stream].flush()
+
+ @staticmethod
+ def _get_random_string(length):
+ return "".join(random.choice(string.ascii_letters) for i in range(length))
+
def write(
self, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, input_messages: Iterable[AirbyteMessage]
) -> Iterable[AirbyteMessage]:
@@ -35,38 +50,63 @@ def write(
:param input_messages: The stream of input messages received from the source
:return: Iterable of AirbyteStateMessages wrapped in AirbyteMessage structs
"""
-
connector_config = ConnectorConfig(**config)
try:
aws_handler = AwsHandler(connector_config, self)
except ClientError as e:
- self.logger.error(f"Could not create session due to exception {repr(e)}")
- raise
- self.logger.debug("AWS session creation OK")
+ logger.error(f"Could not create session due to exception {repr(e)}")
+ raise Exception(f"Could not create session due to exception {repr(e)}")
# creating stream writers
streams = {
- s.stream.name: StreamWriter(
- name=s.stream.name,
- aws_handler=aws_handler,
- connector_config=connector_config,
- schema=s.stream.json_schema["properties"],
- sync_mode=s.destination_sync_mode,
- )
+ s.stream.name: StreamWriter(aws_handler=aws_handler, config=connector_config, configured_stream=s)
for s in configured_catalog.streams
}
for message in input_messages:
if message.type == Type.STATE:
+ if not message.state.data:
+
+ if message.state.stream:
+ stream = message.state.stream.stream_descriptor.name
+ logger.info(f"Received empty state for stream {stream}, resetting stream")
+ if stream in streams:
+ streams[stream].reset()
+ else:
+ logger.warning(f"Trying to reset stream {stream} that is not in the configured catalog")
+
+ if not message.state.stream:
+ logger.info("Received empty state for, resetting all streams including non-incremental streams")
+ for stream in streams:
+ streams[stream].reset()
+
+ # Flush records when state is received
+ if message.state.stream:
+ if message.state.stream.stream_state and hasattr(message.state.stream.stream_state, "stream_name"):
+ stream_name = message.state.stream.stream_state.stream_name
+ if stream_name in streams:
+ logger.info(f"Got state message from source: flushing records for {stream_name}")
+ streams[stream_name].flush(partial=True)
+
yield message
- else:
+
+ elif message.type == Type.RECORD:
data = message.record.data
stream = message.record.stream
- streams[stream].append_message(json.dumps(data, default=str))
+ streams[stream].append_message(data)
+
+ # Flush records every RECORD_FLUSH_INTERVAL records to limit memory consumption
+ # Records will either get flushed when a state message is received or when hitting the RECORD_FLUSH_INTERVAL
+ if len(streams[stream]._messages) > RECORD_FLUSH_INTERVAL:
+ logger.debug(f"Reached size limit: flushing records for {stream}")
+ streams[stream].flush(partial=True)
+
+ else:
+ logger.info(f"Unhandled message type {message.type}: {message}")
- for stream_name, stream in streams.items():
- stream.add_to_datalake()
+ # Flush all or remaining records
+ self._flush_streams(streams)
def check(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> AirbyteConnectionStatus:
"""
@@ -89,6 +129,10 @@ def check(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> AirbyteConn
logger.error(f"""Could not create session on {connector_config.aws_account_id} Exception: {repr(e)}""")
message = f"""Could not authenticate using {connector_config.credentials_type} on Account {connector_config.aws_account_id} Exception: {repr(e)}"""
return AirbyteConnectionStatus(status=Status.FAILED, message=message)
+ except InvalidRegionError:
+ message = f"{connector_config.region} is not a valid AWS region"
+ logger.error(message)
+ return AirbyteConnectionStatus(status=Status.FAILED, message=message)
try:
aws_handler.head_bucket()
@@ -96,16 +140,18 @@ def check(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> AirbyteConn
message = f"""Could not find bucket {connector_config.bucket_name} in aws://{connector_config.aws_account_id}:{connector_config.region} Exception: {repr(e)}"""
return AirbyteConnectionStatus(status=Status.FAILED, message=message)
- with LakeformationTransaction(aws_handler) as tx:
- table_location = "s3://" + connector_config.bucket_name + "/" + connector_config.bucket_prefix + "/" + "airbyte_test/"
- table = aws_handler.get_table(
- txid=tx.txid,
- database_name=connector_config.lakeformation_database_name,
- table_name="airbyte_test",
- location=table_location,
- )
- if table is None:
- message = f"Could not create a table in database {connector_config.lakeformation_database_name}"
+ tbl = f"airbyte_test_{self._get_random_string(5)}"
+ db = connector_config.lakeformation_database_name
+ try:
+ df = pd.DataFrame({"id": [1, 2], "value": ["foo", "bar"]})
+
+ aws_handler.reset_table(db, tbl)
+ aws_handler.write(df, db, tbl, None, None)
+ aws_handler.reset_table(db, tbl)
+
+ except Exception as e:
+ message = f"Could not create table {tbl} in database {db}: {repr(e)}"
+ logger.error(message)
return AirbyteConnectionStatus(status=Status.FAILED, message=message)
return AirbyteConnectionStatus(status=Status.SUCCEEDED)
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/spec.json b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/spec.json
index 7e6f30b131a2..cdf4a1de08c1 100644
--- a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/spec.json
+++ b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/spec.json
@@ -6,20 +6,20 @@
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "AWS Datalake Destination Spec",
"type": "object",
- "required": ["credentials", "region", "bucket_name", "bucket_prefix"],
+ "required": [
+ "credentials",
+ "region",
+ "bucket_name",
+ "lakeformation_database_name"
+ ],
"additionalProperties": false,
"properties": {
"aws_account_id": {
"type": "string",
"title": "AWS Account Id",
"description": "target aws account id",
- "examples": ["111111111111"]
- },
- "region": {
- "title": "AWS Region",
- "type": "string",
- "description": "Region name",
- "airbyte_secret": false
+ "examples": ["111111111111"],
+ "order": 1
},
"credentials": {
"title": "Authentication mode",
@@ -80,25 +80,151 @@
}
}
}
- ]
+ ],
+ "order": 2
+ },
+ "region": {
+ "title": "S3 Bucket Region",
+ "type": "string",
+ "default": "",
+ "description": "The region of the S3 bucket. See here for all region codes.",
+ "enum": [
+ "",
+ "us-east-1",
+ "us-east-2",
+ "us-west-1",
+ "us-west-2",
+ "af-south-1",
+ "ap-east-1",
+ "ap-south-1",
+ "ap-northeast-1",
+ "ap-northeast-2",
+ "ap-northeast-3",
+ "ap-southeast-1",
+ "ap-southeast-2",
+ "ca-central-1",
+ "cn-north-1",
+ "cn-northwest-1",
+ "eu-central-1",
+ "eu-north-1",
+ "eu-south-1",
+ "eu-west-1",
+ "eu-west-2",
+ "eu-west-3",
+ "sa-east-1",
+ "me-south-1",
+ "us-gov-east-1",
+ "us-gov-west-1"
+ ],
+ "order": 3
},
"bucket_name": {
"title": "S3 Bucket Name",
"type": "string",
- "description": "Name of the bucket",
- "airbyte_secret": false
+ "description": "The name of the S3 bucket. Read more here.",
+ "order": 4
},
"bucket_prefix": {
"title": "Target S3 Bucket Prefix",
"type": "string",
"description": "S3 prefix",
- "airbyte_secret": false
+ "order": 5
},
"lakeformation_database_name": {
- "title": "Lakeformation Database Name",
+ "title": "Lake Formation Database Name",
+ "type": "string",
+ "description": "The default database this destination will use to create tables in per stream. Can be changed per connection by customizing the namespace.",
+ "order": 6
+ },
+ "lakeformation_database_default_tag_key": {
+ "title": "Lake Formation Database Tag Key",
+ "description": "Add a default tag key to databases created by this destination",
+ "examples": ["pii_level"],
"type": "string",
- "description": "Which database to use",
- "airbyte_secret": false
+ "order": 7
+ },
+ "lakeformation_database_default_tag_values": {
+ "title": "Lake Formation Database Tag Values",
+ "description": "Add default values for the `Tag Key` to databases created by this destination. Comma separate for multiple values.",
+ "examples": ["private,public"],
+ "type": "string",
+ "order": 8
+ },
+ "lakeformation_governed_tables": {
+ "title": "Lake Formation Governed Tables",
+ "description": "Whether to create tables as LF governed tables.",
+ "type": "boolean",
+ "default": false,
+ "order": 9
+ },
+ "format": {
+ "title": "Output Format *",
+ "type": "object",
+ "description": "Format of the data output.",
+ "oneOf": [
+ {
+ "title": "JSON Lines: Newline-delimited JSON",
+ "required": ["format_type"],
+ "properties": {
+ "format_type": {
+ "title": "Format Type *",
+ "type": "string",
+ "enum": ["JSONL"],
+ "default": "JSONL"
+ },
+ "compression_codec": {
+ "title": "Compression Codec (Optional)",
+ "description": "The compression algorithm used to compress data.",
+ "type": "string",
+ "enum": ["UNCOMPRESSED", "GZIP"],
+ "default": "UNCOMPRESSED"
+ }
+ }
+ },
+ {
+ "title": "Parquet: Columnar Storage",
+ "required": ["format_type"],
+ "properties": {
+ "format_type": {
+ "title": "Format Type *",
+ "type": "string",
+ "enum": ["Parquet"],
+ "default": "Parquet"
+ },
+ "compression_codec": {
+ "title": "Compression Codec (Optional)",
+ "description": "The compression algorithm used to compress data.",
+ "type": "string",
+ "enum": ["UNCOMPRESSED", "SNAPPY", "GZIP", "ZSTD"],
+ "default": "SNAPPY"
+ }
+ }
+ }
+ ],
+ "order": 10
+ },
+ "partitioning": {
+ "title": "Choose how to partition data",
+ "description": "Partition data by cursor fields when a cursor field is a date",
+ "type": "string",
+ "enum": [
+ "NO PARTITIONING",
+ "DATE",
+ "YEAR",
+ "MONTH",
+ "DAY",
+ "YEAR/MONTH",
+ "YEAR/MONTH/DAY"
+ ],
+ "default": "NO PARTITIONING",
+ "order": 11
+ },
+ "glue_catalog_float_as_decimal": {
+ "title": "Glue Catalog: Float as Decimal",
+ "description": "Cast float/double as decimal(38,18). This can help achieve higher accuracy and represent numbers correctly as received from the source.",
+ "type": "boolean",
+ "default": false,
+ "order": 12
}
}
}
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/stream_writer.py b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/stream_writer.py
index 902ebe9db724..b093f1553ed4 100644
--- a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/stream_writer.py
+++ b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/stream_writer.py
@@ -2,70 +2,403 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-from datetime import datetime
+import json
+import logging
+from typing import Any, Dict, List, Optional, Tuple, Union
-import nanoid
-from airbyte_cdk.models import DestinationSyncMode
-from retrying import retry
+import pandas as pd
+from airbyte_cdk.models import ConfiguredAirbyteStream, DestinationSyncMode
+from destination_aws_datalake.config_reader import ConnectorConfig, PartitionOptions
-from .aws import AwsHandler, LakeformationTransaction
+from .aws import AwsHandler
+
+logger = logging.getLogger("airbyte")
class StreamWriter:
- def __init__(self, name, aws_handler: AwsHandler, connector_config, schema, sync_mode):
- self._db = connector_config.lakeformation_database_name
- self._bucket = connector_config.bucket_name
- self._prefix = connector_config.bucket_prefix
- self._table = name
- self._aws_handler = aws_handler
- self._schema = schema
- self._sync_mode = sync_mode
- self._messages = []
- self._logger = aws_handler.logger
-
- self._logger.debug(f"Creating StreamWriter for {self._db}:{self._table}")
- if sync_mode == DestinationSyncMode.overwrite:
- self._logger.debug(f"StreamWriter mode is OVERWRITE, need to purge {self._db}:{self._table}")
- with LakeformationTransaction(self._aws_handler) as tx:
- self._aws_handler.purge_table(tx.txid, self._db, self._table)
-
- def append_message(self, message):
- self._logger.debug(f"Appending message to table {self._table}")
- self._messages.append(message)
-
- def generate_object_key(self, prefix=None):
- salt = nanoid.generate(size=10)
- base = datetime.now().strftime("%Y%m%d%H%M%S")
- path = f"{base}.{salt}.json"
- if prefix:
- path = f"{prefix}/{base}.{salt}.json"
-
- return path
-
- @retry(stop_max_attempt_number=10, wait_random_min=2000, wait_random_max=3000)
- def add_to_datalake(self):
- with LakeformationTransaction(self._aws_handler) as tx:
- self._logger.debug(f"Flushing messages to table {self._table}")
- object_prefix = f"{self._prefix}/{self._table}"
- table_location = "s3://" + self._bucket + "/" + self._prefix + "/" + self._table + "/"
-
- table = self._aws_handler.get_table(tx.txid, self._db, self._table, table_location)
- self._aws_handler.update_table_schema(tx.txid, self._db, table, self._schema)
-
- if len(self._messages) > 0:
- try:
- self._logger.debug(f"There are {len(self._messages)} messages to flush for {self._table}")
- self._logger.debug(f"10 first messages >>> {repr(self._messages[0:10])} <<<")
- object_key = self.generate_object_key(object_prefix)
- self._aws_handler.put_object(object_key, self._messages)
- res = self._aws_handler.head_object(object_key)
- self._aws_handler.update_governed_table(
- tx.txid, self._db, self._table, self._bucket, object_key, res["ETag"], res["ContentLength"]
- )
- self._logger.debug(f"Table {self._table} was updated")
- except Exception as e:
- self._logger.error(f"An exception was raised:\n{repr(e)}")
- raise (e)
- else:
- self._logger.debug(f"There was no message to flush for {self._table}")
+ def __init__(self, aws_handler: AwsHandler, config: ConnectorConfig, configured_stream: ConfiguredAirbyteStream):
+ self._aws_handler: AwsHandler = aws_handler
+ self._config: ConnectorConfig = config
+ self._configured_stream: ConfiguredAirbyteStream = configured_stream
+ self._schema: Dict[str, Any] = configured_stream.stream.json_schema["properties"]
+ self._sync_mode: DestinationSyncMode = configured_stream.destination_sync_mode
+
+ self._table_exists: bool = False
+ self._table: str = configured_stream.stream.name
+ self._database: str = self._configured_stream.stream.namespace or self._config.lakeformation_database_name
+
self._messages = []
+ self._partial_flush_count = 0
+
+ logger.info(f"Creating StreamWriter for {self._database}:{self._table}")
+
+ def _get_date_columns(self) -> list:
+ date_columns = []
+ for key, val in self._schema.items():
+ typ = val.get("type")
+ if (isinstance(typ, str) and typ == "string") or (isinstance(typ, list) and "string" in typ):
+ if val.get("format") in ["date-time", "date"]:
+ date_columns.append(key)
+
+ return date_columns
+
+ def _add_partition_column(self, col: str, df: pd.DataFrame) -> list:
+ partitioning = self._config.partitioning
+
+ if partitioning == PartitionOptions.NONE:
+ return {}
+
+ partitions = partitioning.value.split("/")
+
+ fields = {}
+ for partition in partitions:
+ date_col = f"{col}_{partition.lower()}"
+ fields[date_col] = "bigint"
+
+ # defaulting to 0 since both governed tables
+ # and pyarrow don't play well with __HIVE_DEFAULT_PARTITION__
+ # - pyarrow will fail to cast the column to any other type than string
+ # - governed tables will fail when trying to query a table with partitions that have __HIVE_DEFAULT_PARTITION__
+ # aside from the above, awswrangler will remove data from a table if the partition value is null
+ # see: https://github.com/aws/aws-sdk-pandas/issues/921
+ if partition == "YEAR":
+ df[date_col] = df[col].dt.strftime("%Y").fillna("0").astype("Int64")
+
+ elif partition == "MONTH":
+ df[date_col] = df[col].dt.strftime("%m").fillna("0").astype("Int64")
+
+ elif partition == "DAY":
+ df[date_col] = df[col].dt.strftime("%d").fillna("0").astype("Int64")
+
+ elif partition == "DATE":
+ fields[date_col] = "date"
+ df[date_col] = df[col].dt.strftime("%Y-%m-%d")
+
+ return fields
+
+ def _drop_additional_top_level_properties(self, record: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Helper that removes any unexpected top-level properties from the record.
+ Since the json schema is used to build the table and cast types correctly,
+ we need to remove any unexpected properties that can't be casted accurately.
+ """
+ schema_keys = self._schema.keys()
+ records_keys = record.keys()
+ difference = list(set(records_keys).difference(set(schema_keys)))
+
+ for key in difference:
+ del record[key]
+
+ return record
+
+ def _fix_obvious_type_violations(self, record: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Helper that fixes obvious type violations in a record's top level keys that may
+ cause issues when casting data to pyarrow types. Such as:
+ - Objects having empty strings or " " or "-" as value instead of null or {}
+ - Arrays having empty strings or " " or "-" as value instead of null or []
+ """
+ schema_keys = self._schema.keys()
+ for key in schema_keys:
+ typ = self._schema[key].get("type")
+ typ = self._get_json_schema_type(typ)
+ if typ in ["object", "array"]:
+ if record.get(key) in ["", " ", "-", "/", "null"]:
+ record[key] = None
+
+ return record
+
+ def _add_missing_columns(self, record: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Helper that adds missing columns to a record's top level keys. Required
+ for awswrangler to create the correct schema in glue, even with the explicit
+ schema passed in, awswrangler will remove those columns when not present
+ in the dataframe
+ """
+ schema_keys = self._schema.keys()
+ records_keys = record.keys()
+ difference = list(set(schema_keys).difference(set(records_keys)))
+
+ for key in difference:
+ record[key] = None
+
+ return record
+
+ def _get_non_null_json_schema_types(self, typ: Union[str, List[str]]) -> Union[str, List[str]]:
+ if isinstance(typ, list):
+ return list(filter(lambda x: x != "null", typ))
+
+ return typ
+
+ def _json_schema_type_has_mixed_types(self, typ: Union[str, List[str]]) -> bool:
+ if isinstance(typ, list):
+ typ = self._get_non_null_json_schema_types(typ)
+ if len(typ) > 1:
+ return True
+
+ return False
+
+ def _get_json_schema_type(self, types: Union[List[str], str]) -> str:
+ if isinstance(types, str):
+ return types
+
+ if not isinstance(types, list):
+ return "string"
+
+ types = self._get_non_null_json_schema_types(types)
+ # when multiple types, cast to string
+ if self._json_schema_type_has_mixed_types(types):
+ return "string"
+
+ return types[0]
+
+ def _get_pandas_dtypes_from_json_schema(self, df: pd.DataFrame) -> Dict[str, str]:
+ type_mapper = {
+ "string": "string",
+ "integer": "Int64",
+ "number": "float64",
+ "boolean": "bool",
+ "object": "object",
+ "array": "object",
+ }
+
+ column_types = {}
+
+ typ = "string"
+ for col in df.columns:
+ if col in self._schema:
+ typ = self._schema[col].get("type", "string")
+ airbyte_type = self._schema[col].get("airbyte_type")
+
+ # special case where the json schema type contradicts the airbyte type
+ if airbyte_type and typ == "number" and airbyte_type == "integer":
+ typ = "integer"
+
+ typ = self._get_json_schema_type(typ)
+
+ column_types[col] = type_mapper.get(typ, "string")
+
+ return column_types
+
+ def _get_json_schema_types(self):
+ types = {}
+ for key, val in self._schema.items():
+ typ = val.get("type")
+ types[key] = self._get_json_schema_type(typ)
+ return types
+
+ def _is_invalid_struct_or_array(self, schema: Dict[str, Any]) -> bool:
+ """
+ Helper that detects issues with nested objects/arrays in the json schema.
+ When a complex data type is detected (schema with oneOf) or a nested object without properties
+ the columns' dtype will be casted to string to avoid pyarrow conversion issues.
+ """
+ result = True
+
+ def check_properties(schema):
+ nonlocal result
+ for val in schema.values():
+ # Complex types can't be casted to an athena/glue type
+ if val.get("oneOf"):
+ result = False
+ continue
+
+ raw_typ = val.get("type")
+
+ # If the type is a list, check for mixed types
+ # complex objects with mixed types can't be reliably casted
+ if isinstance(raw_typ, list) and self._json_schema_type_has_mixed_types(raw_typ):
+ result = False
+ continue
+
+ typ = self._get_json_schema_type(raw_typ)
+
+ # If object check nested properties
+ if typ == "object":
+ properties = val.get("properties")
+ if not properties:
+ result = False
+ else:
+ check_properties(properties)
+
+ # If array check nested properties
+ if typ == "array":
+ items = val.get("items")
+
+ if not items:
+ result = False
+ continue
+
+ if isinstance(items, list):
+ items = items[0]
+
+ item_properties = items.get("properties")
+ if item_properties:
+ check_properties(item_properties)
+
+ check_properties(schema)
+ return result
+
+ def _get_glue_dtypes_from_json_schema(self, schema: Dict[str, Any]) -> Tuple[Dict[str, str], List[str]]:
+ """
+ Helper that infers glue dtypes from a json schema.
+ """
+
+ type_mapper = {
+ "string": "string",
+ "integer": "bigint",
+ "number": "decimal(38, 25)" if self._config.glue_catalog_float_as_decimal else "double",
+ "boolean": "boolean",
+ "null": "string",
+ }
+
+ column_types = {}
+ json_columns = set()
+ for (col, definition) in schema.items():
+
+ result_typ = None
+ col_typ = definition.get("type")
+ airbyte_type = definition.get("airbyte_type")
+ col_format = definition.get("format")
+
+ col_typ = self._get_json_schema_type(col_typ)
+
+ # special case where the json schema type contradicts the airbyte type
+ if airbyte_type and col_typ == "number" and airbyte_type == "integer":
+ col_typ = "integer"
+
+ if col_typ == "string" and col_format == "date-time":
+ result_typ = "timestamp"
+
+ if col_typ == "string" and col_format == "date":
+ result_typ = "date"
+
+ if col_typ == "object":
+ properties = definition.get("properties")
+ if properties and self._is_invalid_struct_or_array(properties):
+ object_props, _ = self._get_glue_dtypes_from_json_schema(properties)
+ result_typ = f"struct<{','.join([f'{k}:{v}' for k, v in object_props.items()])}>"
+ else:
+ json_columns.add(col)
+ result_typ = "string"
+
+ if col_typ == "array":
+ items = definition.get("items", {})
+
+ if isinstance(items, list):
+ items = items[0]
+
+ raw_item_type = items.get("type")
+ item_type = self._get_json_schema_type(raw_item_type)
+ item_properties = items.get("properties")
+
+ # if array has no "items", cast to string
+ if not items:
+ json_columns.add(col)
+ result_typ = "string"
+
+ # if array with objects
+ elif isinstance(items, dict) and item_properties:
+ # Check if nested object has properties and no mixed type objects
+ if self._is_invalid_struct_or_array(item_properties):
+ item_dtypes, _ = self._get_glue_dtypes_from_json_schema(item_properties)
+ inner_struct = f"struct<{','.join([f'{k}:{v}' for k, v in item_dtypes.items()])}>"
+ result_typ = f"array<{inner_struct}>"
+ else:
+ json_columns.add(col)
+ result_typ = "string"
+
+ elif item_type and self._json_schema_type_has_mixed_types(raw_item_type):
+ json_columns.add(col)
+ result_typ = "string"
+
+ # array with single type
+ elif item_type and not self._json_schema_type_has_mixed_types(raw_item_type):
+ result_typ = f"array<{type_mapper[item_type]}>"
+
+ if result_typ is None:
+ result_typ = type_mapper.get(col_typ, "string")
+
+ column_types[col] = result_typ
+
+ return column_types, json_columns
+
+ @property
+ def _cursor_fields(self) -> Optional[List[str]]:
+ return self._configured_stream.cursor_field
+
+ def append_message(self, message: Dict[str, Any]):
+ clean_message = self._drop_additional_top_level_properties(message)
+ clean_message = self._fix_obvious_type_violations(clean_message)
+ clean_message = self._add_missing_columns(clean_message)
+ self._messages.append(clean_message)
+
+ def reset(self):
+ logger.info(f"Deleting table {self._database}:{self._table}")
+ success = self._aws_handler.delete_table(self._database, self._table)
+
+ if not success:
+ logger.warning(f"Failed to reset table {self._database}:{self._table}")
+
+ def flush(self, partial: bool = False):
+ logger.debug(f"Flushing {len(self._messages)} messages to table {self._database}:{self._table}")
+
+ df = pd.DataFrame(self._messages)
+ # best effort to convert pandas types
+ df = df.astype(self._get_pandas_dtypes_from_json_schema(df), errors="ignore")
+
+ if len(df) < 1:
+ logger.info(f"No messages to write to {self._database}:{self._table}")
+ return
+
+ partition_fields = {}
+ date_columns = self._get_date_columns()
+ for col in date_columns:
+ if col in df.columns:
+ df[col] = pd.to_datetime(df[col])
+
+ # Create date column for partitioning
+ if self._cursor_fields and col in self._cursor_fields:
+ fields = self._add_partition_column(col, df)
+ partition_fields.update(fields)
+
+ dtype, json_casts = self._get_glue_dtypes_from_json_schema(self._schema)
+ dtype = {**dtype, **partition_fields}
+ partition_fields = list(partition_fields.keys())
+
+ # Make sure complex types that can't be converted
+ # to a struct or array are converted to a json string
+ # so they can be queried with json_extract
+ for col in json_casts:
+ if col in df.columns:
+ df[col] = df[col].apply(json.dumps)
+
+ if self._sync_mode == DestinationSyncMode.overwrite and self._partial_flush_count < 1:
+ logger.debug(f"Overwriting {len(df)} records to {self._database}:{self._table}")
+ self._aws_handler.write(
+ df,
+ self._database,
+ self._table,
+ dtype,
+ partition_fields,
+ )
+
+ elif self._sync_mode == DestinationSyncMode.append or self._partial_flush_count > 0:
+ logger.debug(f"Appending {len(df)} records to {self._database}:{self._table}")
+ self._aws_handler.append(
+ df,
+ self._database,
+ self._table,
+ dtype,
+ partition_fields,
+ )
+
+ else:
+ self._messages = []
+ raise Exception(f"Unsupported sync mode: {self._sync_mode}")
+
+ if partial:
+ self._partial_flush_count += 1
+
+ del df
+ self._messages.clear()
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/unit_test.py b/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/__init__.py
similarity index 57%
rename from airbyte-integrations/connectors/destination-aws-datalake/unit_tests/unit_test.py
rename to airbyte-integrations/connectors/destination-aws-datalake/integration_tests/__init__.py
index 219ae0142c72..c941b3045795 100644
--- a/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/unit_test.py
+++ b/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/__init__.py
@@ -1,7 +1,3 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-
-
-def test_example_method():
- assert True
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/integration_test.py b/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/integration_test.py
new file mode 100644
index 000000000000..8e8c9fd72532
--- /dev/null
+++ b/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/integration_test.py
@@ -0,0 +1,160 @@
+#
+# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
+#
+
+import json
+import logging
+from datetime import datetime
+from typing import Any, Dict, Mapping
+
+import awswrangler as wr
+import pytest
+from airbyte_cdk.models import (
+ AirbyteMessage,
+ AirbyteRecordMessage,
+ AirbyteStateMessage,
+ AirbyteStream,
+ ConfiguredAirbyteCatalog,
+ ConfiguredAirbyteStream,
+ DestinationSyncMode,
+ Status,
+ SyncMode,
+ Type,
+)
+from destination_aws_datalake import DestinationAwsDatalake
+from destination_aws_datalake.aws import AwsHandler
+from destination_aws_datalake.config_reader import ConnectorConfig
+
+logger = logging.getLogger("airbyte")
+
+
+@pytest.fixture(name="config")
+def config_fixture() -> Mapping[str, Any]:
+ with open("secrets/config.json", "r") as f:
+ return json.loads(f.read())
+
+
+@pytest.fixture(name="invalid_region_config")
+def invalid_region_config() -> Mapping[str, Any]:
+ with open("integration_tests/invalid_region_config.json", "r") as f:
+ return json.loads(f.read())
+
+
+@pytest.fixture(name="invalid_account_config")
+def invalid_account_config() -> Mapping[str, Any]:
+ with open("integration_tests/invalid_account_config.json", "r") as f:
+ return json.loads(f.read())
+
+
+@pytest.fixture(name="configured_catalog")
+def configured_catalog_fixture() -> ConfiguredAirbyteCatalog:
+ stream_schema = {
+ "type": "object",
+ "properties": {
+ "string_col": {"type": "str"},
+ "int_col": {"type": "integer"},
+ "date_col": {"type": "string", "format": "date-time"},
+ },
+ }
+
+ append_stream = ConfiguredAirbyteStream(
+ stream=AirbyteStream(
+ name="append_stream", json_schema=stream_schema, supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental]
+ ),
+ sync_mode=SyncMode.incremental,
+ destination_sync_mode=DestinationSyncMode.append,
+ )
+
+ overwrite_stream = ConfiguredAirbyteStream(
+ stream=AirbyteStream(
+ name="overwrite_stream", json_schema=stream_schema, supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental]
+ ),
+ sync_mode=SyncMode.incremental,
+ destination_sync_mode=DestinationSyncMode.overwrite,
+ )
+
+ return ConfiguredAirbyteCatalog(streams=[append_stream, overwrite_stream])
+
+
+def test_check_valid_config(config: Mapping):
+ outcome = DestinationAwsDatalake().check(logger, config)
+ assert outcome.status == Status.SUCCEEDED
+
+
+def test_check_invalid_aws_region_config(invalid_region_config: Mapping):
+ outcome = DestinationAwsDatalake().check(logger, invalid_region_config)
+ assert outcome.status == Status.FAILED
+
+
+def test_check_invalid_aws_account_config(invalid_account_config: Mapping):
+ outcome = DestinationAwsDatalake().check(logger, invalid_account_config)
+ assert outcome.status == Status.FAILED
+
+
+def _state(data: Dict[str, Any]) -> AirbyteMessage:
+ return AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(data=data))
+
+
+def _record(stream: str, str_value: str, int_value: int, date_value: datetime) -> AirbyteMessage:
+ return AirbyteMessage(
+ type=Type.RECORD,
+ record=AirbyteRecordMessage(stream=stream, data={"str_col": str_value, "int_col": int_value, "date_col": date_value}, emitted_at=0),
+ )
+
+
+def test_write(config: Mapping, configured_catalog: ConfiguredAirbyteCatalog):
+ """
+ This test verifies that:
+ 1. writing a stream in "overwrite" mode overwrites any existing data for that stream
+ 2. writing a stream in "append" mode appends new records without deleting the old ones
+ 3. The correct state message is output by the connector at the end of the sync
+ """
+ append_stream, overwrite_stream = configured_catalog.streams[0].stream.name, configured_catalog.streams[1].stream.name
+
+ destination = DestinationAwsDatalake()
+
+ connector_config = ConnectorConfig(**config)
+ aws_handler = AwsHandler(connector_config, destination)
+
+ database = connector_config.lakeformation_database_name
+
+ # make sure we start with empty tables
+ for tbl in [append_stream, overwrite_stream]:
+ aws_handler.reset_table(database, tbl)
+
+ first_state_message = _state({"state": "1"})
+
+ first_record_chunk = [_record(append_stream, str(i), i, datetime.now()) for i in range(5)] + [
+ _record(overwrite_stream, str(i), i, datetime.now()) for i in range(5)
+ ]
+
+ second_state_message = _state({"state": "2"})
+ second_record_chunk = [_record(append_stream, str(i), i, datetime.now()) for i in range(5, 10)] + [
+ _record(overwrite_stream, str(i), i, datetime.now()) for i in range(5, 10)
+ ]
+
+ expected_states = [first_state_message, second_state_message]
+ output_states = list(
+ destination.write(
+ config, configured_catalog, [*first_record_chunk, first_state_message, *second_record_chunk, second_state_message]
+ )
+ )
+ assert expected_states == output_states, "Checkpoint state messages were expected from the destination"
+
+ # Check if table was created
+ for tbl in [append_stream, overwrite_stream]:
+ table = wr.catalog.table(database=database, table=tbl, boto3_session=aws_handler._session)
+ expected_types = {"string_col": "string", "int_col": "bigint", "date_col": "timestamp"}
+
+ # Check table format
+ for col in table.to_dict("records"):
+ assert col["Column Name"] in ["string_col", "int_col", "date_col"]
+ assert col["Type"] == expected_types[col["Column Name"]]
+
+ # Check table data
+ # cannot use wr.lakeformation.read_sql_query because of this issue: https://github.com/aws/aws-sdk-pandas/issues/2007
+ df = wr.athena.read_sql_query(f"SELECT * FROM {tbl}", database=database, boto3_session=aws_handler._session)
+ assert len(df) == 10
+
+ # Reset table
+ aws_handler.reset_table(database, tbl)
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/invalid_account_config.json b/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/invalid_account_config.json
new file mode 100644
index 000000000000..db12398564dd
--- /dev/null
+++ b/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/invalid_account_config.json
@@ -0,0 +1,11 @@
+{
+ "aws_account_id": "111111111111",
+ "region": "us-east-1",
+ "credentials": {
+ "credentials_title": "IAM User",
+ "aws_access_key_id": "dummy",
+ "aws_secret_access_key": "dummykey"
+ },
+ "bucket_name": "my-bucket",
+ "partitioning": "NO PARTITIONING"
+}
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/invalid_region_config.json b/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/invalid_region_config.json
new file mode 100644
index 000000000000..fd28e3b32ca4
--- /dev/null
+++ b/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/invalid_region_config.json
@@ -0,0 +1,10 @@
+{
+ "aws_account_id": "111111111111",
+ "region": "not_a_real_region",
+ "credentials": {
+ "credentials_title": "IAM User",
+ "aws_access_key_id": "dummy",
+ "aws_secret_access_key": "dummykey"
+ },
+ "partitioning": "NO PARTITIONING"
+}
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/setup.py b/airbyte-integrations/connectors/destination-aws-datalake/setup.py
index fb3d48b4dab6..4dccbefb7619 100644
--- a/airbyte-integrations/connectors/destination-aws-datalake/setup.py
+++ b/airbyte-integrations/connectors/destination-aws-datalake/setup.py
@@ -6,17 +6,17 @@
from setuptools import find_packages, setup
MAIN_REQUIREMENTS = [
- "airbyte-cdk==0.1.6-rc1",
- "boto3",
+ "airbyte-cdk~=0.1",
"retrying",
- "nanoid",
+ "awswrangler==2.17.0",
+ "pandas==1.4.4",
]
TEST_REQUIREMENTS = ["pytest~=6.1"]
setup(
name="destination_aws_datalake",
- description="Destination implementation for Aws Datalake.",
+ description="Destination implementation for AWS Datalake.",
author="Airbyte",
author_email="contact@airbyte.io",
packages=find_packages(),
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/AthenaHelper.java b/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/AthenaHelper.java
deleted file mode 100644
index 556c7f6a3f2c..000000000000
--- a/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/AthenaHelper.java
+++ /dev/null
@@ -1,129 +0,0 @@
-/*
- * Copyright (c) 2023 Airbyte, Inc., all rights reserved.
- */
-
-package io.airbyte.integrations.destination.aws_datalake;
-
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import software.amazon.awssdk.auth.credentials.AwsCredentials;
-import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
-import software.amazon.awssdk.regions.Region;
-import software.amazon.awssdk.services.athena.AthenaClient;
-import software.amazon.awssdk.services.athena.model.AthenaException;
-import software.amazon.awssdk.services.athena.model.GetQueryExecutionRequest;
-import software.amazon.awssdk.services.athena.model.GetQueryExecutionResponse;
-import software.amazon.awssdk.services.athena.model.GetQueryResultsRequest;
-import software.amazon.awssdk.services.athena.model.QueryExecutionContext;
-import software.amazon.awssdk.services.athena.model.QueryExecutionState;
-import software.amazon.awssdk.services.athena.model.ResultConfiguration;
-import software.amazon.awssdk.services.athena.model.StartQueryExecutionRequest;
-import software.amazon.awssdk.services.athena.model.StartQueryExecutionResponse;
-import software.amazon.awssdk.services.athena.paginators.GetQueryResultsIterable;
-
-public class AthenaHelper {
-
- private AthenaClient athenaClient;
- private String outputBucket;
- private String workGroup;
- private static final Logger LOGGER = LoggerFactory.getLogger(AthenaHelper.class);
-
- public AthenaHelper(AwsCredentials credentials, Region region, String outputBucket, String workGroup) {
- LOGGER.debug(String.format("region = %s, outputBucket = %s, workGroup = %s", region, outputBucket, workGroup));
- var credProvider = StaticCredentialsProvider.create(credentials);
- this.athenaClient = AthenaClient.builder().region(region).credentialsProvider(credProvider).build();
- this.outputBucket = outputBucket;
- this.workGroup = workGroup;
- }
-
- public String submitAthenaQuery(String database, String query) {
- try {
-
- // The QueryExecutionContext allows us to set the database
- QueryExecutionContext queryExecutionContext = QueryExecutionContext.builder()
- .database(database).build();
-
- // The result configuration specifies where the results of the query should go
- ResultConfiguration resultConfiguration = ResultConfiguration.builder()
- .outputLocation(outputBucket)
- .build();
-
- StartQueryExecutionRequest startQueryExecutionRequest = StartQueryExecutionRequest.builder()
- .queryString(query)
- .queryExecutionContext(queryExecutionContext)
- .resultConfiguration(resultConfiguration)
- .workGroup(workGroup)
- .build();
-
- StartQueryExecutionResponse startQueryExecutionResponse = athenaClient.startQueryExecution(startQueryExecutionRequest);
- return startQueryExecutionResponse.queryExecutionId();
- } catch (AthenaException e) {
- e.printStackTrace();
- System.exit(1);
- }
- return "";
- }
-
- public void waitForQueryToComplete(String queryExecutionId) throws InterruptedException {
- GetQueryExecutionRequest getQueryExecutionRequest = GetQueryExecutionRequest.builder()
- .queryExecutionId(queryExecutionId).build();
-
- GetQueryExecutionResponse getQueryExecutionResponse;
- boolean isQueryStillRunning = true;
- while (isQueryStillRunning) {
- getQueryExecutionResponse = athenaClient.getQueryExecution(getQueryExecutionRequest);
- String queryState = getQueryExecutionResponse.queryExecution().status().state().toString();
- if (queryState.equals(QueryExecutionState.FAILED.toString())) {
- throw new RuntimeException("The Amazon Athena query failed to run with error message: " + getQueryExecutionResponse
- .queryExecution().status().stateChangeReason());
- } else if (queryState.equals(QueryExecutionState.CANCELLED.toString())) {
- throw new RuntimeException("The Amazon Athena query was cancelled.");
- } else if (queryState.equals(QueryExecutionState.SUCCEEDED.toString())) {
- isQueryStillRunning = false;
- } else {
- // Sleep an amount of time before retrying again
- Thread.sleep(1000);
- }
- }
- }
-
- public GetQueryResultsIterable getResults(String queryExecutionId) {
-
- try {
-
- // Max Results can be set but if its not set,
- // it will choose the maximum page size
- GetQueryResultsRequest getQueryResultsRequest = GetQueryResultsRequest.builder()
- .queryExecutionId(queryExecutionId)
- .build();
-
- GetQueryResultsIterable getQueryResultsResults = athenaClient.getQueryResultsPaginator(getQueryResultsRequest);
- return getQueryResultsResults;
-
- } catch (AthenaException e) {
- e.printStackTrace();
- System.exit(1);
- }
- return null;
- }
-
- public GetQueryResultsIterable runQuery(String database, String query) throws InterruptedException {
- int retryCount = 0;
-
- while (retryCount < 10) {
- var execId = submitAthenaQuery(database, query);
- try {
- waitForQueryToComplete(execId);
- } catch (RuntimeException e) {
- e.printStackTrace();
- LOGGER.info("Athena query failed once. Retrying.");
- retryCount++;
- continue;
- }
- return getResults(execId);
- }
- LOGGER.info("Athena query failed and we are out of retries.");
- return null;
- }
-
-}
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeDestinationConfig.java b/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeDestinationConfig.java
deleted file mode 100644
index 4d66d91f5fdd..000000000000
--- a/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeDestinationConfig.java
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Copyright (c) 2023 Airbyte, Inc., all rights reserved.
- */
-
-package io.airbyte.integrations.destination.aws_datalake;
-
-import com.fasterxml.jackson.databind.JsonNode;
-
-public class AwsDatalakeDestinationConfig {
-
- private final String awsAccountId;
- private final String region;
- private final String accessKeyId;
- private final String secretAccessKey;
- private final String bucketName;
- private final String prefix;
- private final String databaseName;
-
- public AwsDatalakeDestinationConfig(String awsAccountId,
- String region,
- String accessKeyId,
- String secretAccessKey,
- String bucketName,
- String prefix,
- String databaseName) {
- this.awsAccountId = awsAccountId;
- this.region = region;
- this.accessKeyId = accessKeyId;
- this.secretAccessKey = secretAccessKey;
- this.bucketName = bucketName;
- this.prefix = prefix;
- this.databaseName = databaseName;
-
- }
-
- public static AwsDatalakeDestinationConfig getAwsDatalakeDestinationConfig(JsonNode config) {
-
- final String aws_access_key_id = config.path("credentials").get("aws_access_key_id").asText();
- final String aws_secret_access_key = config.path("credentials").get("aws_secret_access_key").asText();
-
- return new AwsDatalakeDestinationConfig(
- config.get("aws_account_id").asText(),
- config.get("region").asText(),
- aws_access_key_id,
- aws_secret_access_key,
- config.get("bucket_name").asText(),
- config.get("bucket_prefix").asText(),
- config.get("lakeformation_database_name").asText());
- }
-
- public String getAwsAccountId() {
- return awsAccountId;
- }
-
- public String getRegion() {
- return region;
- }
-
- public String getAccessKeyId() {
- return accessKeyId;
- }
-
- public String getSecretAccessKey() {
- return secretAccessKey;
- }
-
- public String getBucketName() {
- return bucketName;
- }
-
- public String getPrefix() {
- return prefix;
- }
-
- public String getDatabaseName() {
- return databaseName;
- }
-
-}
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/GlueHelper.java b/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/GlueHelper.java
deleted file mode 100644
index bd81bb515ace..000000000000
--- a/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/GlueHelper.java
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Copyright (c) 2023 Airbyte, Inc., all rights reserved.
- */
-
-package io.airbyte.integrations.destination.aws_datalake;
-
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
-import software.amazon.awssdk.auth.credentials.AwsCredentials;
-import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
-import software.amazon.awssdk.regions.Region;
-import software.amazon.awssdk.services.glue.GlueClient;
-import software.amazon.awssdk.services.glue.model.BatchDeleteTableRequest;
-import software.amazon.awssdk.services.glue.model.GetTablesRequest;
-import software.amazon.awssdk.services.glue.model.GetTablesResponse;
-import software.amazon.awssdk.services.glue.model.Table;
-import software.amazon.awssdk.services.glue.paginators.GetTablesIterable;
-
-public class GlueHelper {
-
- private AwsCredentials awsCredentials;
- private Region region;
- private GlueClient glueClient;
-
- public GlueHelper(AwsCredentials credentials, Region region) {
- this.awsCredentials = credentials;
- this.region = region;
-
- var credProvider = StaticCredentialsProvider.create(credentials);
- this.glueClient = GlueClient.builder().region(region).credentialsProvider(credProvider).build();
- }
-
- private GetTablesIterable getAllTables(String DatabaseName) {
-
- GetTablesRequest getTablesRequest = GetTablesRequest.builder().databaseName(DatabaseName).build();
- GetTablesIterable getTablesPaginator = glueClient.getTablesPaginator(getTablesRequest);
-
- return getTablesPaginator;
- }
-
- private BatchDeleteTableRequest getBatchDeleteRequest(String databaseName, GetTablesIterable getTablesPaginator) {
- List tablesToDelete = new ArrayList();
- for (GetTablesResponse response : getTablesPaginator) {
- List tablePage = response.tableList();
- Iterator tableIterator = tablePage.iterator();
- while (tableIterator.hasNext()) {
- Table table = tableIterator.next();
- tablesToDelete.add(table.name());
- }
- }
- BatchDeleteTableRequest batchDeleteRequest = BatchDeleteTableRequest.builder().databaseName(databaseName).tablesToDelete(tablesToDelete).build();
- return batchDeleteRequest;
- }
-
- public void purgeDatabase(String databaseName) {
- int countRetries = 0;
- while (countRetries < 5) {
- try {
- GetTablesIterable allTables = getAllTables(databaseName);
- BatchDeleteTableRequest batchDeleteTableRequest = getBatchDeleteRequest(databaseName, allTables);
- glueClient.batchDeleteTable(batchDeleteTableRequest);
- return;
- } catch (Exception e) {
- countRetries++;
- }
- }
- }
-
-}
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/src/test-integration/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeDestinationAcceptanceTest.java b/airbyte-integrations/connectors/destination-aws-datalake/src/test-integration/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeDestinationAcceptanceTest.java
deleted file mode 100644
index f60620ff142d..000000000000
--- a/airbyte-integrations/connectors/destination-aws-datalake/src/test-integration/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeDestinationAcceptanceTest.java
+++ /dev/null
@@ -1,211 +0,0 @@
-/*
- * Copyright (c) 2023 Airbyte, Inc., all rights reserved.
- */
-
-package io.airbyte.integrations.destination.aws_datalake;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-import com.fasterxml.jackson.databind.JsonNode;
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.Maps;
-import io.airbyte.commons.io.IOs;
-import io.airbyte.commons.json.Jsons;
-import io.airbyte.integrations.standardtest.destination.DestinationAcceptanceTest;
-import io.airbyte.integrations.standardtest.destination.comparator.TestDataComparator;
-import java.io.IOException;
-import java.nio.file.Path;
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
-import software.amazon.awssdk.regions.Region;
-import software.amazon.awssdk.services.athena.model.ColumnInfo;
-import software.amazon.awssdk.services.athena.model.Datum;
-import software.amazon.awssdk.services.athena.model.GetQueryResultsResponse;
-import software.amazon.awssdk.services.athena.model.Row;
-import software.amazon.awssdk.services.athena.paginators.GetQueryResultsIterable;
-
-public class AwsDatalakeDestinationAcceptanceTest extends DestinationAcceptanceTest {
-
- private static final String CONFIG_PATH = "secrets/config.json";
- private static final Logger LOGGER = LoggerFactory.getLogger(AwsDatalakeDestinationAcceptanceTest.class);
-
- private JsonNode configJson;
- private JsonNode configInvalidCredentialsJson;
- protected AwsDatalakeDestinationConfig config;
- private AthenaHelper athenaHelper;
- private GlueHelper glueHelper;
-
- @Override
- protected String getImageName() {
- return "airbyte/destination-aws-datalake:dev";
- }
-
- @Override
- protected JsonNode getConfig() {
- // TODO: Generate the configuration JSON file to be used for running the destination during the test
- // configJson can either be static and read from secrets/config.json directly
- // or created in the setup method
- return configJson;
- }
-
- @Override
- protected JsonNode getFailCheckConfig() {
- JsonNode credentials = Jsons.jsonNode(ImmutableMap.builder()
- .put("credentials_title", "IAM User")
- .put("aws_access_key_id", "wrong-access-key")
- .put("aws_secret_access_key", "wrong-secret")
- .build());
-
- JsonNode config = Jsons.jsonNode(ImmutableMap.builder()
- .put("aws_account_id", "112233")
- .put("region", "us-east-1")
- .put("bucket_name", "test-bucket")
- .put("bucket_prefix", "test")
- .put("lakeformation_database_name", "lf_db")
- .put("credentials", credentials)
- .build());
-
- return config;
- }
-
- @Override
- protected List retrieveRecords(TestDestinationEnv testEnv,
- String streamName,
- String namespace,
- JsonNode streamSchema)
- throws IOException, InterruptedException {
- String query = String.format("SELECT * FROM \"%s\".\"%s\"", config.getDatabaseName(), streamName);
- GetQueryResultsIterable results = athenaHelper.runQuery(config.getDatabaseName(), query);
- return parseResults(results);
- }
-
- protected List parseResults(GetQueryResultsIterable queryResults) {
-
- List processedResults = new ArrayList<>();
-
- for (GetQueryResultsResponse result : queryResults) {
- List columnInfoList = result.resultSet().resultSetMetadata().columnInfo();
- Iterator results = result.resultSet().rows().iterator();
- Row colNamesRow = results.next();
- while (results.hasNext()) {
- Map jsonMap = Maps.newHashMap();
- Row r = results.next();
- Iterator colInfoIterator = columnInfoList.iterator();
- Iterator datum = r.data().iterator();
- while (colInfoIterator.hasNext() && datum.hasNext()) {
- ColumnInfo colInfo = colInfoIterator.next();
- Datum value = datum.next();
- Object typedFieldValue = getTypedFieldValue(colInfo, value);
- if (typedFieldValue != null) {
- jsonMap.put(colInfo.name(), typedFieldValue);
- }
- }
- processedResults.add(Jsons.jsonNode(jsonMap));
- }
- }
- return processedResults;
- }
-
- private static Object getTypedFieldValue(ColumnInfo colInfo, Datum value) {
- var typeName = colInfo.type();
- var varCharValue = value.varCharValue();
-
- if (varCharValue == null)
- return null;
- var returnType = switch (typeName) {
- case "real", "double", "float" -> Double.parseDouble(varCharValue);
- case "varchar" -> varCharValue;
- case "boolean" -> Boolean.parseBoolean(varCharValue);
- case "integer" -> Integer.parseInt(varCharValue);
- case "row" -> varCharValue;
- default -> null;
- };
- if (returnType == null) {
- LOGGER.warn(String.format("Unsupported type = %s", typeName));
- }
- return returnType;
- }
-
- @Override
- protected List resolveIdentifier(String identifier) {
- final List result = new ArrayList<>();
- result.add(identifier);
- result.add(identifier.toLowerCase());
- return result;
- }
-
- private JsonNode loadJsonFile(String fileName) throws IOException {
- final JsonNode configFromSecrets = Jsons.deserialize(IOs.readFile(Path.of(fileName)));
- return (configFromSecrets);
- }
-
- @Override
- protected void setup(TestDestinationEnv testEnv) throws IOException {
- configJson = loadJsonFile(CONFIG_PATH);
-
- this.config = AwsDatalakeDestinationConfig.getAwsDatalakeDestinationConfig(configJson);
-
- Region region = Region.of(config.getRegion());
-
- AwsBasicCredentials awsCreds = AwsBasicCredentials.create(config.getAccessKeyId(), config.getSecretAccessKey());
- athenaHelper = new AthenaHelper(awsCreds, region, String.format("s3://%s/airbyte_athena/", config.getBucketName()),
- "AmazonAthenaLakeFormation");
- glueHelper = new GlueHelper(awsCreds, region);
- glueHelper.purgeDatabase(config.getDatabaseName());
- }
-
- private String toAthenaObject(JsonNode value) {
- StringBuilder sb = new StringBuilder("\"{");
- List elements = new ArrayList<>();
- var it = value.fields();
- while (it.hasNext()) {
- Map.Entry f = it.next();
- final String k = f.getKey();
- final String v = f.getValue().asText();
- elements.add(String.format("%s=%s", k, v));
- }
- sb.append(String.join(",", elements));
- sb.append("}\"");
- return sb.toString();
- }
-
- protected void assertSameValue(final String key, final JsonNode expectedValue, final JsonNode actualValue) {
- if (expectedValue.isObject()) {
- assertEquals(toAthenaObject(expectedValue), actualValue.toString());
- } else {
- assertEquals(expectedValue, actualValue);
- }
- }
-
- @Override
- protected void tearDown(TestDestinationEnv testEnv) {
- // TODO Implement this method to run any cleanup actions needed after every test case
- // glueHelper.purgeDatabase(config.getDatabaseName());
- }
-
- @Override
- protected TestDataComparator getTestDataComparator() {
- return new AwsDatalakeTestDataComparator();
- }
-
- @Override
- protected boolean supportBasicDataTypeTest() {
- return true;
- }
-
- @Override
- protected boolean supportArrayDataTypeTest() {
- return false;
- }
-
- @Override
- protected boolean supportObjectDataTypeTest() {
- return true;
- }
-
-}
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/src/test-integration/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeTestDataComparator.java b/airbyte-integrations/connectors/destination-aws-datalake/src/test-integration/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeTestDataComparator.java
deleted file mode 100644
index 20736ebb1647..000000000000
--- a/airbyte-integrations/connectors/destination-aws-datalake/src/test-integration/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeTestDataComparator.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Copyright (c) 2023 Airbyte, Inc., all rights reserved.
- */
-
-package io.airbyte.integrations.destination.aws_datalake;
-
-import io.airbyte.integrations.standardtest.destination.comparator.AdvancedTestDataComparator;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Locale;
-
-public class AwsDatalakeTestDataComparator extends AdvancedTestDataComparator {
-
- @Override
- protected List resolveIdentifier(String identifier) {
- final List result = new ArrayList<>();
- result.add(identifier);
- result.add(identifier.toLowerCase(Locale.ROOT));
- return result;
- }
-
-}
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/src/test/AwsDatalakeDestinationTest.java b/airbyte-integrations/connectors/destination-aws-datalake/src/test/AwsDatalakeDestinationTest.java
deleted file mode 100644
index 6c55a807d2ed..000000000000
--- a/airbyte-integrations/connectors/destination-aws-datalake/src/test/AwsDatalakeDestinationTest.java
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Copyright (c) 2023 Airbyte, Inc., all rights reserved.
- */
-
-package io.airbyte.integrations.destination.aws_datalake;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-import com.fasterxml.jackson.databind.JsonNode;
-import io.airbyte.commons.json.Jsons;
-import org.junit.jupiter.api.Test;
-
-class AwsDatalakeDestinationTest {
-
- /*
- * @Test void testGetOutputTableNameWithString() throws Exception { var actual =
- * DynamodbOutputTableHelper.getOutputTableName("test_table", "test_namespace", "test_stream");
- * assertEquals("test_table_test_namespace_test_stream", actual); }
- *
- * @Test void testGetOutputTableNameWithStream() throws Exception { var stream = new
- * AirbyteStream(); stream.setName("test_stream"); stream.setNamespace("test_namespace"); var actual
- * = DynamodbOutputTableHelper.getOutputTableName("test_table", stream);
- * assertEquals("test_table_test_namespace_test_stream", actual); }
- */
- @Test
- void testGetAwsDatalakeDestinationdbConfig() throws Exception {
- JsonNode json = Jsons.deserialize("""
- {
- "bucket_prefix": "test_prefix",
- "region": "test_region",
- "auth_mode": "USER",
- "bucket_name": "test_bucket",
- "aws_access_key_id": "test_access_key",
- "aws_account_id": "test_account_id",
- "lakeformation_database_name": "test_database",
- "aws_secret_access_key": "test_secret"
- }""");
-
- var config = AwsDatalakeDestinationConfig.getAwsDatalakeDestinationConfig(json);
-
- assertEquals(config.getBucketPrefix(), "test_prefix");
- assertEquals(config.getRegion(), "test_region");
- assertEquals(config.getAccessKeyId(), "test_access_key");
- assertEquals(config.getSecretAccessKey(), "test_secret");
- }
-
-}
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/__init__.py b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/aws_handler_test.py b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/aws_handler_test.py
new file mode 100644
index 000000000000..ce84337631c0
--- /dev/null
+++ b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/aws_handler_test.py
@@ -0,0 +1,55 @@
+#
+# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
+#
+
+import json
+from typing import Any, Mapping
+
+import pytest
+from destination_aws_datalake import DestinationAwsDatalake
+from destination_aws_datalake.aws import AwsHandler
+from destination_aws_datalake.config_reader import CompressionCodec, ConnectorConfig
+
+
+@pytest.fixture(name="config")
+def config() -> Mapping[str, Any]:
+ with open("unit_tests/fixtures/config.json", "r") as f:
+ return json.loads(f.read())
+
+
+@pytest.fixture(name="config_prefix")
+def config_prefix() -> Mapping[str, Any]:
+ with open("unit_tests/fixtures/config_prefix.json", "r") as f:
+ return json.loads(f.read())
+
+
+def test_get_compression_type(config: Mapping[str, Any]):
+ aws_handler = AwsHandler(ConnectorConfig(**config), DestinationAwsDatalake())
+
+ tests = {
+ CompressionCodec.GZIP: "gzip",
+ CompressionCodec.SNAPPY: "snappy",
+ CompressionCodec.ZSTD: "zstd",
+ "LZO": None,
+ }
+
+ for codec, expected in tests.items():
+ assert aws_handler._get_compression_type(codec) == expected
+
+
+def test_get_path(config: Mapping[str, Any]):
+ conf = ConnectorConfig(**config)
+ aws_handler = AwsHandler(conf, DestinationAwsDatalake())
+
+ tbl = "append_stream"
+ db = conf.lakeformation_database_name
+ assert aws_handler._get_s3_path(db, tbl) == "s3://datalake-bucket/test/append_stream/"
+
+
+def test_get_path_prefix(config_prefix: Mapping[str, Any]):
+ conf = ConnectorConfig(**config_prefix)
+ aws_handler = AwsHandler(conf, DestinationAwsDatalake())
+
+ tbl = "append_stream"
+ db = conf.lakeformation_database_name
+ assert aws_handler._get_s3_path(db, tbl) == "s3://datalake-bucket/prefix/test/append_stream/"
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/fixtures/config.json b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/fixtures/config.json
new file mode 100644
index 000000000000..8c660ecca6d9
--- /dev/null
+++ b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/fixtures/config.json
@@ -0,0 +1,16 @@
+{
+ "aws_account_id": "111111111111",
+ "credentials": {
+ "credentials_title": "IAM User",
+ "aws_access_key_id": "aws_key_id",
+ "aws_secret_access_key": "aws_secret_key"
+ },
+ "region": "us-east-1",
+ "bucket_name": "datalake-bucket",
+ "lakeformation_database_name": "test",
+ "format": {
+ "format_type": "Parquet",
+ "compression_codec": "SNAPPY"
+ },
+ "partitioning": "NO PARTITIONING"
+}
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/fixtures/config_prefix.json b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/fixtures/config_prefix.json
new file mode 100644
index 000000000000..1f6cca04100f
--- /dev/null
+++ b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/fixtures/config_prefix.json
@@ -0,0 +1,17 @@
+{
+ "aws_account_id": "111111111111",
+ "credentials": {
+ "credentials_title": "IAM User",
+ "aws_access_key_id": "aws_key_id",
+ "aws_secret_access_key": "aws_secret_key"
+ },
+ "region": "us-east-1",
+ "bucket_name": "datalake-bucket",
+ "bucket_prefix": "prefix",
+ "lakeformation_database_name": "test",
+ "format": {
+ "format_type": "Parquet",
+ "compression_codec": "SNAPPY"
+ },
+ "partitioning": "NO PARTITIONING"
+}
diff --git a/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/stream_writer_test.py b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/stream_writer_test.py
new file mode 100644
index 000000000000..7e82be769f1d
--- /dev/null
+++ b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/stream_writer_test.py
@@ -0,0 +1,308 @@
+#
+# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
+#
+
+import json
+from datetime import datetime
+from typing import Any, Dict, Mapping
+
+import pandas as pd
+from airbyte_cdk.models import AirbyteStream, ConfiguredAirbyteStream, DestinationSyncMode, SyncMode
+from destination_aws_datalake import DestinationAwsDatalake
+from destination_aws_datalake.aws import AwsHandler
+from destination_aws_datalake.config_reader import ConnectorConfig
+from destination_aws_datalake.stream_writer import StreamWriter
+
+
+def get_config() -> Mapping[str, Any]:
+ with open("unit_tests/fixtures/config.json", "r") as f:
+ return json.loads(f.read())
+
+
+def get_configured_stream():
+ stream_name = "append_stream"
+ stream_schema = {
+ "type": "object",
+ "properties": {
+ "string_col": {"type": "str"},
+ "int_col": {"type": "integer"},
+ "datetime_col": {"type": "string", "format": "date-time"},
+ "date_col": {"type": "string", "format": "date"},
+ },
+ }
+
+ return ConfiguredAirbyteStream(
+ stream=AirbyteStream(
+ name=stream_name,
+ json_schema=stream_schema,
+ default_cursor_field=["datetime_col"],
+ supported_sync_modes=[SyncMode.incremental, SyncMode.full_refresh],
+ ),
+ sync_mode=SyncMode.incremental,
+ destination_sync_mode=DestinationSyncMode.append,
+ cursor_field=["datetime_col"],
+ )
+
+
+def get_writer(config: Dict[str, Any]):
+ connector_config = ConnectorConfig(**config)
+ aws_handler = AwsHandler(connector_config, DestinationAwsDatalake())
+ return StreamWriter(aws_handler, connector_config, get_configured_stream())
+
+
+def get_big_schema_configured_stream():
+ stream_name = "append_stream_big"
+ stream_schema = {
+ "$schema": "http://json-schema.org/draft-07/schema#",
+ "type": ["null", "object"],
+ "properties": {
+ "appId": {"type": ["null", "integer"]},
+ "appName": {"type": ["null", "string"]},
+ "bounced": {"type": ["null", "boolean"]},
+ "browser": {
+ "type": ["null", "object"],
+ "properties": {
+ "family": {"type": ["null", "string"]},
+ "name": {"type": ["null", "string"]},
+ "producer": {"type": ["null", "string"]},
+ "producerUrl": {"type": ["null", "string"]},
+ "type": {"type": ["null", "string"]},
+ "url": {"type": ["null", "string"]},
+ "version": {"type": ["null", "array"], "items": {"type": ["null", "string"]}},
+ },
+ },
+ "causedBy": {
+ "type": ["null", "object"],
+ "properties": {"created": {"type": ["null", "integer"]}, "id": {"type": ["null", "string"]}},
+ },
+ "percentage": {"type": ["null", "number"]},
+ "location": {
+ "type": ["null", "object"],
+ "properties": {
+ "city": {"type": ["null", "string"]},
+ "country": {"type": ["null", "string"]},
+ "latitude": {"type": ["null", "number"]},
+ "longitude": {"type": ["null", "number"]},
+ "state": {"type": ["null", "string"]},
+ "zipcode": {"type": ["null", "string"]},
+ },
+ },
+ "nestedJson": {
+ "type": ["null", "object"],
+ "properties": {
+ "city": {"type": "object", "properties": {"name": {"type": "string"}}},
+ },
+ },
+ "sentBy": {
+ "type": ["null", "object"],
+ "properties": {"created": {"type": ["null", "integer"]}, "id": {"type": ["null", "string"]}},
+ },
+ "sentAt": {"type": ["null", "string"], "format": "date-time"},
+ "receivedAt": {"type": ["null", "string"], "format": "date"},
+ "sourceId": {"type": "string"},
+ "status": {"type": "integer"},
+ "read": {"type": "boolean"},
+ "questions": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {"id": {"type": ["null", "integer"]}, "question": {"type": "string"}, "answer": {"type": "string"}},
+ },
+ },
+ "questions_nested": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "id": {"type": ["null", "integer"]},
+ "questions": {"type": "object", "properties": {"title": {"type": "string"}, "option": {"type": "integer"}}},
+ "answer": {"type": "string"},
+ },
+ },
+ },
+ "nested_mixed_types": {
+ "type": ["null", "object"],
+ "properties": {
+ "city": {"type": ["string", "integer", "null"]},
+ },
+ },
+ "nested_bad_object": {
+ "type": ["null", "object"],
+ "properties": {
+ "city": {"type": "object", "properties": {}},
+ },
+ },
+ "nested_nested_bad_object": {
+ "type": ["null", "object"],
+ "properties": {
+ "city": {
+ "type": "object",
+ "properties": {
+ "name": {"type": "object", "properties": {}},
+ },
+ },
+ },
+ },
+ "answers": {
+ "type": "array",
+ "items": {
+ "type": "string",
+ },
+ },
+ "answers_nested_bad": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "id": {"type": ["string", "integer"]},
+ },
+ },
+ },
+ "phone_number_ids": {"type": ["null", "array"], "items": {"type": ["string", "integer"]}},
+ "mixed_type_simple": {
+ "type": ["integer", "number"],
+ },
+ "empty_array": {"type": ["null", "array"]},
+ "airbyte_type_object": {"type": "number", "airbyte_type": "integer"},
+ },
+ }
+
+ return ConfiguredAirbyteStream(
+ stream=AirbyteStream(
+ name=stream_name,
+ json_schema=stream_schema,
+ default_cursor_field=["datetime_col"],
+ supported_sync_modes=[SyncMode.incremental, SyncMode.full_refresh],
+ ),
+ sync_mode=SyncMode.incremental,
+ destination_sync_mode=DestinationSyncMode.append,
+ cursor_field=["datetime_col"],
+ )
+
+
+def get_big_schema_writer(config: Dict[str, Any]):
+ connector_config = ConnectorConfig(**config)
+ aws_handler = AwsHandler(connector_config, DestinationAwsDatalake())
+ return StreamWriter(aws_handler, connector_config, get_big_schema_configured_stream())
+
+
+def test_get_date_columns():
+ writer = get_writer(get_config())
+ assert writer._get_date_columns() == ["datetime_col", "date_col"]
+
+
+def test_append_messsage():
+ writer = get_writer(get_config())
+ message = {"string_col": "test", "int_col": 1, "datetime_col": "2021-01-01T00:00:00Z", "date_col": "2021-01-01"}
+ writer.append_message(message)
+ assert len(writer._messages) == 1
+ assert writer._messages[0] == message
+
+
+def test_get_cursor_field():
+ writer = get_writer(get_config())
+ assert writer._cursor_fields == ["datetime_col"]
+
+
+def test_add_partition_column():
+ tests = {
+ "NO PARTITIONING": {},
+ "DATE": {"datetime_col_date": "date"},
+ "MONTH": {"datetime_col_month": "bigint"},
+ "YEAR": {"datetime_col_year": "bigint"},
+ "DAY": {"datetime_col_day": "bigint"},
+ "YEAR/MONTH/DAY": {"datetime_col_year": "bigint", "datetime_col_month": "bigint", "datetime_col_day": "bigint"},
+ }
+
+ for partitioning, expected_columns in tests.items():
+ config = get_config()
+ config["partitioning"] = partitioning
+
+ writer = get_writer(config)
+ df = pd.DataFrame(
+ {
+ "datetime_col": [datetime.now()],
+ }
+ )
+ assert writer._add_partition_column("datetime_col", df) == expected_columns
+ assert all([col in df.columns for col in expected_columns])
+
+
+def test_get_glue_dtypes_from_json_schema():
+ writer = get_big_schema_writer(get_config())
+ result, json_casts = writer._get_glue_dtypes_from_json_schema(writer._schema)
+ assert result == {
+ "airbyte_type_object": "bigint",
+ "answers": "array",
+ "answers_nested_bad": "string",
+ "appId": "bigint",
+ "appName": "string",
+ "bounced": "boolean",
+ "browser": "struct>",
+ "causedBy": "struct",
+ "empty_array": "string",
+ "location": "struct",
+ "mixed_type_simple": "string",
+ "nestedJson": "struct>",
+ "nested_bad_object": "string",
+ "nested_mixed_types": "string",
+ "nested_nested_bad_object": "string",
+ "percentage": "double",
+ "phone_number_ids": "string",
+ "questions": "array>",
+ "questions_nested": "array,answer:string>>",
+ "read": "boolean",
+ "receivedAt": "date",
+ "sentAt": "timestamp",
+ "sentBy": "struct",
+ "sourceId": "string",
+ "status": "bigint",
+ }
+
+ assert json_casts == {
+ "answers_nested_bad",
+ "empty_array",
+ "nested_bad_object",
+ "nested_mixed_types",
+ "nested_nested_bad_object",
+ "phone_number_ids",
+ }
+
+
+def test_has_objects_with_no_properties_good():
+ writer = get_big_schema_writer(get_config())
+ assert writer._is_invalid_struct_or_array(
+ {
+ "nestedJson": {
+ "type": ["null", "object"],
+ "properties": {
+ "city": {"type": "object", "properties": {"name": {"type": "string"}}},
+ },
+ }
+ }
+ )
+
+
+def test_has_objects_with_no_properties_bad():
+ writer = get_big_schema_writer(get_config())
+ assert not writer._is_invalid_struct_or_array(
+ {
+ "nestedJson": {
+ "type": ["null", "object"],
+ }
+ }
+ )
+
+
+def test_has_objects_with_no_properties_nested_bad():
+ writer = get_big_schema_writer(get_config())
+ assert not writer._is_invalid_struct_or_array(
+ {
+ "nestedJson": {
+ "type": ["null", "object"],
+ "properties": {
+ "city": {"type": "object", "properties": {}},
+ },
+ }
+ }
+ )
diff --git a/connectors.md b/connectors.md
index 10ef495da885..4160adb57f65 100644
--- a/connectors.md
+++ b/connectors.md
@@ -286,7 +286,7 @@
| Name | Icon | Type | Image | Release Stage | Docs | Code | ID |
|----|----|----|----|----|----|----|----|
-| **AWS Datalake** | x | Destination | airbyte/destination-aws-datalake:0.1.1 | alpha | [link](https://docs.airbyte.com/integrations/destinations/aws-datalake) | [code](https://github.com/airbytehq/airbyte/tree/master/airbyte-integrations/connectors/destination-aws-datalake) | `99878c90-0fbd-46d3-9d98-ffde879d17fc` |
+| **AWS Datalake** | | Destination | airbyte/destination-aws-datalake:0.1.2 | alpha | [link](https://docs.airbyte.com/integrations/destinations/aws-datalake) | [code](https://github.com/airbytehq/airbyte/tree/master/airbyte-integrations/connectors/destination-aws-datalake) | `99878c90-0fbd-46d3-9d98-ffde879d17fc` |
| **Amazon SQS** | | Destination | airbyte/destination-amazon-sqs:0.1.0 | alpha | [link](https://docs.airbyte.com/integrations/destinations/amazon-sqs) | [code](https://github.com/airbytehq/airbyte/tree/master/airbyte-integrations/connectors/destination-amazon-sqs) | `0eeee7fb-518f-4045-bacc-9619e31c43ea` |
| **Apache Doris** | | Destination | airbyte/destination-doris:0.1.0 | alpha | [link](https://docs.airbyte.com/integrations/destinations/doris) | [code](https://github.com/airbytehq/airbyte/tree/master/airbyte-integrations/connectors/destination-doris) | `05c161bf-ca73-4d48-b524-d392be417002` |
| **Apache Iceberg** | x | Destination | airbyte/destination-iceberg:0.1.0 | alpha | [link](https://docs.airbyte.com/integrations/destinations/iceberg) | [code](https://github.com/airbytehq/airbyte/tree/master/airbyte-integrations/connectors/destination-iceberg) | `df65a8f3-9908-451b-aa9b-445462803560` |
diff --git a/docs/integrations/destinations/aws-datalake.md b/docs/integrations/destinations/aws-datalake.md
index dbfa1db6d42e..e3a69b737b5c 100644
--- a/docs/integrations/destinations/aws-datalake.md
+++ b/docs/integrations/destinations/aws-datalake.md
@@ -2,7 +2,7 @@
This page contains the setup guide and reference information for the AWS Datalake destination connector.
-The AWS Datalake destination connector allows you to sync data to AWS. It will write data as JSON files in S3 and
+The AWS Datalake destination connector allows you to sync data to AWS. It will write data as JSON files in S3 and
will make it available through a [Lake Formation Governed Table](https://docs.aws.amazon.com/lake-formation/latest/dg/governed-tables.html) in the Glue Data Catalog so that the data is available throughout other AWS services such as Athena, Glue jobs, EMR, Redshift, etc.
## Prerequisites
@@ -69,5 +69,6 @@ and types in the destination table as in the source except for the following typ
## Changelog
+| 0.1.2 | 2022-09-26 | [\#17193](https://github.com/airbytehq/airbyte/pull/17193) | Fix schema keyerror and add parquet support |
| 0.1.1 | 2022-04-20 | [\#11811](https://github.com/airbytehq/airbyte/pull/11811) | Fix name of required param in specification |
| 0.1.0 | 2022-03-29 | [\#10760](https://github.com/airbytehq/airbyte/pull/10760) | Initial release |
\ No newline at end of file