From efcec106f5756a159ce317896923bdb30caadf3b Mon Sep 17 00:00:00 2001 From: Greg Solovyev Date: Mon, 27 Mar 2023 10:52:04 -0700 Subject: [PATCH] Greg/community pr 17193 (#23855) * [UPD] Add format and partitioning spec Signed-off-by: Henri Blancke * [UPD] Add parquet format, compression and partitioning Signed-off-by: Henri Blancke * [ADD] Integration and unit tests Signed-off-by: Henri Blancke * [UPD] Update docs and add bootstrap Signed-off-by: Henri Blancke * [UPD] bump version to 0.1.2 Signed-off-by: Henri Blancke * [ADD] Changelog entry Signed-off-by: Henri Blancke * [FIX] typo Signed-off-by: Henri Blancke * [FIX] cast arrays with mixed types to json string Signed-off-by: Henri Blancke * [FIX] issues when casting athena to pandas types Signed-off-by: Henri Blancke * [UPD] cleanup Signed-off-by: Henri Blancke * [UPD] flush interval to reduce memory usage Signed-off-by: Henri Blancke * [UPD] allow state reset per stream Signed-off-by: Henri Blancke * [UPD] capitalize AWS Signed-off-by: Henri Blancke * [ADD] decimal support and default db LF-tags Signed-off-by: Henri Blancke * [UPD] account for type error Signed-off-by: Henri Blancke * [FIX] partition field duplication Signed-off-by: Henri Blancke * [UPD] bump awswrangler (fixes json compression issue) Signed-off-by: Henri Blancke * [UPD] refactor, infer pandas and glue types from json schema Signed-off-by: Henri Blancke * [FIX] default for items get Signed-off-by: Henri Blancke * [UPD] account for mixed type properties Signed-off-by: Henri Blancke * [UPD] bad complex types to json string Signed-off-by: Henri Blancke * [UPD] drop top keys when not in json schema Signed-off-by: Henri Blancke * [UPD] fix partitioning, add airbyte type, fix keyerror concurrent partitioning Signed-off-by: Henri Blancke * [UPD] make table type configurable Signed-off-by: Henri Blancke * [UPD] fix obvious type violations Signed-off-by: Henri Blancke * [FIX] add missing columns to create correct schema Signed-off-by: Henri Blancke * [FIX] integration test Signed-off-by: Henri Blancke * [UPD] formatting and flake Signed-off-by: Henri Blancke * [FIX] overwrite partial flush bug Signed-off-by: Henri Blancke * [FIX] integration tests Signed-off-by: Henri Blancke * fix formatting * [FIX] check and typing Signed-off-by: Henri Blancke * [UPD] rmv fillna Signed-off-by: Henri Blancke * [UPD] warn on reset Signed-off-by: Henri Blancke * [FIX] log on failed reset Signed-off-by: Henri Blancke * [UPD] cast bool Signed-off-by: Henri Blancke * [FIX] cast pandas columns bool casting Signed-off-by: Henri Blancke * [FIX] required spec and format defaults Signed-off-by: Henri Blancke * [ADD] icon Signed-off-by: Henri Blancke * [UPD] address review comments Signed-off-by: Henri Blancke * Automated Change * auto-bump connector version --------- Signed-off-by: Henri Blancke Co-authored-by: Henri Blancke Co-authored-by: Marcos Marx Co-authored-by: Sunny <6833405+sh4sh@users.noreply.github.com> Co-authored-by: Octavia Squidington III --- .../src/main/resources/icons/awsdatalake.svg | 10 + .../seed/destination_definitions.yaml | 3 +- .../resources/seed/destination_specs.yaml | 146 +++++- .../destination-aws-datalake/BOOTSTRAP.md | 5 + .../destination-aws-datalake/Dockerfile | 2 +- .../destination-aws-datalake/README.md | 28 +- .../destination_aws_datalake/aws.py | 455 ++++++++---------- .../destination_aws_datalake/config_reader.py | 89 +++- .../destination_aws_datalake/destination.py | 106 ++-- .../destination_aws_datalake/spec.json | 156 +++++- .../destination_aws_datalake/stream_writer.py | 455 +++++++++++++++--- .../__init__.py} | 4 - .../integration_tests/integration_test.py | 160 ++++++ .../invalid_account_config.json | 11 + .../invalid_region_config.json | 10 + .../destination-aws-datalake/setup.py | 8 +- .../aws_datalake/AthenaHelper.java | 129 ----- .../AwsDatalakeDestinationConfig.java | 79 --- .../destination/aws_datalake/GlueHelper.java | 70 --- .../AwsDatalakeDestinationAcceptanceTest.java | 211 -------- .../AwsDatalakeTestDataComparator.java | 22 - .../src/test/AwsDatalakeDestinationTest.java | 47 -- .../unit_tests/__init__.py | 0 .../unit_tests/aws_handler_test.py | 55 +++ .../unit_tests/fixtures/config.json | 16 + .../unit_tests/fixtures/config_prefix.json | 17 + .../unit_tests/stream_writer_test.py | 308 ++++++++++++ connectors.md | 2 +- .../integrations/destinations/aws-datalake.md | 3 +- 29 files changed, 1661 insertions(+), 946 deletions(-) create mode 100644 airbyte-config/init/src/main/resources/icons/awsdatalake.svg create mode 100644 airbyte-integrations/connectors/destination-aws-datalake/BOOTSTRAP.md rename airbyte-integrations/connectors/destination-aws-datalake/{unit_tests/unit_test.py => integration_tests/__init__.py} (57%) create mode 100644 airbyte-integrations/connectors/destination-aws-datalake/integration_tests/integration_test.py create mode 100644 airbyte-integrations/connectors/destination-aws-datalake/integration_tests/invalid_account_config.json create mode 100644 airbyte-integrations/connectors/destination-aws-datalake/integration_tests/invalid_region_config.json delete mode 100644 airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/AthenaHelper.java delete mode 100644 airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeDestinationConfig.java delete mode 100644 airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/GlueHelper.java delete mode 100644 airbyte-integrations/connectors/destination-aws-datalake/src/test-integration/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeDestinationAcceptanceTest.java delete mode 100644 airbyte-integrations/connectors/destination-aws-datalake/src/test-integration/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeTestDataComparator.java delete mode 100644 airbyte-integrations/connectors/destination-aws-datalake/src/test/AwsDatalakeDestinationTest.java create mode 100644 airbyte-integrations/connectors/destination-aws-datalake/unit_tests/__init__.py create mode 100644 airbyte-integrations/connectors/destination-aws-datalake/unit_tests/aws_handler_test.py create mode 100644 airbyte-integrations/connectors/destination-aws-datalake/unit_tests/fixtures/config.json create mode 100644 airbyte-integrations/connectors/destination-aws-datalake/unit_tests/fixtures/config_prefix.json create mode 100644 airbyte-integrations/connectors/destination-aws-datalake/unit_tests/stream_writer_test.py diff --git a/airbyte-config/init/src/main/resources/icons/awsdatalake.svg b/airbyte-config/init/src/main/resources/icons/awsdatalake.svg new file mode 100644 index 000000000000..5ea8d6d6aab5 --- /dev/null +++ b/airbyte-config/init/src/main/resources/icons/awsdatalake.svg @@ -0,0 +1,10 @@ + + + + + + + + diff --git a/airbyte-config/init/src/main/resources/seed/destination_definitions.yaml b/airbyte-config/init/src/main/resources/seed/destination_definitions.yaml index 3238f29c174b..9a714a1ca387 100644 --- a/airbyte-config/init/src/main/resources/seed/destination_definitions.yaml +++ b/airbyte-config/init/src/main/resources/seed/destination_definitions.yaml @@ -34,8 +34,9 @@ - name: AWS Datalake destinationDefinitionId: 99878c90-0fbd-46d3-9d98-ffde879d17fc dockerRepository: airbyte/destination-aws-datalake - dockerImageTag: 0.1.1 + dockerImageTag: 0.1.2 documentationUrl: https://docs.airbyte.com/integrations/destinations/aws-datalake + icon: awsdatalake.svg releaseStage: alpha - name: BigQuery destinationDefinitionId: 22f6c74f-5699-40ff-833c-4a879ea40133 diff --git a/airbyte-config/init/src/main/resources/seed/destination_specs.yaml b/airbyte-config/init/src/main/resources/seed/destination_specs.yaml index 6a56f7ff7b6c..f213f702217f 100644 --- a/airbyte-config/init/src/main/resources/seed/destination_specs.yaml +++ b/airbyte-config/init/src/main/resources/seed/destination_specs.yaml @@ -533,7 +533,7 @@ supported_destination_sync_modes: - "overwrite" - "append" -- dockerImage: "airbyte/destination-aws-datalake:0.1.1" +- dockerImage: "airbyte/destination-aws-datalake:0.1.2" spec: documentationUrl: "https://docs.airbyte.com/integrations/destinations/aws-datalake" connectionSpecification: @@ -544,7 +544,7 @@ - "credentials" - "region" - "bucket_name" - - "bucket_prefix" + - "lakeformation_database_name" additionalProperties: false properties: aws_account_id: @@ -553,11 +553,7 @@ description: "target aws account id" examples: - "111111111111" - region: - title: "AWS Region" - type: "string" - description: "Region name" - airbyte_secret: false + order: 1 credentials: title: "Authentication mode" description: "Choose How to Authenticate to AWS." @@ -609,21 +605,145 @@ type: "string" description: "Secret Access Key" airbyte_secret: true + order: 2 + region: + title: "S3 Bucket Region" + type: "string" + default: "" + description: "The region of the S3 bucket. See here for all region codes." + enum: + - "" + - "us-east-1" + - "us-east-2" + - "us-west-1" + - "us-west-2" + - "af-south-1" + - "ap-east-1" + - "ap-south-1" + - "ap-northeast-1" + - "ap-northeast-2" + - "ap-northeast-3" + - "ap-southeast-1" + - "ap-southeast-2" + - "ca-central-1" + - "cn-north-1" + - "cn-northwest-1" + - "eu-central-1" + - "eu-north-1" + - "eu-south-1" + - "eu-west-1" + - "eu-west-2" + - "eu-west-3" + - "sa-east-1" + - "me-south-1" + - "us-gov-east-1" + - "us-gov-west-1" + order: 3 bucket_name: title: "S3 Bucket Name" type: "string" - description: "Name of the bucket" - airbyte_secret: false + description: "The name of the S3 bucket. Read more here." + order: 4 bucket_prefix: title: "Target S3 Bucket Prefix" type: "string" description: "S3 prefix" - airbyte_secret: false + order: 5 lakeformation_database_name: - title: "Lakeformation Database Name" + title: "Lake Formation Database Name" type: "string" - description: "Which database to use" - airbyte_secret: false + description: "The default database this destination will use to create tables\ + \ in per stream. Can be changed per connection by customizing the namespace." + order: 6 + lakeformation_database_default_tag_key: + title: "Lake Formation Database Tag Key" + description: "Add a default tag key to databases created by this destination" + examples: + - "pii_level" + type: "string" + order: 7 + lakeformation_database_default_tag_values: + title: "Lake Formation Database Tag Values" + description: "Add default values for the `Tag Key` to databases created\ + \ by this destination. Comma separate for multiple values." + examples: + - "private,public" + type: "string" + order: 8 + lakeformation_governed_tables: + title: "Lake Formation Governed Tables" + description: "Whether to create tables as LF governed tables." + type: "boolean" + default: false + order: 9 + format: + title: "Output Format *" + type: "object" + description: "Format of the data output." + oneOf: + - title: "JSON Lines: Newline-delimited JSON" + required: + - "format_type" + properties: + format_type: + title: "Format Type *" + type: "string" + enum: + - "JSONL" + default: "JSONL" + compression_codec: + title: "Compression Codec (Optional)" + description: "The compression algorithm used to compress data." + type: "string" + enum: + - "UNCOMPRESSED" + - "GZIP" + default: "UNCOMPRESSED" + - title: "Parquet: Columnar Storage" + required: + - "format_type" + properties: + format_type: + title: "Format Type *" + type: "string" + enum: + - "Parquet" + default: "Parquet" + compression_codec: + title: "Compression Codec (Optional)" + description: "The compression algorithm used to compress data." + type: "string" + enum: + - "UNCOMPRESSED" + - "SNAPPY" + - "GZIP" + - "ZSTD" + default: "SNAPPY" + order: 10 + partitioning: + title: "Choose how to partition data" + description: "Partition data by cursor fields when a cursor field is a date" + type: "string" + enum: + - "NO PARTITIONING" + - "DATE" + - "YEAR" + - "MONTH" + - "DAY" + - "YEAR/MONTH" + - "YEAR/MONTH/DAY" + default: "NO PARTITIONING" + order: 11 + glue_catalog_float_as_decimal: + title: "Glue Catalog: Float as Decimal" + description: "Cast float/double as decimal(38,18). This can help achieve\ + \ higher accuracy and represent numbers correctly as received from the\ + \ source." + type: "boolean" + default: false + order: 12 supportsIncremental: true supportsNormalization: false supportsDBT: false diff --git a/airbyte-integrations/connectors/destination-aws-datalake/BOOTSTRAP.md b/airbyte-integrations/connectors/destination-aws-datalake/BOOTSTRAP.md new file mode 100644 index 000000000000..16998e8a3326 --- /dev/null +++ b/airbyte-integrations/connectors/destination-aws-datalake/BOOTSTRAP.md @@ -0,0 +1,5 @@ +# AWS Lake Formation Destination Connector Bootstrap + +This destination syncs your data to s3 and aws data lake and will automatically create a glue catalog databases and tables for you. + +See [this](https://docs.aws.amazon.com/lake-formation/latest/dg/how-it-works.html) to learn more about AWS Lake Formation. diff --git a/airbyte-integrations/connectors/destination-aws-datalake/Dockerfile b/airbyte-integrations/connectors/destination-aws-datalake/Dockerfile index ebe2844b7244..a036661f7da7 100644 --- a/airbyte-integrations/connectors/destination-aws-datalake/Dockerfile +++ b/airbyte-integrations/connectors/destination-aws-datalake/Dockerfile @@ -13,5 +13,5 @@ RUN pip install . ENV AIRBYTE_ENTRYPOINT "python /airbyte/integration_code/main.py" ENTRYPOINT ["python", "/airbyte/integration_code/main.py"] -LABEL io.airbyte.version=0.1.1 +LABEL io.airbyte.version=0.1.2 LABEL io.airbyte.name=airbyte/destination-aws-datalake diff --git a/airbyte-integrations/connectors/destination-aws-datalake/README.md b/airbyte-integrations/connectors/destination-aws-datalake/README.md index 81a16e03a8ad..e4e7f1d858b2 100644 --- a/airbyte-integrations/connectors/destination-aws-datalake/README.md +++ b/airbyte-integrations/connectors/destination-aws-datalake/README.md @@ -140,23 +140,28 @@ To run acceptance and custom integration tests: ./gradlew :airbyte-integrations:connectors:destination-aws-datalake:integrationTest ``` -#### Running the Destination Acceptance Tests +#### Running the Destination Integration Tests To successfully run the Destination Acceptance Tests, you need a `secrets/config.json` file with appropriate information. For example: ```json { - "bucket_name": "your-bucket-name", - "bucket_prefix": "your-prefix", - "region": "your-region", - "aws_account_id": "111111111111", - "lakeformation_database_name": "an_lf_database", - "credentials": { - "credentials_title": "IAM User", - "aws_access_key_id": ".....", - "aws_secret_access_key": "....." - } + "aws_account_id": "111111111111", + "credentials": { + "credentials_title": "IAM User", + "aws_access_key_id": "aws_key_id", + "aws_secret_access_key": "aws_secret_key" + }, + "region": "us-east-1", + "bucket_name": "datalake-bucket", + "lakeformation_database_name": "test", + "format": { + "format_type": "Parquet", + "compression_codec": "SNAPPY" + }, + "partitioning": "NO PARTITIONING" } + ``` In the AWS account, you need to have the following elements in place: @@ -167,7 +172,6 @@ In the AWS account, you need to have the following elements in place: * The user must have appropriate permissions to the Lake Formation database to perform the tests (For example see: [Granting Database Permissions Using the Lake Formation Console and the Named Resource Method](https://docs.aws.amazon.com/lake-formation/latest/dg/granting-database-permissions.html)) - ## Dependency Management All of your dependencies should go in `setup.py`, NOT `requirements.txt`. The requirements file is only used to connect internal Airbyte dependencies in the monorepo for local development. diff --git a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/aws.py b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/aws.py index c03b8305e5d7..1e1a679938d7 100644 --- a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/aws.py +++ b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/aws.py @@ -2,42 +2,95 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -import json +import logging +from decimal import Decimal +from typing import Dict, Optional +import awswrangler as wr import boto3 +import pandas as pd from airbyte_cdk.destinations import Destination +from awswrangler import _data_types from botocore.exceptions import ClientError from retrying import retry -from .config_reader import AuthMode, ConnectorConfig +from .config_reader import CompressionCodec, ConnectorConfig, CredentialsType, OutputFormat + +logger = logging.getLogger("airbyte") + +null_values = ["", " ", "#N/A", "#N/A N/A", "#NA", "", "N/A", "NA", "NULL", "none", "None", "NaN", "n/a", "nan", "null"] + + +def _cast_pandas_column(df: pd.DataFrame, col: str, current_type: str, desired_type: str) -> pd.DataFrame: + if desired_type == "datetime64": + df[col] = pd.to_datetime(df[col]) + elif desired_type == "date": + df[col] = df[col].apply(lambda x: _data_types._cast2date(value=x)).replace(to_replace={pd.NaT: None}) + elif desired_type == "bytes": + df[col] = df[col].astype("string").str.encode(encoding="utf-8").replace(to_replace={pd.NA: None}) + elif desired_type == "decimal": + # First cast to string + df = _cast_pandas_column(df=df, col=col, current_type=current_type, desired_type="string") + # Then cast to decimal + df[col] = df[col].apply(lambda x: Decimal(str(x)) if str(x) not in null_values else None) + elif desired_type.lower() in ["float64", "int64"]: + df[col] = df[col].fillna("") + df[col] = pd.to_numeric(df[col]) + elif desired_type in ["boolean", "bool"]: + if df[col].dtype in ["string", "O"]: + df[col] = df[col].fillna("false").apply(lambda x: str(x).lower() in ["true", "1", "1.0", "t", "y", "yes"]) + + df[col] = df[col].astype(bool) + else: + try: + df[col] = df[col].astype(desired_type) + except (TypeError, ValueError) as ex: + if "object cannot be converted to an IntegerDtype" not in str(ex): + raise ex + logger.warn( + "Object cannot be converted to an IntegerDtype. Integer columns in Python cannot contain " + "missing values. If your input data contains missing values, it will be encoded as floats" + "which may cause precision loss.", + UserWarning, + ) + df[col] = df[col].apply(lambda x: int(x) if str(x) not in null_values else None).astype(desired_type) + return df -class AwsHandler: - COLUMNS_MAPPING = {"number": "float", "string": "string", "integer": "int"} +# Overwrite to fix type conversion issues from athena to pandas +# These happen when appending data to an existing table. awswrangler +# tries to cast the data types to the existing table schema, examples include: +# Fixes: ValueError: could not convert string to float: '' +# Fixes: TypeError: Need to pass bool-like values +_data_types._cast_pandas_column = _cast_pandas_column + - def __init__(self, connector_config, destination: Destination): - self._connector_config: ConnectorConfig = connector_config +class AwsHandler: + def __init__(self, connector_config: ConnectorConfig, destination: Destination): + self._config: ConnectorConfig = connector_config self._destination: Destination = destination - self._bucket_name = connector_config.bucket_name - self.logger = self._destination.logger + self._session: boto3.Session = None self.create_session() - self.s3_client = self.session.client("s3", region_name=connector_config.region) - self.glue_client = self.session.client("glue") - self.lf_client = self.session.client("lakeformation") + self.glue_client = self._session.client("glue") + self.s3_client = self._session.client("s3") + self.lf_client = self._session.client("lakeformation") + + self._table_type = "GOVERNED" if self._config.lakeformation_governed_tables else "EXTERNAL_TABLE" @retry(stop_max_attempt_number=10, wait_random_min=1000, wait_random_max=2000) def create_session(self): - if self._connector_config.credentials_type == AuthMode.IAM_USER.value: + if self._config.credentials_type == CredentialsType.IAM_USER: self._session = boto3.Session( - aws_access_key_id=self._connector_config.aws_access_key, - aws_secret_access_key=self._connector_config.aws_secret_key, - region_name=self._connector_config.region, + aws_access_key_id=self._config.aws_access_key, + aws_secret_access_key=self._config.aws_secret_key, + region_name=self._config.region, ) - elif self._connector_config.credentials_type == AuthMode.IAM_ROLE.value: + + elif self._config.credentials_type == CredentialsType.IAM_ROLE: client = boto3.client("sts") role = client.assume_role( - RoleArn=self._connector_config.role_arn, + RoleArn=self._config.role_arn, RoleSessionName="airbyte-destination-aws-datalake", ) creds = role.get("Credentials", {}) @@ -45,244 +98,164 @@ def create_session(self): aws_access_key_id=creds.get("AccessKeyId"), aws_secret_access_key=creds.get("SecretAccessKey"), aws_session_token=creds.get("SessionToken"), - region_name=self._connector_config.region, + region_name=self._config.region, ) + + def _get_s3_path(self, database: str, table: str) -> str: + bucket = f"s3://{self._config.bucket_name}" + if self._config.bucket_prefix: + bucket += f"/{self._config.bucket_prefix}" + + return f"{bucket}/{database}/{table}/" + + def _get_compression_type(self, compression: CompressionCodec): + if compression == CompressionCodec.GZIP: + return "gzip" + elif compression == CompressionCodec.SNAPPY: + return "snappy" + elif compression == CompressionCodec.ZSTD: + return "zstd" else: - raise Exception("Session wasn't created") + return None - @property - def session(self) -> boto3.Session: - return self._session + def _write_parquet( + self, + df: pd.DataFrame, + path: str, + database: str, + table: str, + mode: str, + dtype: Optional[Dict[str, str]], + partition_cols: list = None, + ): + return wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + database=database, + table=table, + table_type=self._table_type, + mode=mode, + use_threads=False, # True causes s3 NoCredentialsError error + catalog_versioning=True, + boto3_session=self._session, + partition_cols=partition_cols, + compression=self._get_compression_type(self._config.compression_codec), + dtype=dtype, + ) - @retry(stop_max_attempt_number=10, wait_random_min=2000, wait_random_max=3000) - def head_bucket(self): - self.s3_client.head_bucket(Bucket=self._bucket_name) + def _write_json( + self, + df: pd.DataFrame, + path: str, + database: str, + table: str, + mode: str, + dtype: Optional[Dict[str, str]], + partition_cols: list = None, + ): + return wr.s3.to_json( + df=df, + path=path, + dataset=True, + database=database, + table=table, + table_type=self._table_type, + mode=mode, + use_threads=False, # True causes s3 NoCredentialsError error + orient="records", + lines=True, + catalog_versioning=True, + boto3_session=self._session, + partition_cols=partition_cols, + dtype=dtype, + compression=self._get_compression_type(self._config.compression_codec), + ) - @retry(stop_max_attempt_number=10, wait_random_min=2000, wait_random_max=3000) - def head_object(self, object_key): - return self.s3_client.head_object(Bucket=self._bucket_name, Key=object_key) + def _write(self, df: pd.DataFrame, path: str, database: str, table: str, mode: str, dtype: Dict[str, str], partition_cols: list = None): + self._create_database_if_not_exists(database) - @retry(stop_max_attempt_number=10, wait_random_min=2000, wait_random_max=3000) - def put_object(self, object_key, body): - self.s3_client.put_object(Bucket=self._bucket_name, Key=object_key, Body="\n".join(body)) + if self._config.format_type == OutputFormat.JSONL: + return self._write_json(df, path, database, table, mode, dtype, partition_cols) - @staticmethod - def batch_iterate(iterable, n=1): - size = len(iterable) - for ndx in range(0, size, n): - yield iterable[ndx : min(ndx + n, size)] + elif self._config.format_type == OutputFormat.PARQUET: + return self._write_parquet(df, path, database, table, mode, dtype, partition_cols) - def get_table(self, txid, database_name: str, table_name: str, location: str): - table = None - try: - table = self.glue_client.get_table(DatabaseName=database_name, Name=table_name, TransactionId=txid) - except ClientError as e: - if e.response["Error"]["Code"] == "EntityNotFoundException": - table_input = { - "Name": table_name, - "TableType": "GOVERNED", - "StorageDescriptor": { - "Location": location, - "InputFormat": "org.apache.hadoop.mapred.TextInputFormat", - "OutputFormat": "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat", - "SerdeInfo": {"SerializationLibrary": "org.openx.data.jsonserde.JsonSerDe", "Parameters": {"paths": ","}}, - }, - "PartitionKeys": [], - "Parameters": {"classification": "json", "lakeformation.aso.status": "true"}, - } - self.glue_client.create_table(DatabaseName=database_name, TableInput=table_input, TransactionId=txid) - table = self.glue_client.get_table(DatabaseName=database_name, Name=table_name, TransactionId=txid) - else: - err = e.response["Error"]["Code"] - self.logger.error(f"An error occurred: {err}") - raise - - if table: - return table else: - return None + raise Exception(f"Unsupported output format: {self._config.format_type}") - def update_table(self, database, table_info, transaction_id): - self.glue_client.update_table(DatabaseName=database, TableInput=table_info, TransactionId=transaction_id) + def _create_database_if_not_exists(self, database: str): + tag_key = self._config.lakeformation_database_default_tag_key + tag_values = self._config.lakeformation_database_default_tag_values - def preprocess_type(self, property_type): - if type(property_type) is list: - not_null_types = list(filter(lambda t: t != "null", property_type)) - if len(not_null_types) > 2: - return "string" - else: - return not_null_types[0] - else: - return property_type - - def cast_to_athena(self, str_type): - preprocessed_type = self.preprocess_type(str_type) - return self.COLUMNS_MAPPING.get(preprocessed_type, preprocessed_type) - - def generate_athena_schema(self, schema): - columns = [] - for (k, v) in schema.items(): - athena_type = self.cast_to_athena(v["type"]) - if athena_type == "object": - properties = v["properties"] - type_str = ",".join([f"{k1}:{self.cast_to_athena(v1['type'])}" for (k1, v1) in properties.items()]) - columns.append({"Name": k, "Type": f"struct<{type_str}>"}) - else: - columns.append({"Name": k, "Type": athena_type}) - return columns - - def update_table_schema(self, txid, database, table, schema): - table_info = table["Table"] - table_info_keys = list(table_info.keys()) - for k in table_info_keys: - if k not in [ - "Name", - "Description", - "Owner", - "LastAccessTime", - "LastAnalyzedTime", - "Retention", - "StorageDescriptor", - "PartitionKeys", - "ViewOriginalText", - "ViewExpandedText", - "TableType", - "Parameters", - "TargetTable", - "IsRowFilteringEnabled", - ]: - table_info.pop(k) - - self.logger.debug("Schema = " + repr(schema)) - - columns = self.generate_athena_schema(schema) - if "StorageDescriptor" in table_info: - table_info["StorageDescriptor"]["Columns"] = columns - else: - table_info["StorageDescriptor"] = {"Columns": columns} - self.update_table(database, table_info, txid) - self.glue_client.update_table(DatabaseName=database, TableInput=table_info, TransactionId=txid) + wr.catalog.create_database(name=database, boto3_session=self._session, exist_ok=True) + + if tag_key and tag_values: + self.lf_client.add_lf_tags_to_resource( + Resource={ + "Database": {"Name": database}, + }, + LFTags=[{"TagKey": tag_key, "TagValues": tag_values.split(",")}], + ) - def get_all_table_objects(self, txid, database, table): - table_objects = [] + @retry(stop_max_attempt_number=10, wait_random_min=2000, wait_random_max=3000) + def head_bucket(self): + return self.s3_client.head_bucket(Bucket=self._config.bucket_name) + def table_exists(self, database: str, table: str) -> bool: try: - res = self.lf_client.get_table_objects(DatabaseName=database, TableName=table, TransactionId=txid) - except ClientError as e: - if e.response["Error"]["Code"] == "EntityNotFoundException": - return [] - else: - err = e.response["Error"]["Code"] - self.logger.error(f"Could not get table objects due to error: {err}") - raise - - while True: - next_token = res.get("NextToken", None) - partition_objects = res.get("Objects") - table_objects.extend([p["Objects"] for p in partition_objects]) - if next_token: - res = self.lf_client.get_table_objects( - DatabaseName=database, - TableName=table, - TransactionId=txid, - NextToken=next_token, - ) - else: - break - flat_list = [item for sublist in table_objects for item in sublist] - return flat_list - - def purge_table(self, txid, database, table): - self.logger.debug(f"Going to purge table {table}") - write_ops = [] - all_objects = self.get_all_table_objects(txid, database, table) - write_ops.extend([{"DeleteObject": {"Uri": o["Uri"]}} for o in all_objects]) - if len(write_ops) > 0: - self.logger.debug(f"{len(write_ops)} objects to purge") - for batch in self.batch_iterate(write_ops, 99): - self.logger.debug("Purging batch") - try: - self.lf_client.update_table_objects( - TransactionId=txid, - DatabaseName=database, - TableName=table, - WriteOperations=batch, - ) - except ClientError as e: - self.logger.error(f"Could not delete object due to exception {repr(e)}") - raise - else: - self.logger.debug("Table was empty, nothing to purge.") - - def update_governed_table(self, txid, database, table, bucket, object_key, etag, size): - self.logger.debug(f"Updating governed table {database}:{table}") - write_ops = [ - { - "AddObject": { - "Uri": f"s3://{bucket}/{object_key}", - "ETag": etag, - "Size": size, - } - } - ] - - self.lf_client.update_table_objects( - TransactionId=txid, - DatabaseName=database, - TableName=table, - WriteOperations=write_ops, + self.glue_client.get_table(DatabaseName=database, Name=table) + return True + except ClientError: + return False + + def delete_table(self, database: str, table: str) -> bool: + logger.info(f"Deleting table {database}.{table}") + return wr.catalog.delete_table_if_exists(database=database, table=table, boto3_session=self._session) + + def delete_table_objects(self, database: str, table: str) -> None: + path = self._get_s3_path(database, table) + logger.info(f"Deleting objects in {path}") + return wr.s3.delete_objects(path=path, boto3_session=self._session) + + def reset_table(self, database: str, table: str) -> None: + logger.info(f"Resetting table {database}.{table}") + if self.table_exists(database, table): + self.delete_table(database, table) + self.delete_table_objects(database, table) + + def write(self, df: pd.DataFrame, database: str, table: str, dtype: Dict[str, str], partition_cols: list): + path = self._get_s3_path(database, table) + return self._write( + df, + path, + database, + table, + "overwrite", + dtype, + partition_cols, ) + def append(self, df: pd.DataFrame, database: str, table: str, dtype: Dict[str, str], partition_cols: list): + path = self._get_s3_path(database, table) + return self._write( + df, + path, + database, + table, + "append", + dtype, + partition_cols, + ) -class LakeformationTransaction: - def __init__(self, aws_handler: AwsHandler): - self._aws_handler = aws_handler - self._transaction = None - self._logger = aws_handler.logger - - @property - def txid(self): - return self._transaction["TransactionId"] - - def cancel_transaction(self): - self._logger.debug("Canceling Lakeformation Transaction") - self._aws_handler.lf_client.cancel_transaction(TransactionId=self.txid) - - def commit_transaction(self): - self._logger.debug(f"Commiting Lakeformation Transaction {self.txid}") - self._aws_handler.lf_client.commit_transaction(TransactionId=self.txid) - - def extend_transaction(self): - self._logger.debug("Extending Lakeformation Transaction") - self._aws_handler.lf_client.extend_transaction(TransactionId=self.txid) - - def describe_transaction(self): - return self._aws_handler.lf_client.describe_transaction(TransactionId=self.txid) - - def __enter__(self, transaction_type="READ_AND_WRITE"): - self._logger.debug("Starting Lakeformation Transaction") - self._transaction = self._aws_handler.lf_client.start_transaction(TransactionType=transaction_type) - self._logger.debug(f"Transaction id = {self.txid}") - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._logger.debug("Exiting LakeformationTransaction context manager") - tx_desc = self.describe_transaction() - self._logger.debug(json.dumps(tx_desc, default=str)) - - if exc_type: - self._logger.error("Exiting LakeformationTransaction context manager due to an exception") - self._logger.error(repr(exc_type)) - self._logger.error(repr(exc_val)) - self.cancel_transaction() - self._transaction = None - else: - self._logger.debug("Exiting LakeformationTransaction context manager due to reaching end of with block") - try: - self.commit_transaction() - self._transaction = None - except Exception as e: - self.cancel_transaction() - self._logger.error(f"Could not commit the transaction id = {self.txid} because of :\n{repr(e)}") - self._transaction = None - raise (e) + def upsert(self, df: pd.DataFrame, database: str, table: str, dtype: Dict[str, str], partition_cols: list): + path = self._get_s3_path(database, table) + return self._write( + df, + path, + database, + table, + "overwrite_partitions", + dtype, + partition_cols, + ) diff --git a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/config_reader.py b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/config_reader.py index db3648e11c62..8559e9d17c4a 100644 --- a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/config_reader.py +++ b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/config_reader.py @@ -5,10 +5,76 @@ import enum -class AuthMode(enum.Enum): +class CredentialsType(enum.Enum): IAM_ROLE = "IAM Role" IAM_USER = "IAM User" + @staticmethod + def from_string(s: str): + if s == "IAM Role": + return CredentialsType.IAM_ROLE + elif s == "IAM User": + return CredentialsType.IAM_USER + else: + raise ValueError(f"Unknown auth mode: {s}") + + +class OutputFormat(enum.Enum): + PARQUET = "Parquet" + JSONL = "JSONL" + + @staticmethod + def from_string(s: str): + if s == "Parquet": + return OutputFormat.PARQUET + + return OutputFormat.JSONL + + +class CompressionCodec(enum.Enum): + SNAPPY = "SNAPPY" + GZIP = "GZIP" + ZSTD = "ZSTD" + UNCOMPRESSED = "UNCOMPRESSED" + + @staticmethod + def from_config(str: str): + if str == "SNAPPY": + return CompressionCodec.SNAPPY + elif str == "GZIP": + return CompressionCodec.GZIP + elif str == "ZSTD": + return CompressionCodec.ZSTD + + return CompressionCodec.UNCOMPRESSED + + +class PartitionOptions(enum.Enum): + NONE = "NO PARTITIONING" + DATE = "DATE" + YEAR = "YEAR" + MONTH = "MONTH" + DAY = "DAY" + YEAR_MONTH = "YEAR/MONTH" + YEAR_MONTH_DAY = "YEAR/MONTH/DAY" + + @staticmethod + def from_string(s: str): + if s == "DATE": + return PartitionOptions.DATE + elif s == "YEAR": + return PartitionOptions.YEAR + elif s == "MONTH": + return PartitionOptions.MONTH + elif s == "DAY": + return PartitionOptions.DAY + elif s == "YEAR/MONTH": + return PartitionOptions.YEAR_MONTH + elif s == "YEAR/MONTH/DAY": + return PartitionOptions.YEAR_MONTH_DAY + + return PartitionOptions.NONE + class ConnectorConfig: def __init__( @@ -19,21 +85,36 @@ def __init__( bucket_name: str = None, bucket_prefix: str = None, lakeformation_database_name: str = None, + lakeformation_database_default_tag_key: str = None, + lakeformation_database_default_tag_values: str = None, + lakeformation_governed_tables: bool = False, + glue_catalog_float_as_decimal: bool = False, table_name: str = None, + format: dict = {}, + partitioning: str = None, ): self.aws_account_id = aws_account_id self.credentials = credentials - self.credentials_type = credentials.get("credentials_title") + self.credentials_type = CredentialsType.from_string(credentials.get("credentials_title")) self.region = region self.bucket_name = bucket_name self.bucket_prefix = bucket_prefix self.lakeformation_database_name = lakeformation_database_name + self.lakeformation_database_default_tag_key = lakeformation_database_default_tag_key + self.lakeformation_database_default_tag_values = lakeformation_database_default_tag_values + self.lakeformation_governed_tables = lakeformation_governed_tables + self.glue_catalog_float_as_decimal = glue_catalog_float_as_decimal self.table_name = table_name - if self.credentials_type == AuthMode.IAM_USER.value: + self.format_type = OutputFormat.from_string(format.get("format_type", OutputFormat.PARQUET.value)) + self.compression_codec = CompressionCodec.from_config(format.get("compression_codec", CompressionCodec.UNCOMPRESSED.value)) + + self.partitioning = PartitionOptions.from_string(partitioning) + + if self.credentials_type == CredentialsType.IAM_USER: self.aws_access_key = self.credentials.get("aws_access_key_id") self.aws_secret_key = self.credentials.get("aws_secret_access_key") - elif self.credentials_type == AuthMode.IAM_ROLE.value: + elif self.credentials_type == CredentialsType.IAM_ROLE: self.role_arn = self.credentials.get("role_arn") else: raise Exception("Auth Mode not recognized.") diff --git a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/destination.py b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/destination.py index 577af35670fa..95bce879685b 100644 --- a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/destination.py +++ b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/destination.py @@ -2,21 +2,36 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +import logging +import random +import string +from typing import Any, Dict, Iterable, Mapping -import json -from typing import Any, Iterable, Mapping - +import pandas as pd from airbyte_cdk import AirbyteLogger from airbyte_cdk.destinations import Destination from airbyte_cdk.models import AirbyteConnectionStatus, AirbyteMessage, ConfiguredAirbyteCatalog, Status, Type -from botocore.exceptions import ClientError +from botocore.exceptions import ClientError, InvalidRegionError -from .aws import AwsHandler, LakeformationTransaction +from .aws import AwsHandler from .config_reader import ConnectorConfig from .stream_writer import StreamWriter +logger = logging.getLogger("airbyte") + +# Flush records every 25000 records to limit memory consumption +RECORD_FLUSH_INTERVAL = 25000 + class DestinationAwsDatalake(Destination): + def _flush_streams(self, streams: Dict[str, StreamWriter]) -> None: + for stream in streams: + streams[stream].flush() + + @staticmethod + def _get_random_string(length): + return "".join(random.choice(string.ascii_letters) for i in range(length)) + def write( self, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, input_messages: Iterable[AirbyteMessage] ) -> Iterable[AirbyteMessage]: @@ -35,38 +50,63 @@ def write( :param input_messages: The stream of input messages received from the source :return: Iterable of AirbyteStateMessages wrapped in AirbyteMessage structs """ - connector_config = ConnectorConfig(**config) try: aws_handler = AwsHandler(connector_config, self) except ClientError as e: - self.logger.error(f"Could not create session due to exception {repr(e)}") - raise - self.logger.debug("AWS session creation OK") + logger.error(f"Could not create session due to exception {repr(e)}") + raise Exception(f"Could not create session due to exception {repr(e)}") # creating stream writers streams = { - s.stream.name: StreamWriter( - name=s.stream.name, - aws_handler=aws_handler, - connector_config=connector_config, - schema=s.stream.json_schema["properties"], - sync_mode=s.destination_sync_mode, - ) + s.stream.name: StreamWriter(aws_handler=aws_handler, config=connector_config, configured_stream=s) for s in configured_catalog.streams } for message in input_messages: if message.type == Type.STATE: + if not message.state.data: + + if message.state.stream: + stream = message.state.stream.stream_descriptor.name + logger.info(f"Received empty state for stream {stream}, resetting stream") + if stream in streams: + streams[stream].reset() + else: + logger.warning(f"Trying to reset stream {stream} that is not in the configured catalog") + + if not message.state.stream: + logger.info("Received empty state for, resetting all streams including non-incremental streams") + for stream in streams: + streams[stream].reset() + + # Flush records when state is received + if message.state.stream: + if message.state.stream.stream_state and hasattr(message.state.stream.stream_state, "stream_name"): + stream_name = message.state.stream.stream_state.stream_name + if stream_name in streams: + logger.info(f"Got state message from source: flushing records for {stream_name}") + streams[stream_name].flush(partial=True) + yield message - else: + + elif message.type == Type.RECORD: data = message.record.data stream = message.record.stream - streams[stream].append_message(json.dumps(data, default=str)) + streams[stream].append_message(data) + + # Flush records every RECORD_FLUSH_INTERVAL records to limit memory consumption + # Records will either get flushed when a state message is received or when hitting the RECORD_FLUSH_INTERVAL + if len(streams[stream]._messages) > RECORD_FLUSH_INTERVAL: + logger.debug(f"Reached size limit: flushing records for {stream}") + streams[stream].flush(partial=True) + + else: + logger.info(f"Unhandled message type {message.type}: {message}") - for stream_name, stream in streams.items(): - stream.add_to_datalake() + # Flush all or remaining records + self._flush_streams(streams) def check(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> AirbyteConnectionStatus: """ @@ -89,6 +129,10 @@ def check(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> AirbyteConn logger.error(f"""Could not create session on {connector_config.aws_account_id} Exception: {repr(e)}""") message = f"""Could not authenticate using {connector_config.credentials_type} on Account {connector_config.aws_account_id} Exception: {repr(e)}""" return AirbyteConnectionStatus(status=Status.FAILED, message=message) + except InvalidRegionError: + message = f"{connector_config.region} is not a valid AWS region" + logger.error(message) + return AirbyteConnectionStatus(status=Status.FAILED, message=message) try: aws_handler.head_bucket() @@ -96,16 +140,18 @@ def check(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> AirbyteConn message = f"""Could not find bucket {connector_config.bucket_name} in aws://{connector_config.aws_account_id}:{connector_config.region} Exception: {repr(e)}""" return AirbyteConnectionStatus(status=Status.FAILED, message=message) - with LakeformationTransaction(aws_handler) as tx: - table_location = "s3://" + connector_config.bucket_name + "/" + connector_config.bucket_prefix + "/" + "airbyte_test/" - table = aws_handler.get_table( - txid=tx.txid, - database_name=connector_config.lakeformation_database_name, - table_name="airbyte_test", - location=table_location, - ) - if table is None: - message = f"Could not create a table in database {connector_config.lakeformation_database_name}" + tbl = f"airbyte_test_{self._get_random_string(5)}" + db = connector_config.lakeformation_database_name + try: + df = pd.DataFrame({"id": [1, 2], "value": ["foo", "bar"]}) + + aws_handler.reset_table(db, tbl) + aws_handler.write(df, db, tbl, None, None) + aws_handler.reset_table(db, tbl) + + except Exception as e: + message = f"Could not create table {tbl} in database {db}: {repr(e)}" + logger.error(message) return AirbyteConnectionStatus(status=Status.FAILED, message=message) return AirbyteConnectionStatus(status=Status.SUCCEEDED) diff --git a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/spec.json b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/spec.json index 7e6f30b131a2..cdf4a1de08c1 100644 --- a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/spec.json +++ b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/spec.json @@ -6,20 +6,20 @@ "$schema": "http://json-schema.org/draft-07/schema#", "title": "AWS Datalake Destination Spec", "type": "object", - "required": ["credentials", "region", "bucket_name", "bucket_prefix"], + "required": [ + "credentials", + "region", + "bucket_name", + "lakeformation_database_name" + ], "additionalProperties": false, "properties": { "aws_account_id": { "type": "string", "title": "AWS Account Id", "description": "target aws account id", - "examples": ["111111111111"] - }, - "region": { - "title": "AWS Region", - "type": "string", - "description": "Region name", - "airbyte_secret": false + "examples": ["111111111111"], + "order": 1 }, "credentials": { "title": "Authentication mode", @@ -80,25 +80,151 @@ } } } - ] + ], + "order": 2 + }, + "region": { + "title": "S3 Bucket Region", + "type": "string", + "default": "", + "description": "The region of the S3 bucket. See here for all region codes.", + "enum": [ + "", + "us-east-1", + "us-east-2", + "us-west-1", + "us-west-2", + "af-south-1", + "ap-east-1", + "ap-south-1", + "ap-northeast-1", + "ap-northeast-2", + "ap-northeast-3", + "ap-southeast-1", + "ap-southeast-2", + "ca-central-1", + "cn-north-1", + "cn-northwest-1", + "eu-central-1", + "eu-north-1", + "eu-south-1", + "eu-west-1", + "eu-west-2", + "eu-west-3", + "sa-east-1", + "me-south-1", + "us-gov-east-1", + "us-gov-west-1" + ], + "order": 3 }, "bucket_name": { "title": "S3 Bucket Name", "type": "string", - "description": "Name of the bucket", - "airbyte_secret": false + "description": "The name of the S3 bucket. Read more here.", + "order": 4 }, "bucket_prefix": { "title": "Target S3 Bucket Prefix", "type": "string", "description": "S3 prefix", - "airbyte_secret": false + "order": 5 }, "lakeformation_database_name": { - "title": "Lakeformation Database Name", + "title": "Lake Formation Database Name", + "type": "string", + "description": "The default database this destination will use to create tables in per stream. Can be changed per connection by customizing the namespace.", + "order": 6 + }, + "lakeformation_database_default_tag_key": { + "title": "Lake Formation Database Tag Key", + "description": "Add a default tag key to databases created by this destination", + "examples": ["pii_level"], "type": "string", - "description": "Which database to use", - "airbyte_secret": false + "order": 7 + }, + "lakeformation_database_default_tag_values": { + "title": "Lake Formation Database Tag Values", + "description": "Add default values for the `Tag Key` to databases created by this destination. Comma separate for multiple values.", + "examples": ["private,public"], + "type": "string", + "order": 8 + }, + "lakeformation_governed_tables": { + "title": "Lake Formation Governed Tables", + "description": "Whether to create tables as LF governed tables.", + "type": "boolean", + "default": false, + "order": 9 + }, + "format": { + "title": "Output Format *", + "type": "object", + "description": "Format of the data output.", + "oneOf": [ + { + "title": "JSON Lines: Newline-delimited JSON", + "required": ["format_type"], + "properties": { + "format_type": { + "title": "Format Type *", + "type": "string", + "enum": ["JSONL"], + "default": "JSONL" + }, + "compression_codec": { + "title": "Compression Codec (Optional)", + "description": "The compression algorithm used to compress data.", + "type": "string", + "enum": ["UNCOMPRESSED", "GZIP"], + "default": "UNCOMPRESSED" + } + } + }, + { + "title": "Parquet: Columnar Storage", + "required": ["format_type"], + "properties": { + "format_type": { + "title": "Format Type *", + "type": "string", + "enum": ["Parquet"], + "default": "Parquet" + }, + "compression_codec": { + "title": "Compression Codec (Optional)", + "description": "The compression algorithm used to compress data.", + "type": "string", + "enum": ["UNCOMPRESSED", "SNAPPY", "GZIP", "ZSTD"], + "default": "SNAPPY" + } + } + } + ], + "order": 10 + }, + "partitioning": { + "title": "Choose how to partition data", + "description": "Partition data by cursor fields when a cursor field is a date", + "type": "string", + "enum": [ + "NO PARTITIONING", + "DATE", + "YEAR", + "MONTH", + "DAY", + "YEAR/MONTH", + "YEAR/MONTH/DAY" + ], + "default": "NO PARTITIONING", + "order": 11 + }, + "glue_catalog_float_as_decimal": { + "title": "Glue Catalog: Float as Decimal", + "description": "Cast float/double as decimal(38,18). This can help achieve higher accuracy and represent numbers correctly as received from the source.", + "type": "boolean", + "default": false, + "order": 12 } } } diff --git a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/stream_writer.py b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/stream_writer.py index 902ebe9db724..b093f1553ed4 100644 --- a/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/stream_writer.py +++ b/airbyte-integrations/connectors/destination-aws-datalake/destination_aws_datalake/stream_writer.py @@ -2,70 +2,403 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from datetime import datetime +import json +import logging +from typing import Any, Dict, List, Optional, Tuple, Union -import nanoid -from airbyte_cdk.models import DestinationSyncMode -from retrying import retry +import pandas as pd +from airbyte_cdk.models import ConfiguredAirbyteStream, DestinationSyncMode +from destination_aws_datalake.config_reader import ConnectorConfig, PartitionOptions -from .aws import AwsHandler, LakeformationTransaction +from .aws import AwsHandler + +logger = logging.getLogger("airbyte") class StreamWriter: - def __init__(self, name, aws_handler: AwsHandler, connector_config, schema, sync_mode): - self._db = connector_config.lakeformation_database_name - self._bucket = connector_config.bucket_name - self._prefix = connector_config.bucket_prefix - self._table = name - self._aws_handler = aws_handler - self._schema = schema - self._sync_mode = sync_mode - self._messages = [] - self._logger = aws_handler.logger - - self._logger.debug(f"Creating StreamWriter for {self._db}:{self._table}") - if sync_mode == DestinationSyncMode.overwrite: - self._logger.debug(f"StreamWriter mode is OVERWRITE, need to purge {self._db}:{self._table}") - with LakeformationTransaction(self._aws_handler) as tx: - self._aws_handler.purge_table(tx.txid, self._db, self._table) - - def append_message(self, message): - self._logger.debug(f"Appending message to table {self._table}") - self._messages.append(message) - - def generate_object_key(self, prefix=None): - salt = nanoid.generate(size=10) - base = datetime.now().strftime("%Y%m%d%H%M%S") - path = f"{base}.{salt}.json" - if prefix: - path = f"{prefix}/{base}.{salt}.json" - - return path - - @retry(stop_max_attempt_number=10, wait_random_min=2000, wait_random_max=3000) - def add_to_datalake(self): - with LakeformationTransaction(self._aws_handler) as tx: - self._logger.debug(f"Flushing messages to table {self._table}") - object_prefix = f"{self._prefix}/{self._table}" - table_location = "s3://" + self._bucket + "/" + self._prefix + "/" + self._table + "/" - - table = self._aws_handler.get_table(tx.txid, self._db, self._table, table_location) - self._aws_handler.update_table_schema(tx.txid, self._db, table, self._schema) - - if len(self._messages) > 0: - try: - self._logger.debug(f"There are {len(self._messages)} messages to flush for {self._table}") - self._logger.debug(f"10 first messages >>> {repr(self._messages[0:10])} <<<") - object_key = self.generate_object_key(object_prefix) - self._aws_handler.put_object(object_key, self._messages) - res = self._aws_handler.head_object(object_key) - self._aws_handler.update_governed_table( - tx.txid, self._db, self._table, self._bucket, object_key, res["ETag"], res["ContentLength"] - ) - self._logger.debug(f"Table {self._table} was updated") - except Exception as e: - self._logger.error(f"An exception was raised:\n{repr(e)}") - raise (e) - else: - self._logger.debug(f"There was no message to flush for {self._table}") + def __init__(self, aws_handler: AwsHandler, config: ConnectorConfig, configured_stream: ConfiguredAirbyteStream): + self._aws_handler: AwsHandler = aws_handler + self._config: ConnectorConfig = config + self._configured_stream: ConfiguredAirbyteStream = configured_stream + self._schema: Dict[str, Any] = configured_stream.stream.json_schema["properties"] + self._sync_mode: DestinationSyncMode = configured_stream.destination_sync_mode + + self._table_exists: bool = False + self._table: str = configured_stream.stream.name + self._database: str = self._configured_stream.stream.namespace or self._config.lakeformation_database_name + self._messages = [] + self._partial_flush_count = 0 + + logger.info(f"Creating StreamWriter for {self._database}:{self._table}") + + def _get_date_columns(self) -> list: + date_columns = [] + for key, val in self._schema.items(): + typ = val.get("type") + if (isinstance(typ, str) and typ == "string") or (isinstance(typ, list) and "string" in typ): + if val.get("format") in ["date-time", "date"]: + date_columns.append(key) + + return date_columns + + def _add_partition_column(self, col: str, df: pd.DataFrame) -> list: + partitioning = self._config.partitioning + + if partitioning == PartitionOptions.NONE: + return {} + + partitions = partitioning.value.split("/") + + fields = {} + for partition in partitions: + date_col = f"{col}_{partition.lower()}" + fields[date_col] = "bigint" + + # defaulting to 0 since both governed tables + # and pyarrow don't play well with __HIVE_DEFAULT_PARTITION__ + # - pyarrow will fail to cast the column to any other type than string + # - governed tables will fail when trying to query a table with partitions that have __HIVE_DEFAULT_PARTITION__ + # aside from the above, awswrangler will remove data from a table if the partition value is null + # see: https://github.com/aws/aws-sdk-pandas/issues/921 + if partition == "YEAR": + df[date_col] = df[col].dt.strftime("%Y").fillna("0").astype("Int64") + + elif partition == "MONTH": + df[date_col] = df[col].dt.strftime("%m").fillna("0").astype("Int64") + + elif partition == "DAY": + df[date_col] = df[col].dt.strftime("%d").fillna("0").astype("Int64") + + elif partition == "DATE": + fields[date_col] = "date" + df[date_col] = df[col].dt.strftime("%Y-%m-%d") + + return fields + + def _drop_additional_top_level_properties(self, record: Dict[str, Any]) -> Dict[str, Any]: + """ + Helper that removes any unexpected top-level properties from the record. + Since the json schema is used to build the table and cast types correctly, + we need to remove any unexpected properties that can't be casted accurately. + """ + schema_keys = self._schema.keys() + records_keys = record.keys() + difference = list(set(records_keys).difference(set(schema_keys))) + + for key in difference: + del record[key] + + return record + + def _fix_obvious_type_violations(self, record: Dict[str, Any]) -> Dict[str, Any]: + """ + Helper that fixes obvious type violations in a record's top level keys that may + cause issues when casting data to pyarrow types. Such as: + - Objects having empty strings or " " or "-" as value instead of null or {} + - Arrays having empty strings or " " or "-" as value instead of null or [] + """ + schema_keys = self._schema.keys() + for key in schema_keys: + typ = self._schema[key].get("type") + typ = self._get_json_schema_type(typ) + if typ in ["object", "array"]: + if record.get(key) in ["", " ", "-", "/", "null"]: + record[key] = None + + return record + + def _add_missing_columns(self, record: Dict[str, Any]) -> Dict[str, Any]: + """ + Helper that adds missing columns to a record's top level keys. Required + for awswrangler to create the correct schema in glue, even with the explicit + schema passed in, awswrangler will remove those columns when not present + in the dataframe + """ + schema_keys = self._schema.keys() + records_keys = record.keys() + difference = list(set(schema_keys).difference(set(records_keys))) + + for key in difference: + record[key] = None + + return record + + def _get_non_null_json_schema_types(self, typ: Union[str, List[str]]) -> Union[str, List[str]]: + if isinstance(typ, list): + return list(filter(lambda x: x != "null", typ)) + + return typ + + def _json_schema_type_has_mixed_types(self, typ: Union[str, List[str]]) -> bool: + if isinstance(typ, list): + typ = self._get_non_null_json_schema_types(typ) + if len(typ) > 1: + return True + + return False + + def _get_json_schema_type(self, types: Union[List[str], str]) -> str: + if isinstance(types, str): + return types + + if not isinstance(types, list): + return "string" + + types = self._get_non_null_json_schema_types(types) + # when multiple types, cast to string + if self._json_schema_type_has_mixed_types(types): + return "string" + + return types[0] + + def _get_pandas_dtypes_from_json_schema(self, df: pd.DataFrame) -> Dict[str, str]: + type_mapper = { + "string": "string", + "integer": "Int64", + "number": "float64", + "boolean": "bool", + "object": "object", + "array": "object", + } + + column_types = {} + + typ = "string" + for col in df.columns: + if col in self._schema: + typ = self._schema[col].get("type", "string") + airbyte_type = self._schema[col].get("airbyte_type") + + # special case where the json schema type contradicts the airbyte type + if airbyte_type and typ == "number" and airbyte_type == "integer": + typ = "integer" + + typ = self._get_json_schema_type(typ) + + column_types[col] = type_mapper.get(typ, "string") + + return column_types + + def _get_json_schema_types(self): + types = {} + for key, val in self._schema.items(): + typ = val.get("type") + types[key] = self._get_json_schema_type(typ) + return types + + def _is_invalid_struct_or_array(self, schema: Dict[str, Any]) -> bool: + """ + Helper that detects issues with nested objects/arrays in the json schema. + When a complex data type is detected (schema with oneOf) or a nested object without properties + the columns' dtype will be casted to string to avoid pyarrow conversion issues. + """ + result = True + + def check_properties(schema): + nonlocal result + for val in schema.values(): + # Complex types can't be casted to an athena/glue type + if val.get("oneOf"): + result = False + continue + + raw_typ = val.get("type") + + # If the type is a list, check for mixed types + # complex objects with mixed types can't be reliably casted + if isinstance(raw_typ, list) and self._json_schema_type_has_mixed_types(raw_typ): + result = False + continue + + typ = self._get_json_schema_type(raw_typ) + + # If object check nested properties + if typ == "object": + properties = val.get("properties") + if not properties: + result = False + else: + check_properties(properties) + + # If array check nested properties + if typ == "array": + items = val.get("items") + + if not items: + result = False + continue + + if isinstance(items, list): + items = items[0] + + item_properties = items.get("properties") + if item_properties: + check_properties(item_properties) + + check_properties(schema) + return result + + def _get_glue_dtypes_from_json_schema(self, schema: Dict[str, Any]) -> Tuple[Dict[str, str], List[str]]: + """ + Helper that infers glue dtypes from a json schema. + """ + + type_mapper = { + "string": "string", + "integer": "bigint", + "number": "decimal(38, 25)" if self._config.glue_catalog_float_as_decimal else "double", + "boolean": "boolean", + "null": "string", + } + + column_types = {} + json_columns = set() + for (col, definition) in schema.items(): + + result_typ = None + col_typ = definition.get("type") + airbyte_type = definition.get("airbyte_type") + col_format = definition.get("format") + + col_typ = self._get_json_schema_type(col_typ) + + # special case where the json schema type contradicts the airbyte type + if airbyte_type and col_typ == "number" and airbyte_type == "integer": + col_typ = "integer" + + if col_typ == "string" and col_format == "date-time": + result_typ = "timestamp" + + if col_typ == "string" and col_format == "date": + result_typ = "date" + + if col_typ == "object": + properties = definition.get("properties") + if properties and self._is_invalid_struct_or_array(properties): + object_props, _ = self._get_glue_dtypes_from_json_schema(properties) + result_typ = f"struct<{','.join([f'{k}:{v}' for k, v in object_props.items()])}>" + else: + json_columns.add(col) + result_typ = "string" + + if col_typ == "array": + items = definition.get("items", {}) + + if isinstance(items, list): + items = items[0] + + raw_item_type = items.get("type") + item_type = self._get_json_schema_type(raw_item_type) + item_properties = items.get("properties") + + # if array has no "items", cast to string + if not items: + json_columns.add(col) + result_typ = "string" + + # if array with objects + elif isinstance(items, dict) and item_properties: + # Check if nested object has properties and no mixed type objects + if self._is_invalid_struct_or_array(item_properties): + item_dtypes, _ = self._get_glue_dtypes_from_json_schema(item_properties) + inner_struct = f"struct<{','.join([f'{k}:{v}' for k, v in item_dtypes.items()])}>" + result_typ = f"array<{inner_struct}>" + else: + json_columns.add(col) + result_typ = "string" + + elif item_type and self._json_schema_type_has_mixed_types(raw_item_type): + json_columns.add(col) + result_typ = "string" + + # array with single type + elif item_type and not self._json_schema_type_has_mixed_types(raw_item_type): + result_typ = f"array<{type_mapper[item_type]}>" + + if result_typ is None: + result_typ = type_mapper.get(col_typ, "string") + + column_types[col] = result_typ + + return column_types, json_columns + + @property + def _cursor_fields(self) -> Optional[List[str]]: + return self._configured_stream.cursor_field + + def append_message(self, message: Dict[str, Any]): + clean_message = self._drop_additional_top_level_properties(message) + clean_message = self._fix_obvious_type_violations(clean_message) + clean_message = self._add_missing_columns(clean_message) + self._messages.append(clean_message) + + def reset(self): + logger.info(f"Deleting table {self._database}:{self._table}") + success = self._aws_handler.delete_table(self._database, self._table) + + if not success: + logger.warning(f"Failed to reset table {self._database}:{self._table}") + + def flush(self, partial: bool = False): + logger.debug(f"Flushing {len(self._messages)} messages to table {self._database}:{self._table}") + + df = pd.DataFrame(self._messages) + # best effort to convert pandas types + df = df.astype(self._get_pandas_dtypes_from_json_schema(df), errors="ignore") + + if len(df) < 1: + logger.info(f"No messages to write to {self._database}:{self._table}") + return + + partition_fields = {} + date_columns = self._get_date_columns() + for col in date_columns: + if col in df.columns: + df[col] = pd.to_datetime(df[col]) + + # Create date column for partitioning + if self._cursor_fields and col in self._cursor_fields: + fields = self._add_partition_column(col, df) + partition_fields.update(fields) + + dtype, json_casts = self._get_glue_dtypes_from_json_schema(self._schema) + dtype = {**dtype, **partition_fields} + partition_fields = list(partition_fields.keys()) + + # Make sure complex types that can't be converted + # to a struct or array are converted to a json string + # so they can be queried with json_extract + for col in json_casts: + if col in df.columns: + df[col] = df[col].apply(json.dumps) + + if self._sync_mode == DestinationSyncMode.overwrite and self._partial_flush_count < 1: + logger.debug(f"Overwriting {len(df)} records to {self._database}:{self._table}") + self._aws_handler.write( + df, + self._database, + self._table, + dtype, + partition_fields, + ) + + elif self._sync_mode == DestinationSyncMode.append or self._partial_flush_count > 0: + logger.debug(f"Appending {len(df)} records to {self._database}:{self._table}") + self._aws_handler.append( + df, + self._database, + self._table, + dtype, + partition_fields, + ) + + else: + self._messages = [] + raise Exception(f"Unsupported sync mode: {self._sync_mode}") + + if partial: + self._partial_flush_count += 1 + + del df + self._messages.clear() diff --git a/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/unit_test.py b/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/__init__.py similarity index 57% rename from airbyte-integrations/connectors/destination-aws-datalake/unit_tests/unit_test.py rename to airbyte-integrations/connectors/destination-aws-datalake/integration_tests/__init__.py index 219ae0142c72..c941b3045795 100644 --- a/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/unit_test.py +++ b/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/__init__.py @@ -1,7 +1,3 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - - -def test_example_method(): - assert True diff --git a/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/integration_test.py b/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/integration_test.py new file mode 100644 index 000000000000..8e8c9fd72532 --- /dev/null +++ b/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/integration_test.py @@ -0,0 +1,160 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import json +import logging +from datetime import datetime +from typing import Any, Dict, Mapping + +import awswrangler as wr +import pytest +from airbyte_cdk.models import ( + AirbyteMessage, + AirbyteRecordMessage, + AirbyteStateMessage, + AirbyteStream, + ConfiguredAirbyteCatalog, + ConfiguredAirbyteStream, + DestinationSyncMode, + Status, + SyncMode, + Type, +) +from destination_aws_datalake import DestinationAwsDatalake +from destination_aws_datalake.aws import AwsHandler +from destination_aws_datalake.config_reader import ConnectorConfig + +logger = logging.getLogger("airbyte") + + +@pytest.fixture(name="config") +def config_fixture() -> Mapping[str, Any]: + with open("secrets/config.json", "r") as f: + return json.loads(f.read()) + + +@pytest.fixture(name="invalid_region_config") +def invalid_region_config() -> Mapping[str, Any]: + with open("integration_tests/invalid_region_config.json", "r") as f: + return json.loads(f.read()) + + +@pytest.fixture(name="invalid_account_config") +def invalid_account_config() -> Mapping[str, Any]: + with open("integration_tests/invalid_account_config.json", "r") as f: + return json.loads(f.read()) + + +@pytest.fixture(name="configured_catalog") +def configured_catalog_fixture() -> ConfiguredAirbyteCatalog: + stream_schema = { + "type": "object", + "properties": { + "string_col": {"type": "str"}, + "int_col": {"type": "integer"}, + "date_col": {"type": "string", "format": "date-time"}, + }, + } + + append_stream = ConfiguredAirbyteStream( + stream=AirbyteStream( + name="append_stream", json_schema=stream_schema, supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental] + ), + sync_mode=SyncMode.incremental, + destination_sync_mode=DestinationSyncMode.append, + ) + + overwrite_stream = ConfiguredAirbyteStream( + stream=AirbyteStream( + name="overwrite_stream", json_schema=stream_schema, supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental] + ), + sync_mode=SyncMode.incremental, + destination_sync_mode=DestinationSyncMode.overwrite, + ) + + return ConfiguredAirbyteCatalog(streams=[append_stream, overwrite_stream]) + + +def test_check_valid_config(config: Mapping): + outcome = DestinationAwsDatalake().check(logger, config) + assert outcome.status == Status.SUCCEEDED + + +def test_check_invalid_aws_region_config(invalid_region_config: Mapping): + outcome = DestinationAwsDatalake().check(logger, invalid_region_config) + assert outcome.status == Status.FAILED + + +def test_check_invalid_aws_account_config(invalid_account_config: Mapping): + outcome = DestinationAwsDatalake().check(logger, invalid_account_config) + assert outcome.status == Status.FAILED + + +def _state(data: Dict[str, Any]) -> AirbyteMessage: + return AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(data=data)) + + +def _record(stream: str, str_value: str, int_value: int, date_value: datetime) -> AirbyteMessage: + return AirbyteMessage( + type=Type.RECORD, + record=AirbyteRecordMessage(stream=stream, data={"str_col": str_value, "int_col": int_value, "date_col": date_value}, emitted_at=0), + ) + + +def test_write(config: Mapping, configured_catalog: ConfiguredAirbyteCatalog): + """ + This test verifies that: + 1. writing a stream in "overwrite" mode overwrites any existing data for that stream + 2. writing a stream in "append" mode appends new records without deleting the old ones + 3. The correct state message is output by the connector at the end of the sync + """ + append_stream, overwrite_stream = configured_catalog.streams[0].stream.name, configured_catalog.streams[1].stream.name + + destination = DestinationAwsDatalake() + + connector_config = ConnectorConfig(**config) + aws_handler = AwsHandler(connector_config, destination) + + database = connector_config.lakeformation_database_name + + # make sure we start with empty tables + for tbl in [append_stream, overwrite_stream]: + aws_handler.reset_table(database, tbl) + + first_state_message = _state({"state": "1"}) + + first_record_chunk = [_record(append_stream, str(i), i, datetime.now()) for i in range(5)] + [ + _record(overwrite_stream, str(i), i, datetime.now()) for i in range(5) + ] + + second_state_message = _state({"state": "2"}) + second_record_chunk = [_record(append_stream, str(i), i, datetime.now()) for i in range(5, 10)] + [ + _record(overwrite_stream, str(i), i, datetime.now()) for i in range(5, 10) + ] + + expected_states = [first_state_message, second_state_message] + output_states = list( + destination.write( + config, configured_catalog, [*first_record_chunk, first_state_message, *second_record_chunk, second_state_message] + ) + ) + assert expected_states == output_states, "Checkpoint state messages were expected from the destination" + + # Check if table was created + for tbl in [append_stream, overwrite_stream]: + table = wr.catalog.table(database=database, table=tbl, boto3_session=aws_handler._session) + expected_types = {"string_col": "string", "int_col": "bigint", "date_col": "timestamp"} + + # Check table format + for col in table.to_dict("records"): + assert col["Column Name"] in ["string_col", "int_col", "date_col"] + assert col["Type"] == expected_types[col["Column Name"]] + + # Check table data + # cannot use wr.lakeformation.read_sql_query because of this issue: https://github.com/aws/aws-sdk-pandas/issues/2007 + df = wr.athena.read_sql_query(f"SELECT * FROM {tbl}", database=database, boto3_session=aws_handler._session) + assert len(df) == 10 + + # Reset table + aws_handler.reset_table(database, tbl) diff --git a/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/invalid_account_config.json b/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/invalid_account_config.json new file mode 100644 index 000000000000..db12398564dd --- /dev/null +++ b/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/invalid_account_config.json @@ -0,0 +1,11 @@ +{ + "aws_account_id": "111111111111", + "region": "us-east-1", + "credentials": { + "credentials_title": "IAM User", + "aws_access_key_id": "dummy", + "aws_secret_access_key": "dummykey" + }, + "bucket_name": "my-bucket", + "partitioning": "NO PARTITIONING" +} diff --git a/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/invalid_region_config.json b/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/invalid_region_config.json new file mode 100644 index 000000000000..fd28e3b32ca4 --- /dev/null +++ b/airbyte-integrations/connectors/destination-aws-datalake/integration_tests/invalid_region_config.json @@ -0,0 +1,10 @@ +{ + "aws_account_id": "111111111111", + "region": "not_a_real_region", + "credentials": { + "credentials_title": "IAM User", + "aws_access_key_id": "dummy", + "aws_secret_access_key": "dummykey" + }, + "partitioning": "NO PARTITIONING" +} diff --git a/airbyte-integrations/connectors/destination-aws-datalake/setup.py b/airbyte-integrations/connectors/destination-aws-datalake/setup.py index fb3d48b4dab6..4dccbefb7619 100644 --- a/airbyte-integrations/connectors/destination-aws-datalake/setup.py +++ b/airbyte-integrations/connectors/destination-aws-datalake/setup.py @@ -6,17 +6,17 @@ from setuptools import find_packages, setup MAIN_REQUIREMENTS = [ - "airbyte-cdk==0.1.6-rc1", - "boto3", + "airbyte-cdk~=0.1", "retrying", - "nanoid", + "awswrangler==2.17.0", + "pandas==1.4.4", ] TEST_REQUIREMENTS = ["pytest~=6.1"] setup( name="destination_aws_datalake", - description="Destination implementation for Aws Datalake.", + description="Destination implementation for AWS Datalake.", author="Airbyte", author_email="contact@airbyte.io", packages=find_packages(), diff --git a/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/AthenaHelper.java b/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/AthenaHelper.java deleted file mode 100644 index 556c7f6a3f2c..000000000000 --- a/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/AthenaHelper.java +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.integrations.destination.aws_datalake; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import software.amazon.awssdk.auth.credentials.AwsCredentials; -import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.athena.AthenaClient; -import software.amazon.awssdk.services.athena.model.AthenaException; -import software.amazon.awssdk.services.athena.model.GetQueryExecutionRequest; -import software.amazon.awssdk.services.athena.model.GetQueryExecutionResponse; -import software.amazon.awssdk.services.athena.model.GetQueryResultsRequest; -import software.amazon.awssdk.services.athena.model.QueryExecutionContext; -import software.amazon.awssdk.services.athena.model.QueryExecutionState; -import software.amazon.awssdk.services.athena.model.ResultConfiguration; -import software.amazon.awssdk.services.athena.model.StartQueryExecutionRequest; -import software.amazon.awssdk.services.athena.model.StartQueryExecutionResponse; -import software.amazon.awssdk.services.athena.paginators.GetQueryResultsIterable; - -public class AthenaHelper { - - private AthenaClient athenaClient; - private String outputBucket; - private String workGroup; - private static final Logger LOGGER = LoggerFactory.getLogger(AthenaHelper.class); - - public AthenaHelper(AwsCredentials credentials, Region region, String outputBucket, String workGroup) { - LOGGER.debug(String.format("region = %s, outputBucket = %s, workGroup = %s", region, outputBucket, workGroup)); - var credProvider = StaticCredentialsProvider.create(credentials); - this.athenaClient = AthenaClient.builder().region(region).credentialsProvider(credProvider).build(); - this.outputBucket = outputBucket; - this.workGroup = workGroup; - } - - public String submitAthenaQuery(String database, String query) { - try { - - // The QueryExecutionContext allows us to set the database - QueryExecutionContext queryExecutionContext = QueryExecutionContext.builder() - .database(database).build(); - - // The result configuration specifies where the results of the query should go - ResultConfiguration resultConfiguration = ResultConfiguration.builder() - .outputLocation(outputBucket) - .build(); - - StartQueryExecutionRequest startQueryExecutionRequest = StartQueryExecutionRequest.builder() - .queryString(query) - .queryExecutionContext(queryExecutionContext) - .resultConfiguration(resultConfiguration) - .workGroup(workGroup) - .build(); - - StartQueryExecutionResponse startQueryExecutionResponse = athenaClient.startQueryExecution(startQueryExecutionRequest); - return startQueryExecutionResponse.queryExecutionId(); - } catch (AthenaException e) { - e.printStackTrace(); - System.exit(1); - } - return ""; - } - - public void waitForQueryToComplete(String queryExecutionId) throws InterruptedException { - GetQueryExecutionRequest getQueryExecutionRequest = GetQueryExecutionRequest.builder() - .queryExecutionId(queryExecutionId).build(); - - GetQueryExecutionResponse getQueryExecutionResponse; - boolean isQueryStillRunning = true; - while (isQueryStillRunning) { - getQueryExecutionResponse = athenaClient.getQueryExecution(getQueryExecutionRequest); - String queryState = getQueryExecutionResponse.queryExecution().status().state().toString(); - if (queryState.equals(QueryExecutionState.FAILED.toString())) { - throw new RuntimeException("The Amazon Athena query failed to run with error message: " + getQueryExecutionResponse - .queryExecution().status().stateChangeReason()); - } else if (queryState.equals(QueryExecutionState.CANCELLED.toString())) { - throw new RuntimeException("The Amazon Athena query was cancelled."); - } else if (queryState.equals(QueryExecutionState.SUCCEEDED.toString())) { - isQueryStillRunning = false; - } else { - // Sleep an amount of time before retrying again - Thread.sleep(1000); - } - } - } - - public GetQueryResultsIterable getResults(String queryExecutionId) { - - try { - - // Max Results can be set but if its not set, - // it will choose the maximum page size - GetQueryResultsRequest getQueryResultsRequest = GetQueryResultsRequest.builder() - .queryExecutionId(queryExecutionId) - .build(); - - GetQueryResultsIterable getQueryResultsResults = athenaClient.getQueryResultsPaginator(getQueryResultsRequest); - return getQueryResultsResults; - - } catch (AthenaException e) { - e.printStackTrace(); - System.exit(1); - } - return null; - } - - public GetQueryResultsIterable runQuery(String database, String query) throws InterruptedException { - int retryCount = 0; - - while (retryCount < 10) { - var execId = submitAthenaQuery(database, query); - try { - waitForQueryToComplete(execId); - } catch (RuntimeException e) { - e.printStackTrace(); - LOGGER.info("Athena query failed once. Retrying."); - retryCount++; - continue; - } - return getResults(execId); - } - LOGGER.info("Athena query failed and we are out of retries."); - return null; - } - -} diff --git a/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeDestinationConfig.java b/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeDestinationConfig.java deleted file mode 100644 index 4d66d91f5fdd..000000000000 --- a/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeDestinationConfig.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.integrations.destination.aws_datalake; - -import com.fasterxml.jackson.databind.JsonNode; - -public class AwsDatalakeDestinationConfig { - - private final String awsAccountId; - private final String region; - private final String accessKeyId; - private final String secretAccessKey; - private final String bucketName; - private final String prefix; - private final String databaseName; - - public AwsDatalakeDestinationConfig(String awsAccountId, - String region, - String accessKeyId, - String secretAccessKey, - String bucketName, - String prefix, - String databaseName) { - this.awsAccountId = awsAccountId; - this.region = region; - this.accessKeyId = accessKeyId; - this.secretAccessKey = secretAccessKey; - this.bucketName = bucketName; - this.prefix = prefix; - this.databaseName = databaseName; - - } - - public static AwsDatalakeDestinationConfig getAwsDatalakeDestinationConfig(JsonNode config) { - - final String aws_access_key_id = config.path("credentials").get("aws_access_key_id").asText(); - final String aws_secret_access_key = config.path("credentials").get("aws_secret_access_key").asText(); - - return new AwsDatalakeDestinationConfig( - config.get("aws_account_id").asText(), - config.get("region").asText(), - aws_access_key_id, - aws_secret_access_key, - config.get("bucket_name").asText(), - config.get("bucket_prefix").asText(), - config.get("lakeformation_database_name").asText()); - } - - public String getAwsAccountId() { - return awsAccountId; - } - - public String getRegion() { - return region; - } - - public String getAccessKeyId() { - return accessKeyId; - } - - public String getSecretAccessKey() { - return secretAccessKey; - } - - public String getBucketName() { - return bucketName; - } - - public String getPrefix() { - return prefix; - } - - public String getDatabaseName() { - return databaseName; - } - -} diff --git a/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/GlueHelper.java b/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/GlueHelper.java deleted file mode 100644 index bd81bb515ace..000000000000 --- a/airbyte-integrations/connectors/destination-aws-datalake/src/main/java/io/airbyte/integrations/destination/aws_datalake/GlueHelper.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.integrations.destination.aws_datalake; - -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import software.amazon.awssdk.auth.credentials.AwsCredentials; -import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.glue.GlueClient; -import software.amazon.awssdk.services.glue.model.BatchDeleteTableRequest; -import software.amazon.awssdk.services.glue.model.GetTablesRequest; -import software.amazon.awssdk.services.glue.model.GetTablesResponse; -import software.amazon.awssdk.services.glue.model.Table; -import software.amazon.awssdk.services.glue.paginators.GetTablesIterable; - -public class GlueHelper { - - private AwsCredentials awsCredentials; - private Region region; - private GlueClient glueClient; - - public GlueHelper(AwsCredentials credentials, Region region) { - this.awsCredentials = credentials; - this.region = region; - - var credProvider = StaticCredentialsProvider.create(credentials); - this.glueClient = GlueClient.builder().region(region).credentialsProvider(credProvider).build(); - } - - private GetTablesIterable getAllTables(String DatabaseName) { - - GetTablesRequest getTablesRequest = GetTablesRequest.builder().databaseName(DatabaseName).build(); - GetTablesIterable getTablesPaginator = glueClient.getTablesPaginator(getTablesRequest); - - return getTablesPaginator; - } - - private BatchDeleteTableRequest getBatchDeleteRequest(String databaseName, GetTablesIterable getTablesPaginator) { - List tablesToDelete = new ArrayList(); - for (GetTablesResponse response : getTablesPaginator) { - List tablePage = response.tableList(); - Iterator
tableIterator = tablePage.iterator(); - while (tableIterator.hasNext()) { - Table table = tableIterator.next(); - tablesToDelete.add(table.name()); - } - } - BatchDeleteTableRequest batchDeleteRequest = BatchDeleteTableRequest.builder().databaseName(databaseName).tablesToDelete(tablesToDelete).build(); - return batchDeleteRequest; - } - - public void purgeDatabase(String databaseName) { - int countRetries = 0; - while (countRetries < 5) { - try { - GetTablesIterable allTables = getAllTables(databaseName); - BatchDeleteTableRequest batchDeleteTableRequest = getBatchDeleteRequest(databaseName, allTables); - glueClient.batchDeleteTable(batchDeleteTableRequest); - return; - } catch (Exception e) { - countRetries++; - } - } - } - -} diff --git a/airbyte-integrations/connectors/destination-aws-datalake/src/test-integration/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeDestinationAcceptanceTest.java b/airbyte-integrations/connectors/destination-aws-datalake/src/test-integration/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeDestinationAcceptanceTest.java deleted file mode 100644 index f60620ff142d..000000000000 --- a/airbyte-integrations/connectors/destination-aws-datalake/src/test-integration/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeDestinationAcceptanceTest.java +++ /dev/null @@ -1,211 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.integrations.destination.aws_datalake; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -import com.fasterxml.jackson.databind.JsonNode; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Maps; -import io.airbyte.commons.io.IOs; -import io.airbyte.commons.json.Jsons; -import io.airbyte.integrations.standardtest.destination.DestinationAcceptanceTest; -import io.airbyte.integrations.standardtest.destination.comparator.TestDataComparator; -import java.io.IOException; -import java.nio.file.Path; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.athena.model.ColumnInfo; -import software.amazon.awssdk.services.athena.model.Datum; -import software.amazon.awssdk.services.athena.model.GetQueryResultsResponse; -import software.amazon.awssdk.services.athena.model.Row; -import software.amazon.awssdk.services.athena.paginators.GetQueryResultsIterable; - -public class AwsDatalakeDestinationAcceptanceTest extends DestinationAcceptanceTest { - - private static final String CONFIG_PATH = "secrets/config.json"; - private static final Logger LOGGER = LoggerFactory.getLogger(AwsDatalakeDestinationAcceptanceTest.class); - - private JsonNode configJson; - private JsonNode configInvalidCredentialsJson; - protected AwsDatalakeDestinationConfig config; - private AthenaHelper athenaHelper; - private GlueHelper glueHelper; - - @Override - protected String getImageName() { - return "airbyte/destination-aws-datalake:dev"; - } - - @Override - protected JsonNode getConfig() { - // TODO: Generate the configuration JSON file to be used for running the destination during the test - // configJson can either be static and read from secrets/config.json directly - // or created in the setup method - return configJson; - } - - @Override - protected JsonNode getFailCheckConfig() { - JsonNode credentials = Jsons.jsonNode(ImmutableMap.builder() - .put("credentials_title", "IAM User") - .put("aws_access_key_id", "wrong-access-key") - .put("aws_secret_access_key", "wrong-secret") - .build()); - - JsonNode config = Jsons.jsonNode(ImmutableMap.builder() - .put("aws_account_id", "112233") - .put("region", "us-east-1") - .put("bucket_name", "test-bucket") - .put("bucket_prefix", "test") - .put("lakeformation_database_name", "lf_db") - .put("credentials", credentials) - .build()); - - return config; - } - - @Override - protected List retrieveRecords(TestDestinationEnv testEnv, - String streamName, - String namespace, - JsonNode streamSchema) - throws IOException, InterruptedException { - String query = String.format("SELECT * FROM \"%s\".\"%s\"", config.getDatabaseName(), streamName); - GetQueryResultsIterable results = athenaHelper.runQuery(config.getDatabaseName(), query); - return parseResults(results); - } - - protected List parseResults(GetQueryResultsIterable queryResults) { - - List processedResults = new ArrayList<>(); - - for (GetQueryResultsResponse result : queryResults) { - List columnInfoList = result.resultSet().resultSetMetadata().columnInfo(); - Iterator results = result.resultSet().rows().iterator(); - Row colNamesRow = results.next(); - while (results.hasNext()) { - Map jsonMap = Maps.newHashMap(); - Row r = results.next(); - Iterator colInfoIterator = columnInfoList.iterator(); - Iterator datum = r.data().iterator(); - while (colInfoIterator.hasNext() && datum.hasNext()) { - ColumnInfo colInfo = colInfoIterator.next(); - Datum value = datum.next(); - Object typedFieldValue = getTypedFieldValue(colInfo, value); - if (typedFieldValue != null) { - jsonMap.put(colInfo.name(), typedFieldValue); - } - } - processedResults.add(Jsons.jsonNode(jsonMap)); - } - } - return processedResults; - } - - private static Object getTypedFieldValue(ColumnInfo colInfo, Datum value) { - var typeName = colInfo.type(); - var varCharValue = value.varCharValue(); - - if (varCharValue == null) - return null; - var returnType = switch (typeName) { - case "real", "double", "float" -> Double.parseDouble(varCharValue); - case "varchar" -> varCharValue; - case "boolean" -> Boolean.parseBoolean(varCharValue); - case "integer" -> Integer.parseInt(varCharValue); - case "row" -> varCharValue; - default -> null; - }; - if (returnType == null) { - LOGGER.warn(String.format("Unsupported type = %s", typeName)); - } - return returnType; - } - - @Override - protected List resolveIdentifier(String identifier) { - final List result = new ArrayList<>(); - result.add(identifier); - result.add(identifier.toLowerCase()); - return result; - } - - private JsonNode loadJsonFile(String fileName) throws IOException { - final JsonNode configFromSecrets = Jsons.deserialize(IOs.readFile(Path.of(fileName))); - return (configFromSecrets); - } - - @Override - protected void setup(TestDestinationEnv testEnv) throws IOException { - configJson = loadJsonFile(CONFIG_PATH); - - this.config = AwsDatalakeDestinationConfig.getAwsDatalakeDestinationConfig(configJson); - - Region region = Region.of(config.getRegion()); - - AwsBasicCredentials awsCreds = AwsBasicCredentials.create(config.getAccessKeyId(), config.getSecretAccessKey()); - athenaHelper = new AthenaHelper(awsCreds, region, String.format("s3://%s/airbyte_athena/", config.getBucketName()), - "AmazonAthenaLakeFormation"); - glueHelper = new GlueHelper(awsCreds, region); - glueHelper.purgeDatabase(config.getDatabaseName()); - } - - private String toAthenaObject(JsonNode value) { - StringBuilder sb = new StringBuilder("\"{"); - List elements = new ArrayList<>(); - var it = value.fields(); - while (it.hasNext()) { - Map.Entry f = it.next(); - final String k = f.getKey(); - final String v = f.getValue().asText(); - elements.add(String.format("%s=%s", k, v)); - } - sb.append(String.join(",", elements)); - sb.append("}\""); - return sb.toString(); - } - - protected void assertSameValue(final String key, final JsonNode expectedValue, final JsonNode actualValue) { - if (expectedValue.isObject()) { - assertEquals(toAthenaObject(expectedValue), actualValue.toString()); - } else { - assertEquals(expectedValue, actualValue); - } - } - - @Override - protected void tearDown(TestDestinationEnv testEnv) { - // TODO Implement this method to run any cleanup actions needed after every test case - // glueHelper.purgeDatabase(config.getDatabaseName()); - } - - @Override - protected TestDataComparator getTestDataComparator() { - return new AwsDatalakeTestDataComparator(); - } - - @Override - protected boolean supportBasicDataTypeTest() { - return true; - } - - @Override - protected boolean supportArrayDataTypeTest() { - return false; - } - - @Override - protected boolean supportObjectDataTypeTest() { - return true; - } - -} diff --git a/airbyte-integrations/connectors/destination-aws-datalake/src/test-integration/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeTestDataComparator.java b/airbyte-integrations/connectors/destination-aws-datalake/src/test-integration/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeTestDataComparator.java deleted file mode 100644 index 20736ebb1647..000000000000 --- a/airbyte-integrations/connectors/destination-aws-datalake/src/test-integration/java/io/airbyte/integrations/destination/aws_datalake/AwsDatalakeTestDataComparator.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.integrations.destination.aws_datalake; - -import io.airbyte.integrations.standardtest.destination.comparator.AdvancedTestDataComparator; -import java.util.ArrayList; -import java.util.List; -import java.util.Locale; - -public class AwsDatalakeTestDataComparator extends AdvancedTestDataComparator { - - @Override - protected List resolveIdentifier(String identifier) { - final List result = new ArrayList<>(); - result.add(identifier); - result.add(identifier.toLowerCase(Locale.ROOT)); - return result; - } - -} diff --git a/airbyte-integrations/connectors/destination-aws-datalake/src/test/AwsDatalakeDestinationTest.java b/airbyte-integrations/connectors/destination-aws-datalake/src/test/AwsDatalakeDestinationTest.java deleted file mode 100644 index 6c55a807d2ed..000000000000 --- a/airbyte-integrations/connectors/destination-aws-datalake/src/test/AwsDatalakeDestinationTest.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) 2023 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.integrations.destination.aws_datalake; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -import com.fasterxml.jackson.databind.JsonNode; -import io.airbyte.commons.json.Jsons; -import org.junit.jupiter.api.Test; - -class AwsDatalakeDestinationTest { - - /* - * @Test void testGetOutputTableNameWithString() throws Exception { var actual = - * DynamodbOutputTableHelper.getOutputTableName("test_table", "test_namespace", "test_stream"); - * assertEquals("test_table_test_namespace_test_stream", actual); } - * - * @Test void testGetOutputTableNameWithStream() throws Exception { var stream = new - * AirbyteStream(); stream.setName("test_stream"); stream.setNamespace("test_namespace"); var actual - * = DynamodbOutputTableHelper.getOutputTableName("test_table", stream); - * assertEquals("test_table_test_namespace_test_stream", actual); } - */ - @Test - void testGetAwsDatalakeDestinationdbConfig() throws Exception { - JsonNode json = Jsons.deserialize(""" - { - "bucket_prefix": "test_prefix", - "region": "test_region", - "auth_mode": "USER", - "bucket_name": "test_bucket", - "aws_access_key_id": "test_access_key", - "aws_account_id": "test_account_id", - "lakeformation_database_name": "test_database", - "aws_secret_access_key": "test_secret" - }"""); - - var config = AwsDatalakeDestinationConfig.getAwsDatalakeDestinationConfig(json); - - assertEquals(config.getBucketPrefix(), "test_prefix"); - assertEquals(config.getRegion(), "test_region"); - assertEquals(config.getAccessKeyId(), "test_access_key"); - assertEquals(config.getSecretAccessKey(), "test_secret"); - } - -} diff --git a/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/__init__.py b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/aws_handler_test.py b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/aws_handler_test.py new file mode 100644 index 000000000000..ce84337631c0 --- /dev/null +++ b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/aws_handler_test.py @@ -0,0 +1,55 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import json +from typing import Any, Mapping + +import pytest +from destination_aws_datalake import DestinationAwsDatalake +from destination_aws_datalake.aws import AwsHandler +from destination_aws_datalake.config_reader import CompressionCodec, ConnectorConfig + + +@pytest.fixture(name="config") +def config() -> Mapping[str, Any]: + with open("unit_tests/fixtures/config.json", "r") as f: + return json.loads(f.read()) + + +@pytest.fixture(name="config_prefix") +def config_prefix() -> Mapping[str, Any]: + with open("unit_tests/fixtures/config_prefix.json", "r") as f: + return json.loads(f.read()) + + +def test_get_compression_type(config: Mapping[str, Any]): + aws_handler = AwsHandler(ConnectorConfig(**config), DestinationAwsDatalake()) + + tests = { + CompressionCodec.GZIP: "gzip", + CompressionCodec.SNAPPY: "snappy", + CompressionCodec.ZSTD: "zstd", + "LZO": None, + } + + for codec, expected in tests.items(): + assert aws_handler._get_compression_type(codec) == expected + + +def test_get_path(config: Mapping[str, Any]): + conf = ConnectorConfig(**config) + aws_handler = AwsHandler(conf, DestinationAwsDatalake()) + + tbl = "append_stream" + db = conf.lakeformation_database_name + assert aws_handler._get_s3_path(db, tbl) == "s3://datalake-bucket/test/append_stream/" + + +def test_get_path_prefix(config_prefix: Mapping[str, Any]): + conf = ConnectorConfig(**config_prefix) + aws_handler = AwsHandler(conf, DestinationAwsDatalake()) + + tbl = "append_stream" + db = conf.lakeformation_database_name + assert aws_handler._get_s3_path(db, tbl) == "s3://datalake-bucket/prefix/test/append_stream/" diff --git a/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/fixtures/config.json b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/fixtures/config.json new file mode 100644 index 000000000000..8c660ecca6d9 --- /dev/null +++ b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/fixtures/config.json @@ -0,0 +1,16 @@ +{ + "aws_account_id": "111111111111", + "credentials": { + "credentials_title": "IAM User", + "aws_access_key_id": "aws_key_id", + "aws_secret_access_key": "aws_secret_key" + }, + "region": "us-east-1", + "bucket_name": "datalake-bucket", + "lakeformation_database_name": "test", + "format": { + "format_type": "Parquet", + "compression_codec": "SNAPPY" + }, + "partitioning": "NO PARTITIONING" +} diff --git a/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/fixtures/config_prefix.json b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/fixtures/config_prefix.json new file mode 100644 index 000000000000..1f6cca04100f --- /dev/null +++ b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/fixtures/config_prefix.json @@ -0,0 +1,17 @@ +{ + "aws_account_id": "111111111111", + "credentials": { + "credentials_title": "IAM User", + "aws_access_key_id": "aws_key_id", + "aws_secret_access_key": "aws_secret_key" + }, + "region": "us-east-1", + "bucket_name": "datalake-bucket", + "bucket_prefix": "prefix", + "lakeformation_database_name": "test", + "format": { + "format_type": "Parquet", + "compression_codec": "SNAPPY" + }, + "partitioning": "NO PARTITIONING" +} diff --git a/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/stream_writer_test.py b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/stream_writer_test.py new file mode 100644 index 000000000000..7e82be769f1d --- /dev/null +++ b/airbyte-integrations/connectors/destination-aws-datalake/unit_tests/stream_writer_test.py @@ -0,0 +1,308 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import json +from datetime import datetime +from typing import Any, Dict, Mapping + +import pandas as pd +from airbyte_cdk.models import AirbyteStream, ConfiguredAirbyteStream, DestinationSyncMode, SyncMode +from destination_aws_datalake import DestinationAwsDatalake +from destination_aws_datalake.aws import AwsHandler +from destination_aws_datalake.config_reader import ConnectorConfig +from destination_aws_datalake.stream_writer import StreamWriter + + +def get_config() -> Mapping[str, Any]: + with open("unit_tests/fixtures/config.json", "r") as f: + return json.loads(f.read()) + + +def get_configured_stream(): + stream_name = "append_stream" + stream_schema = { + "type": "object", + "properties": { + "string_col": {"type": "str"}, + "int_col": {"type": "integer"}, + "datetime_col": {"type": "string", "format": "date-time"}, + "date_col": {"type": "string", "format": "date"}, + }, + } + + return ConfiguredAirbyteStream( + stream=AirbyteStream( + name=stream_name, + json_schema=stream_schema, + default_cursor_field=["datetime_col"], + supported_sync_modes=[SyncMode.incremental, SyncMode.full_refresh], + ), + sync_mode=SyncMode.incremental, + destination_sync_mode=DestinationSyncMode.append, + cursor_field=["datetime_col"], + ) + + +def get_writer(config: Dict[str, Any]): + connector_config = ConnectorConfig(**config) + aws_handler = AwsHandler(connector_config, DestinationAwsDatalake()) + return StreamWriter(aws_handler, connector_config, get_configured_stream()) + + +def get_big_schema_configured_stream(): + stream_name = "append_stream_big" + stream_schema = { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": ["null", "object"], + "properties": { + "appId": {"type": ["null", "integer"]}, + "appName": {"type": ["null", "string"]}, + "bounced": {"type": ["null", "boolean"]}, + "browser": { + "type": ["null", "object"], + "properties": { + "family": {"type": ["null", "string"]}, + "name": {"type": ["null", "string"]}, + "producer": {"type": ["null", "string"]}, + "producerUrl": {"type": ["null", "string"]}, + "type": {"type": ["null", "string"]}, + "url": {"type": ["null", "string"]}, + "version": {"type": ["null", "array"], "items": {"type": ["null", "string"]}}, + }, + }, + "causedBy": { + "type": ["null", "object"], + "properties": {"created": {"type": ["null", "integer"]}, "id": {"type": ["null", "string"]}}, + }, + "percentage": {"type": ["null", "number"]}, + "location": { + "type": ["null", "object"], + "properties": { + "city": {"type": ["null", "string"]}, + "country": {"type": ["null", "string"]}, + "latitude": {"type": ["null", "number"]}, + "longitude": {"type": ["null", "number"]}, + "state": {"type": ["null", "string"]}, + "zipcode": {"type": ["null", "string"]}, + }, + }, + "nestedJson": { + "type": ["null", "object"], + "properties": { + "city": {"type": "object", "properties": {"name": {"type": "string"}}}, + }, + }, + "sentBy": { + "type": ["null", "object"], + "properties": {"created": {"type": ["null", "integer"]}, "id": {"type": ["null", "string"]}}, + }, + "sentAt": {"type": ["null", "string"], "format": "date-time"}, + "receivedAt": {"type": ["null", "string"], "format": "date"}, + "sourceId": {"type": "string"}, + "status": {"type": "integer"}, + "read": {"type": "boolean"}, + "questions": { + "type": "array", + "items": { + "type": "object", + "properties": {"id": {"type": ["null", "integer"]}, "question": {"type": "string"}, "answer": {"type": "string"}}, + }, + }, + "questions_nested": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": ["null", "integer"]}, + "questions": {"type": "object", "properties": {"title": {"type": "string"}, "option": {"type": "integer"}}}, + "answer": {"type": "string"}, + }, + }, + }, + "nested_mixed_types": { + "type": ["null", "object"], + "properties": { + "city": {"type": ["string", "integer", "null"]}, + }, + }, + "nested_bad_object": { + "type": ["null", "object"], + "properties": { + "city": {"type": "object", "properties": {}}, + }, + }, + "nested_nested_bad_object": { + "type": ["null", "object"], + "properties": { + "city": { + "type": "object", + "properties": { + "name": {"type": "object", "properties": {}}, + }, + }, + }, + }, + "answers": { + "type": "array", + "items": { + "type": "string", + }, + }, + "answers_nested_bad": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": ["string", "integer"]}, + }, + }, + }, + "phone_number_ids": {"type": ["null", "array"], "items": {"type": ["string", "integer"]}}, + "mixed_type_simple": { + "type": ["integer", "number"], + }, + "empty_array": {"type": ["null", "array"]}, + "airbyte_type_object": {"type": "number", "airbyte_type": "integer"}, + }, + } + + return ConfiguredAirbyteStream( + stream=AirbyteStream( + name=stream_name, + json_schema=stream_schema, + default_cursor_field=["datetime_col"], + supported_sync_modes=[SyncMode.incremental, SyncMode.full_refresh], + ), + sync_mode=SyncMode.incremental, + destination_sync_mode=DestinationSyncMode.append, + cursor_field=["datetime_col"], + ) + + +def get_big_schema_writer(config: Dict[str, Any]): + connector_config = ConnectorConfig(**config) + aws_handler = AwsHandler(connector_config, DestinationAwsDatalake()) + return StreamWriter(aws_handler, connector_config, get_big_schema_configured_stream()) + + +def test_get_date_columns(): + writer = get_writer(get_config()) + assert writer._get_date_columns() == ["datetime_col", "date_col"] + + +def test_append_messsage(): + writer = get_writer(get_config()) + message = {"string_col": "test", "int_col": 1, "datetime_col": "2021-01-01T00:00:00Z", "date_col": "2021-01-01"} + writer.append_message(message) + assert len(writer._messages) == 1 + assert writer._messages[0] == message + + +def test_get_cursor_field(): + writer = get_writer(get_config()) + assert writer._cursor_fields == ["datetime_col"] + + +def test_add_partition_column(): + tests = { + "NO PARTITIONING": {}, + "DATE": {"datetime_col_date": "date"}, + "MONTH": {"datetime_col_month": "bigint"}, + "YEAR": {"datetime_col_year": "bigint"}, + "DAY": {"datetime_col_day": "bigint"}, + "YEAR/MONTH/DAY": {"datetime_col_year": "bigint", "datetime_col_month": "bigint", "datetime_col_day": "bigint"}, + } + + for partitioning, expected_columns in tests.items(): + config = get_config() + config["partitioning"] = partitioning + + writer = get_writer(config) + df = pd.DataFrame( + { + "datetime_col": [datetime.now()], + } + ) + assert writer._add_partition_column("datetime_col", df) == expected_columns + assert all([col in df.columns for col in expected_columns]) + + +def test_get_glue_dtypes_from_json_schema(): + writer = get_big_schema_writer(get_config()) + result, json_casts = writer._get_glue_dtypes_from_json_schema(writer._schema) + assert result == { + "airbyte_type_object": "bigint", + "answers": "array", + "answers_nested_bad": "string", + "appId": "bigint", + "appName": "string", + "bounced": "boolean", + "browser": "struct>", + "causedBy": "struct", + "empty_array": "string", + "location": "struct", + "mixed_type_simple": "string", + "nestedJson": "struct>", + "nested_bad_object": "string", + "nested_mixed_types": "string", + "nested_nested_bad_object": "string", + "percentage": "double", + "phone_number_ids": "string", + "questions": "array>", + "questions_nested": "array,answer:string>>", + "read": "boolean", + "receivedAt": "date", + "sentAt": "timestamp", + "sentBy": "struct", + "sourceId": "string", + "status": "bigint", + } + + assert json_casts == { + "answers_nested_bad", + "empty_array", + "nested_bad_object", + "nested_mixed_types", + "nested_nested_bad_object", + "phone_number_ids", + } + + +def test_has_objects_with_no_properties_good(): + writer = get_big_schema_writer(get_config()) + assert writer._is_invalid_struct_or_array( + { + "nestedJson": { + "type": ["null", "object"], + "properties": { + "city": {"type": "object", "properties": {"name": {"type": "string"}}}, + }, + } + } + ) + + +def test_has_objects_with_no_properties_bad(): + writer = get_big_schema_writer(get_config()) + assert not writer._is_invalid_struct_or_array( + { + "nestedJson": { + "type": ["null", "object"], + } + } + ) + + +def test_has_objects_with_no_properties_nested_bad(): + writer = get_big_schema_writer(get_config()) + assert not writer._is_invalid_struct_or_array( + { + "nestedJson": { + "type": ["null", "object"], + "properties": { + "city": {"type": "object", "properties": {}}, + }, + } + } + ) diff --git a/connectors.md b/connectors.md index 10ef495da885..4160adb57f65 100644 --- a/connectors.md +++ b/connectors.md @@ -286,7 +286,7 @@ | Name | Icon | Type | Image | Release Stage | Docs | Code | ID | |----|----|----|----|----|----|----|----| -| **AWS Datalake** | x | Destination | airbyte/destination-aws-datalake:0.1.1 | alpha | [link](https://docs.airbyte.com/integrations/destinations/aws-datalake) | [code](https://github.com/airbytehq/airbyte/tree/master/airbyte-integrations/connectors/destination-aws-datalake) | `99878c90-0fbd-46d3-9d98-ffde879d17fc` | +| **AWS Datalake** | AWS Datalake icon | Destination | airbyte/destination-aws-datalake:0.1.2 | alpha | [link](https://docs.airbyte.com/integrations/destinations/aws-datalake) | [code](https://github.com/airbytehq/airbyte/tree/master/airbyte-integrations/connectors/destination-aws-datalake) | `99878c90-0fbd-46d3-9d98-ffde879d17fc` | | **Amazon SQS** | Amazon SQS icon | Destination | airbyte/destination-amazon-sqs:0.1.0 | alpha | [link](https://docs.airbyte.com/integrations/destinations/amazon-sqs) | [code](https://github.com/airbytehq/airbyte/tree/master/airbyte-integrations/connectors/destination-amazon-sqs) | `0eeee7fb-518f-4045-bacc-9619e31c43ea` | | **Apache Doris** | Apache Doris icon | Destination | airbyte/destination-doris:0.1.0 | alpha | [link](https://docs.airbyte.com/integrations/destinations/doris) | [code](https://github.com/airbytehq/airbyte/tree/master/airbyte-integrations/connectors/destination-doris) | `05c161bf-ca73-4d48-b524-d392be417002` | | **Apache Iceberg** | x | Destination | airbyte/destination-iceberg:0.1.0 | alpha | [link](https://docs.airbyte.com/integrations/destinations/iceberg) | [code](https://github.com/airbytehq/airbyte/tree/master/airbyte-integrations/connectors/destination-iceberg) | `df65a8f3-9908-451b-aa9b-445462803560` | diff --git a/docs/integrations/destinations/aws-datalake.md b/docs/integrations/destinations/aws-datalake.md index dbfa1db6d42e..e3a69b737b5c 100644 --- a/docs/integrations/destinations/aws-datalake.md +++ b/docs/integrations/destinations/aws-datalake.md @@ -2,7 +2,7 @@ This page contains the setup guide and reference information for the AWS Datalake destination connector. -The AWS Datalake destination connector allows you to sync data to AWS. It will write data as JSON files in S3 and +The AWS Datalake destination connector allows you to sync data to AWS. It will write data as JSON files in S3 and will make it available through a [Lake Formation Governed Table](https://docs.aws.amazon.com/lake-formation/latest/dg/governed-tables.html) in the Glue Data Catalog so that the data is available throughout other AWS services such as Athena, Glue jobs, EMR, Redshift, etc. ## Prerequisites @@ -69,5 +69,6 @@ and types in the destination table as in the source except for the following typ ## Changelog +| 0.1.2 | 2022-09-26 | [\#17193](https://github.com/airbytehq/airbyte/pull/17193) | Fix schema keyerror and add parquet support | | 0.1.1 | 2022-04-20 | [\#11811](https://github.com/airbytehq/airbyte/pull/11811) | Fix name of required param in specification | | 0.1.0 | 2022-03-29 | [\#10760](https://github.com/airbytehq/airbyte/pull/10760) | Initial release | \ No newline at end of file