diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py index 1cd5e4ce22d15..b8ab98ba21da1 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py @@ -21,7 +21,7 @@ import json from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook from airflow.providers.amazon.version_compat import BaseOperator @@ -53,6 +53,7 @@ class HiveToDynamoDBOperator(BaseOperator): :param hiveserver2_conn_id: Reference to the :ref: `Hive Server2 thrift service connection id `. :param aws_conn_id: aws connection + :param df_type: DataFrame type to use ("pandas" or "polars"). """ template_fields: Sequence[str] = ("sql",) @@ -73,6 +74,7 @@ def __init__( schema: str = "default", hiveserver2_conn_id: str = "hiveserver2_default", aws_conn_id: str | None = "aws_default", + df_type: Literal["pandas", "polars"] = "pandas", **kwargs, ) -> None: super().__init__(**kwargs) @@ -86,6 +88,7 @@ def __init__( self.schema = schema self.hiveserver2_conn_id = hiveserver2_conn_id self.aws_conn_id = aws_conn_id + self.df_type = df_type def execute(self, context: Context): hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) @@ -93,7 +96,7 @@ def execute(self, context: Context): self.log.info("Extracting data from Hive") self.log.info(self.sql) - data = hive.get_df(self.sql, schema=self.schema, df_type="pandas") + data = hive.get_df(self.sql, schema=self.schema, df_type=self.df_type) dynamodb = DynamoDBHook( aws_conn_id=self.aws_conn_id, table_name=self.table_name, @@ -104,7 +107,10 @@ def execute(self, context: Context): self.log.info("Inserting rows into dynamodb") if self.pre_process is None: - dynamodb.write_batch_data(json.loads(data.to_json(orient="records"))) + if self.df_type == "polars": + dynamodb.write_batch_data(data.to_dicts()) # type:ignore[operator] + elif self.df_type == "pandas": + dynamodb.write_batch_data(json.loads(data.to_json(orient="records"))) # type:ignore[union-attr] else: dynamodb.write_batch_data( self.pre_process(data=data, args=self.pre_process_args, kwargs=self.pre_process_kwargs) diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_hive_to_dynamodb.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_hive_to_dynamodb.py index 82183a444345b..4249ce093a6d6 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_hive_to_dynamodb.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_hive_to_dynamodb.py @@ -22,6 +22,8 @@ from unittest import mock import pandas as pd +import polars as pl +import pytest from moto import mock_aws import airflow.providers.amazon.aws.transfers.hive_to_dynamodb @@ -110,3 +112,43 @@ def test_pre_process_records_with_schema(self, mock_get_df): table = self.hook.get_conn().Table("test_airflow") table.meta.client.get_waiter("table_exists").wait(TableName="test_airflow") assert table.item_count == 1 + + @pytest.mark.parametrize("df_type", ["pandas", "polars"]) + @mock_aws + def test_df_type_parameter(self, df_type): + if df_type == "polars" and pl is None: + pytest.skip("Polars not installed") + + if df_type == "pandas": + test_df = pd.DataFrame(data=[("1", "sid")], columns=["id", "name"]) + else: + test_df = pl.DataFrame({"id": ["1"], "name": ["sid"]}) + + with mock.patch( + "airflow.providers.apache.hive.hooks.hive.HiveServer2Hook.get_df", + return_value=test_df, + ) as mock_get_df: + self.hook.get_conn().create_table( + TableName="test_airflow", + KeySchema=[ + {"AttributeName": "id", "KeyType": "HASH"}, + ], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 10, "WriteCapacityUnits": 10}, + ) + + operator = airflow.providers.amazon.aws.transfers.hive_to_dynamodb.HiveToDynamoDBOperator( + sql=self.sql, + table_name="test_airflow", + task_id="hive_to_dynamodb_check", + table_keys=["id"], + df_type=df_type, + dag=self.dag, + ) + + operator.execute(None) + mock_get_df.assert_called_once_with(self.sql, schema="default", df_type=df_type) + + table = self.hook.get_conn().Table("test_airflow") + table.meta.client.get_waiter("table_exists").wait(TableName="test_airflow") + assert table.item_count == 1