Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create get_partition and create_partition methods in GlueCatalogHook #23857

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion airflow/providers/amazon/aws/hooks/glue_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@

"""This module contains AWS Glue Catalog Hook"""
import warnings
from typing import Optional, Set
from typing import Dict, List, Optional, Set

from botocore.exceptions import ClientError

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook


Expand Down Expand Up @@ -123,6 +126,59 @@ def get_table_location(self, database_name: str, table_name: str) -> str:

return table['StorageDescriptor']['Location']

def get_partition(self, database_name: str, table_name: str, partition_values: List[str]) -> Dict:
"""
Gets a Partition

:param database_name: Database name
:param table_name: Database's Table name
:param partition_values: List of utf-8 strings that define the partition
Please see official AWS documentation for further information.
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartition

:rtype: dict

:raises: AirflowException

>>> hook = GlueCatalogHook()
>>> partition = hook.get_partition('db', 'table', ['string'])
>>> partition['Values']
"""
try:
response = self.get_conn().get_partition(
DatabaseName=database_name, TableName=table_name, PartitionValues=partition_values
)
return response["Partition"]
except ClientError as e:
self.log.error("Client error: %s", e)
raise AirflowException("AWS request failed, check logs for more info")

def create_partition(self, database_name: str, table_name: str, partition_input: Dict) -> Dict:
"""
Creates a new Partition

:param database_name: Database name
:param table_name: Database's Table name
:param partition_input: Definition of how the partition is created
Please see official AWS documentation for further information.
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-CreatePartition

:rtype: dict

:raises: AirflowException

>>> hook = GlueCatalogHook()
>>> partition_input = {"Values": []}
>>> hook.create_partition(database_name="db", table_name="table", partition_input=partition_input)
"""
try:
return self.get_conn().create_partition(
DatabaseName=database_name, TableName=table_name, PartitionInput=partition_input
)
except ClientError as e:
self.log.error("Client error: %s", e)
raise AirflowException("AWS request failed, check logs for more info")


class AwsGlueCatalogHook(GlueCatalogHook):
"""
Expand Down
56 changes: 56 additions & 0 deletions tests/providers/amazon/aws/hooks/test_glue_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

import boto3
import pytest
from botocore.exceptions import ClientError

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook

try:
Expand All @@ -38,6 +40,9 @@
"Location": f"s3://mybucket/{DB_NAME}/{TABLE_NAME}",
},
}
PARTITION_INPUT: dict = {
"Values": [],
}


@unittest.skipIf(mock_glue is None, "Skipping test because moto.mock_glue is not available")
Expand Down Expand Up @@ -134,3 +139,54 @@ def test_get_table_location(self):

result = self.hook.get_table_location(DB_NAME, TABLE_NAME)
assert result == TABLE_INPUT['StorageDescriptor']['Location']

@mock_glue
def test_get_partition(self):
self.client.create_database(DatabaseInput={'Name': DB_NAME})
self.client.create_table(DatabaseName=DB_NAME, TableInput=TABLE_INPUT)
self.client.create_partition(
DatabaseName=DB_NAME, TableName=TABLE_NAME, PartitionInput=PARTITION_INPUT
)

result = self.hook.get_partition(DB_NAME, TABLE_NAME, PARTITION_INPUT['Values'])

assert result["Values"] == PARTITION_INPUT['Values']
assert result["DatabaseName"] == DB_NAME
assert result["TableName"] == TABLE_INPUT["Name"]

@mock_glue
@mock.patch.object(GlueCatalogHook, 'get_conn')
def test_get_partition_with_client_error(self, mocked_connection):
mocked_client = mock.Mock()
mocked_client.get_partition.side_effect = ClientError({}, "get_partition")
mocked_connection.return_value = mocked_client

with pytest.raises(AirflowException):
self.hook.get_partition(DB_NAME, TABLE_NAME, PARTITION_INPUT['Values'])

mocked_client.get_partition.assert_called_once_with(
DatabaseName=DB_NAME, TableName=TABLE_NAME, PartitionValues=PARTITION_INPUT['Values']
)

@mock_glue
def test_create_partition(self):
self.client.create_database(DatabaseInput={'Name': DB_NAME})
self.client.create_table(DatabaseName=DB_NAME, TableInput=TABLE_INPUT)

result = self.hook.create_partition(DB_NAME, TABLE_NAME, PARTITION_INPUT)

assert result

@mock_glue
@mock.patch.object(GlueCatalogHook, 'get_conn')
def test_create_partition_with_client_error(self, mocked_connection):
mocked_client = mock.Mock()
mocked_client.create_partition.side_effect = ClientError({}, "create_partition")
mocked_connection.return_value = mocked_client

with pytest.raises(AirflowException):
self.hook.create_partition(DB_NAME, TABLE_NAME, PARTITION_INPUT)

mocked_client.create_partition.assert_called_once_with(
DatabaseName=DB_NAME, TableName=TABLE_NAME, PartitionInput=PARTITION_INPUT
)