Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import json
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
from airflow.providers.amazon.version_compat import BaseOperator
Expand Down Expand Up @@ -53,6 +53,7 @@ class HiveToDynamoDBOperator(BaseOperator):
:param hiveserver2_conn_id: Reference to the
:ref: `Hive Server2 thrift service connection id <howto/connection:hiveserver2>`.
:param aws_conn_id: aws connection
:param df_type: DataFrame type to use ("pandas" or "polars").
"""

template_fields: Sequence[str] = ("sql",)
Expand All @@ -73,6 +74,7 @@ def __init__(
schema: str = "default",
hiveserver2_conn_id: str = "hiveserver2_default",
aws_conn_id: str | None = "aws_default",
df_type: Literal["pandas", "polars"] = "pandas",
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -86,14 +88,15 @@ def __init__(
self.schema = schema
self.hiveserver2_conn_id = hiveserver2_conn_id
self.aws_conn_id = aws_conn_id
self.df_type = df_type

def execute(self, context: Context):
hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)

self.log.info("Extracting data from Hive")
self.log.info(self.sql)

data = hive.get_df(self.sql, schema=self.schema, df_type="pandas")
data = hive.get_df(self.sql, schema=self.schema, df_type=self.df_type)
dynamodb = DynamoDBHook(
aws_conn_id=self.aws_conn_id,
table_name=self.table_name,
Expand All @@ -104,7 +107,10 @@ def execute(self, context: Context):
self.log.info("Inserting rows into dynamodb")

if self.pre_process is None:
dynamodb.write_batch_data(json.loads(data.to_json(orient="records")))
if self.df_type == "polars":
dynamodb.write_batch_data(data.to_dicts()) # type:ignore[operator]
elif self.df_type == "pandas":
dynamodb.write_batch_data(json.loads(data.to_json(orient="records"))) # type:ignore[union-attr]
else:
dynamodb.write_batch_data(
self.pre_process(data=data, args=self.pre_process_args, kwargs=self.pre_process_kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from unittest import mock

import pandas as pd
import polars as pl
import pytest
from moto import mock_aws

import airflow.providers.amazon.aws.transfers.hive_to_dynamodb
Expand Down Expand Up @@ -110,3 +112,43 @@ def test_pre_process_records_with_schema(self, mock_get_df):
table = self.hook.get_conn().Table("test_airflow")
table.meta.client.get_waiter("table_exists").wait(TableName="test_airflow")
assert table.item_count == 1

@pytest.mark.parametrize("df_type", ["pandas", "polars"])
@mock_aws
def test_df_type_parameter(self, df_type):
if df_type == "polars" and pl is None:
pytest.skip("Polars not installed")

if df_type == "pandas":
test_df = pd.DataFrame(data=[("1", "sid")], columns=["id", "name"])
else:
test_df = pl.DataFrame({"id": ["1"], "name": ["sid"]})

with mock.patch(
"airflow.providers.apache.hive.hooks.hive.HiveServer2Hook.get_df",
return_value=test_df,
) as mock_get_df:
self.hook.get_conn().create_table(
TableName="test_airflow",
KeySchema=[
{"AttributeName": "id", "KeyType": "HASH"},
],
AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}],
ProvisionedThroughput={"ReadCapacityUnits": 10, "WriteCapacityUnits": 10},
)

operator = airflow.providers.amazon.aws.transfers.hive_to_dynamodb.HiveToDynamoDBOperator(
sql=self.sql,
table_name="test_airflow",
task_id="hive_to_dynamodb_check",
table_keys=["id"],
df_type=df_type,
dag=self.dag,
)

operator.execute(None)
mock_get_df.assert_called_once_with(self.sql, schema="default", df_type=df_type)

table = self.hook.get_conn().Table("test_airflow")
table.meta.client.get_waiter("table_exists").wait(TableName="test_airflow")
assert table.item_count == 1
Loading