Skip to content

Commit

Permalink
Add partition related methods to GlueCatalogHook: (#23857)
Browse files Browse the repository at this point in the history
* "get_partition" to retrieve a Partition
* "create_partition" to create a Partition
  • Loading branch information
gmcrocetti authored May 30, 2022
1 parent 8f3a9b8 commit 94f2ce9
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 1 deletion.
58 changes: 57 additions & 1 deletion airflow/providers/amazon/aws/hooks/glue_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@

"""This module contains AWS Glue Catalog Hook"""
import warnings
from typing import Optional, Set
from typing import Dict, List, Optional, Set

from botocore.exceptions import ClientError

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook


Expand Down Expand Up @@ -123,6 +126,59 @@ def get_table_location(self, database_name: str, table_name: str) -> str:

return table['StorageDescriptor']['Location']

def get_partition(self, database_name: str, table_name: str, partition_values: List[str]) -> Dict:
"""
Gets a Partition
:param database_name: Database name
:param table_name: Database's Table name
:param partition_values: List of utf-8 strings that define the partition
Please see official AWS documentation for further information.
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartition
:rtype: dict
:raises: AirflowException
>>> hook = GlueCatalogHook()
>>> partition = hook.get_partition('db', 'table', ['string'])
>>> partition['Values']
"""
try:
response = self.get_conn().get_partition(
DatabaseName=database_name, TableName=table_name, PartitionValues=partition_values
)
return response["Partition"]
except ClientError as e:
self.log.error("Client error: %s", e)
raise AirflowException("AWS request failed, check logs for more info")

def create_partition(self, database_name: str, table_name: str, partition_input: Dict) -> Dict:
"""
Creates a new Partition
:param database_name: Database name
:param table_name: Database's Table name
:param partition_input: Definition of how the partition is created
Please see official AWS documentation for further information.
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-CreatePartition
:rtype: dict
:raises: AirflowException
>>> hook = GlueCatalogHook()
>>> partition_input = {"Values": []}
>>> hook.create_partition(database_name="db", table_name="table", partition_input=partition_input)
"""
try:
return self.get_conn().create_partition(
DatabaseName=database_name, TableName=table_name, PartitionInput=partition_input
)
except ClientError as e:
self.log.error("Client error: %s", e)
raise AirflowException("AWS request failed, check logs for more info")


class AwsGlueCatalogHook(GlueCatalogHook):
"""
Expand Down
56 changes: 56 additions & 0 deletions tests/providers/amazon/aws/hooks/test_glue_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

import boto3
import pytest
from botocore.exceptions import ClientError

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook

try:
Expand All @@ -38,6 +40,9 @@
"Location": f"s3://mybucket/{DB_NAME}/{TABLE_NAME}",
},
}
PARTITION_INPUT: dict = {
"Values": [],
}


@unittest.skipIf(mock_glue is None, "Skipping test because moto.mock_glue is not available")
Expand Down Expand Up @@ -134,3 +139,54 @@ def test_get_table_location(self):

result = self.hook.get_table_location(DB_NAME, TABLE_NAME)
assert result == TABLE_INPUT['StorageDescriptor']['Location']

@mock_glue
def test_get_partition(self):
self.client.create_database(DatabaseInput={'Name': DB_NAME})
self.client.create_table(DatabaseName=DB_NAME, TableInput=TABLE_INPUT)
self.client.create_partition(
DatabaseName=DB_NAME, TableName=TABLE_NAME, PartitionInput=PARTITION_INPUT
)

result = self.hook.get_partition(DB_NAME, TABLE_NAME, PARTITION_INPUT['Values'])

assert result["Values"] == PARTITION_INPUT['Values']
assert result["DatabaseName"] == DB_NAME
assert result["TableName"] == TABLE_INPUT["Name"]

@mock_glue
@mock.patch.object(GlueCatalogHook, 'get_conn')
def test_get_partition_with_client_error(self, mocked_connection):
mocked_client = mock.Mock()
mocked_client.get_partition.side_effect = ClientError({}, "get_partition")
mocked_connection.return_value = mocked_client

with pytest.raises(AirflowException):
self.hook.get_partition(DB_NAME, TABLE_NAME, PARTITION_INPUT['Values'])

mocked_client.get_partition.assert_called_once_with(
DatabaseName=DB_NAME, TableName=TABLE_NAME, PartitionValues=PARTITION_INPUT['Values']
)

@mock_glue
def test_create_partition(self):
self.client.create_database(DatabaseInput={'Name': DB_NAME})
self.client.create_table(DatabaseName=DB_NAME, TableInput=TABLE_INPUT)

result = self.hook.create_partition(DB_NAME, TABLE_NAME, PARTITION_INPUT)

assert result

@mock_glue
@mock.patch.object(GlueCatalogHook, 'get_conn')
def test_create_partition_with_client_error(self, mocked_connection):
mocked_client = mock.Mock()
mocked_client.create_partition.side_effect = ClientError({}, "create_partition")
mocked_connection.return_value = mocked_client

with pytest.raises(AirflowException):
self.hook.create_partition(DB_NAME, TABLE_NAME, PARTITION_INPUT)

mocked_client.create_partition.assert_called_once_with(
DatabaseName=DB_NAME, TableName=TABLE_NAME, PartitionInput=PARTITION_INPUT
)

0 comments on commit 94f2ce9

Please sign in to comment.