From 01634029ef6d5d62442ba50675c6cc55a9da1061 Mon Sep 17 00:00:00 2001 From: Guilherme Martins Crocetti <24530683+gmcrocetti@users.noreply.github.com> Date: Sun, 22 May 2022 14:14:35 -0300 Subject: [PATCH] Add partition related methods into GlueCatalogHook: * "get_partition" to retrieve a Partition * "create_partition" to create a Partition --- .../amazon/aws/hooks/glue_catalog.py | 58 ++++++++++++++++++- .../amazon/aws/hooks/test_glue_catalog.py | 56 ++++++++++++++++++ 2 files changed, 113 insertions(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/hooks/glue_catalog.py b/airflow/providers/amazon/aws/hooks/glue_catalog.py index fc9c353e084b0..e77916d09e3b9 100644 --- a/airflow/providers/amazon/aws/hooks/glue_catalog.py +++ b/airflow/providers/amazon/aws/hooks/glue_catalog.py @@ -18,8 +18,11 @@ """This module contains AWS Glue Catalog Hook""" import warnings -from typing import Optional, Set +from typing import Dict, List, Optional, Set +from botocore.exceptions import ClientError + +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -123,6 +126,59 @@ def get_table_location(self, database_name: str, table_name: str) -> str: return table['StorageDescriptor']['Location'] + def get_partition(self, database_name: str, table_name: str, partition_values: List[str]) -> Dict: + """ + Gets a Partition + + :param database_name: Database name + :param table_name: Database's Table name + :param partition_values: List of utf-8 strings that define the partition + Please see official AWS documentation for further information. + https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartition + + :rtype: dict + + :raises: AirflowException + + >>> hook = GlueCatalogHook() + >>> partition = hook.get_partition('db', 'table', ['string']) + >>> partition['Values'] + """ + try: + response = self.get_conn().get_partition( + DatabaseName=database_name, TableName=table_name, PartitionValues=partition_values + ) + return response["Partition"] + except ClientError as e: + self.log.error("Client error: %s", e) + raise AirflowException("AWS request failed, check logs for more info") + + def create_partition(self, database_name: str, table_name: str, partition_input: Dict) -> Dict: + """ + Creates a new Partition + + :param database_name: Database name + :param table_name: Database's Table name + :param partition_input: Definition of how the partition is created + Please see official AWS documentation for further information. + https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-CreatePartition + + :rtype: dict + + :raises: AirflowException + + >>> hook = GlueCatalogHook() + >>> partition_input = {"Values": []} + >>> hook.create_partition(database_name="db", table_name="table", partition_input=partition_input) + """ + try: + return self.get_conn().create_partition( + DatabaseName=database_name, TableName=table_name, PartitionInput=partition_input + ) + except ClientError as e: + self.log.error("Client error: %s", e) + raise AirflowException("AWS request failed, check logs for more info") + class AwsGlueCatalogHook(GlueCatalogHook): """ diff --git a/tests/providers/amazon/aws/hooks/test_glue_catalog.py b/tests/providers/amazon/aws/hooks/test_glue_catalog.py index 29730a12e60fe..adbe3da29365d 100644 --- a/tests/providers/amazon/aws/hooks/test_glue_catalog.py +++ b/tests/providers/amazon/aws/hooks/test_glue_catalog.py @@ -21,7 +21,9 @@ import boto3 import pytest +from botocore.exceptions import ClientError +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook try: @@ -38,6 +40,9 @@ "Location": f"s3://mybucket/{DB_NAME}/{TABLE_NAME}", }, } +PARTITION_INPUT: dict = { + "Values": [], +} @unittest.skipIf(mock_glue is None, "Skipping test because moto.mock_glue is not available") @@ -134,3 +139,54 @@ def test_get_table_location(self): result = self.hook.get_table_location(DB_NAME, TABLE_NAME) assert result == TABLE_INPUT['StorageDescriptor']['Location'] + + @mock_glue + def test_get_partition(self): + self.client.create_database(DatabaseInput={'Name': DB_NAME}) + self.client.create_table(DatabaseName=DB_NAME, TableInput=TABLE_INPUT) + self.client.create_partition( + DatabaseName=DB_NAME, TableName=TABLE_NAME, PartitionInput=PARTITION_INPUT + ) + + result = self.hook.get_partition(DB_NAME, TABLE_NAME, PARTITION_INPUT['Values']) + + assert result["Values"] == PARTITION_INPUT['Values'] + assert result["DatabaseName"] == DB_NAME + assert result["TableName"] == TABLE_INPUT["Name"] + + @mock_glue + @mock.patch.object(GlueCatalogHook, 'get_conn') + def test_get_partition_with_client_error(self, mocked_connection): + mocked_client = mock.Mock() + mocked_client.get_partition.side_effect = ClientError({}, "get_partition") + mocked_connection.return_value = mocked_client + + with pytest.raises(AirflowException): + self.hook.get_partition(DB_NAME, TABLE_NAME, PARTITION_INPUT['Values']) + + mocked_client.get_partition.assert_called_once_with( + DatabaseName=DB_NAME, TableName=TABLE_NAME, PartitionValues=PARTITION_INPUT['Values'] + ) + + @mock_glue + def test_create_partition(self): + self.client.create_database(DatabaseInput={'Name': DB_NAME}) + self.client.create_table(DatabaseName=DB_NAME, TableInput=TABLE_INPUT) + + result = self.hook.create_partition(DB_NAME, TABLE_NAME, PARTITION_INPUT) + + assert result + + @mock_glue + @mock.patch.object(GlueCatalogHook, 'get_conn') + def test_create_partition_with_client_error(self, mocked_connection): + mocked_client = mock.Mock() + mocked_client.create_partition.side_effect = ClientError({}, "create_partition") + mocked_connection.return_value = mocked_client + + with pytest.raises(AirflowException): + self.hook.create_partition(DB_NAME, TABLE_NAME, PARTITION_INPUT) + + mocked_client.create_partition.assert_called_once_with( + DatabaseName=DB_NAME, TableName=TABLE_NAME, PartitionInput=PARTITION_INPUT + )