diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py index 344935215b785..1befb2a4dc1b9 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py @@ -20,16 +20,18 @@ import enum import gzip import io +import warnings from collections import namedtuple from collections.abc import Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal, cast -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.version_compat import BaseHook, BaseOperator if TYPE_CHECKING: import pandas as pd + import polars as pl from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.utils.context import Context @@ -69,7 +71,8 @@ class SqlToS3Operator(BaseOperator): :param sql_hook_params: Extra config params to be passed to the underlying hook. Should match the desired hook constructor params. :param parameters: (optional) the parameters to render the SQL query with. - :param read_pd_kwargs: arguments to include in DataFrame when ``pd.read_sql()`` is called. + :param read_kwargs: arguments to include in DataFrame when reading from SQL (supports both pandas and polars). + :param df_type: the type of DataFrame to use ('pandas' or 'polars'). Defaults to 'pandas'. :param aws_conn_id: reference to a specific S3 connection :param verify: Whether or not to verify SSL certificates for S3 connection. By default SSL certificates are verified. @@ -84,7 +87,7 @@ class SqlToS3Operator(BaseOperator): :param max_rows_per_file: (optional) argument to set destination file number of rows limit, if source data is larger than that, it will be dispatched into multiple files. Will be ignored if ``groupby_kwargs`` argument is specified. - :param pd_kwargs: arguments to include in DataFrame ``.to_parquet()``, ``.to_json()`` or ``.to_csv()``. + :param df_kwargs: arguments to include in DataFrame ``.to_parquet()``, ``.to_json()`` or ``.to_csv()``. :param groupby_kwargs: argument to include in DataFrame ``groupby()``. """ @@ -97,8 +100,9 @@ class SqlToS3Operator(BaseOperator): template_ext: Sequence[str] = (".sql",) template_fields_renderers = { "query": "sql", + "df_kwargs": "json", "pd_kwargs": "json", - "read_pd_kwargs": "json", + "read_kwargs": "json", } def __init__( @@ -110,12 +114,15 @@ def __init__( sql_conn_id: str, sql_hook_params: dict | None = None, parameters: None | Mapping[str, Any] | list | tuple = None, + read_kwargs: dict | None = None, read_pd_kwargs: dict | None = None, + df_type: Literal["pandas", "polars"] = "pandas", replace: bool = False, aws_conn_id: str | None = "aws_default", verify: bool | str | None = None, file_format: Literal["csv", "json", "parquet"] = "csv", max_rows_per_file: int = 0, + df_kwargs: dict | None = None, pd_kwargs: dict | None = None, groupby_kwargs: dict | None = None, **kwargs, @@ -128,14 +135,30 @@ def __init__( self.aws_conn_id = aws_conn_id self.verify = verify self.replace = replace - self.pd_kwargs = pd_kwargs or {} self.parameters = parameters - self.read_pd_kwargs = read_pd_kwargs or {} self.max_rows_per_file = max_rows_per_file self.groupby_kwargs = groupby_kwargs or {} self.sql_hook_params = sql_hook_params + self.df_type = df_type - if "path_or_buf" in self.pd_kwargs: + if read_pd_kwargs is not None: + warnings.warn( + "The 'read_pd_kwargs' parameter is deprecated. Use 'read_kwargs' instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + self.read_kwargs = read_kwargs if read_kwargs is not None else read_pd_kwargs or {} + + if pd_kwargs is not None: + warnings.warn( + "The 'pd_kwargs' parameter is deprecated. Use 'df_kwargs' instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + + self.df_kwargs = df_kwargs if df_kwargs is not None else pd_kwargs or {} + + if "path_or_buf" in self.df_kwargs: raise AirflowException("The argument path_or_buf is not allowed, please remove it") if self.max_rows_per_file and self.groupby_kwargs: @@ -190,11 +213,12 @@ def execute(self, context: Context) -> None: sql_hook = self._get_hook() s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) data_df = sql_hook.get_df( - sql=self.query, parameters=self.parameters, df_type="pandas", **self.read_pd_kwargs + sql=self.query, parameters=self.parameters, df_type=self.df_type, **self.read_kwargs ) self.log.info("Data from SQL obtained") - if ("dtype_backend", "pyarrow") not in self.read_pd_kwargs.items(): - self._fix_dtypes(data_df, self.file_format) + # Only apply dtype fixes to pandas DataFrames since Polars doesn't have the same NaN/None inconsistencies as panda + if ("dtype_backend", "pyarrow") not in self.read_kwargs.items() and self.df_type == "pandas": + self._fix_dtypes(data_df, self.file_format) # type: ignore[arg-type] file_options = FILE_OPTIONS_MAP[self.file_format] for group_name, df in self._partition_dataframe(df=data_df): @@ -202,16 +226,16 @@ def execute(self, context: Context) -> None: self.log.info("Writing data to in-memory buffer") object_key = f"{self.s3_key}_{group_name}" if group_name else self.s3_key - if self.pd_kwargs.get("compression") == "gzip": - pd_kwargs = {k: v for k, v in self.pd_kwargs.items() if k != "compression"} + if self.df_kwargs.get("compression") == "gzip": + df_kwargs = {k: v for k, v in self.df_kwargs.items() if k != "compression"} with gzip.GzipFile(fileobj=buf, mode="wb", filename=object_key) as gz: - getattr(df, file_options.function)(gz, **pd_kwargs) + getattr(df, file_options.function)(gz, **df_kwargs) else: if self.file_format == FILE_FORMAT.PARQUET: - getattr(df, file_options.function)(buf, **self.pd_kwargs) + getattr(df, file_options.function)(buf, **self.df_kwargs) else: text_buf = io.TextIOWrapper(buf, encoding="utf-8", write_through=True) - getattr(df, file_options.function)(text_buf, **self.pd_kwargs) + getattr(df, file_options.function)(text_buf, **self.df_kwargs) text_buf.flush() buf.seek(0) @@ -220,17 +244,23 @@ def execute(self, context: Context) -> None: file_obj=buf, key=object_key, bucket_name=self.s3_bucket, replace=self.replace ) - def _partition_dataframe(self, df: pd.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]: + def _partition_dataframe(self, df: pd.DataFrame | pl.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]: """Partition dataframe using pandas groupby() method.""" try: import secrets import string import numpy as np + import polars as pl except ImportError: pass + + if isinstance(df, pl.DataFrame): + df = df.to_pandas() + # if max_rows_per_file argument is specified, a temporary column with a random unusual name will be # added to the dataframe. This column is used to dispatch the dataframe into smaller ones using groupby() + random_column_name = "" if self.max_rows_per_file and not self.groupby_kwargs: random_column_name = "".join(secrets.choice(string.ascii_letters) for _ in range(20)) diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py index 0d3791cc8d513..4a8fc21c49fcd 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py @@ -24,7 +24,7 @@ import pandas as pd import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import Connection from airflow.providers.amazon.aws.transfers.sql_to_s3 import SqlToS3Operator @@ -50,8 +50,8 @@ def test_execute_csv(self, mock_s3_hook, dtype_backend): aws_conn_id="aws_conn_id", task_id="task_id", replace=True, - read_pd_kwargs={"dtype_backend": dtype_backend}, - pd_kwargs={"index": False, "header": False}, + read_kwargs={"dtype_backend": dtype_backend}, + df_kwargs={"index": False, "header": False}, dag=None, ) op._get_hook = mock_dbapi_hook @@ -87,7 +87,7 @@ def test_execute_parquet(self, mock_s3_hook, dtype_backend): sql_conn_id="mysql_conn_id", aws_conn_id="aws_conn_id", task_id="task_id", - read_pd_kwargs={"dtype_backend": dtype_backend}, + read_kwargs={"dtype_backend": dtype_backend}, file_format="parquet", replace=True, dag=None, @@ -126,7 +126,7 @@ def test_execute_json(self, mock_s3_hook): task_id="task_id", file_format="json", replace=True, - pd_kwargs={"date_format": "iso", "lines": True, "orient": "records"}, + df_kwargs={"date_format": "iso", "lines": True, "orient": "records"}, dag=None, ) op._get_hook = mock_dbapi_hook @@ -159,7 +159,7 @@ def test_execute_gzip_with_bytesio(self, mock_s3_hook): aws_conn_id="aws_conn_id", task_id="task_id", replace=True, - pd_kwargs={"index": False, "compression": "gzip"}, + df_kwargs={"index": False, "compression": "gzip"}, dag=None, ) op._get_hook = mock_dbapi_hook @@ -220,7 +220,7 @@ def test_fix_dtypes_not_called(self, mock_s3_hook): sql_conn_id="mysql_conn_id", aws_conn_id="aws_conn_id", task_id="task_id", - read_pd_kwargs={"dtype_backend": "pyarrow"}, + read_kwargs={"dtype_backend": "pyarrow"}, file_format="parquet", replace=True, dag=None, @@ -261,7 +261,7 @@ def test_with_groupby_kwarg(self): aws_conn_id="aws_conn_id", task_id="task_id", replace=True, - pd_kwargs={"index": False, "header": False}, + df_kwargs={"index": False, "header": False}, groupby_kwargs={"by": "Team"}, dag=None, ) @@ -313,7 +313,7 @@ def test_without_groupby_kwarg(self): aws_conn_id="aws_conn_id", task_id="task_id", replace=True, - pd_kwargs={"index": False, "header": False}, + df_kwargs={"index": False, "header": False}, dag=None, ) example = { @@ -355,7 +355,7 @@ def test_with_max_rows_per_file(self): aws_conn_id="aws_conn_id", task_id="task_id", replace=True, - pd_kwargs={"index": False, "header": False}, + df_kwargs={"index": False, "header": False}, max_rows_per_file=3, dag=None, ) @@ -407,3 +407,249 @@ def test_hook_params(self, mock_get_conn): ) hook = op._get_hook() assert hook.log_sql == op.sql_hook_params["log_sql"] + + @pytest.mark.parametrize( + "df_type_param,expected_df_type", + [ + pytest.param("polars", "polars", id="with-polars"), + pytest.param("pandas", "pandas", id="with-pandas"), + pytest.param(None, "pandas", id="with-default"), + ], + ) + @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook") + def test_execute_with_df_type(self, mock_s3_hook, df_type_param, expected_df_type): + query = "query" + s3_bucket = "bucket" + s3_key = "key" + + mock_dbapi_hook = mock.Mock() + test_df = pd.DataFrame({"a": "1", "b": "2"}, index=[0, 1]) + get_df_mock = mock_dbapi_hook.return_value.get_df + get_df_mock.return_value = test_df + + kwargs = { + "query": query, + "s3_bucket": s3_bucket, + "s3_key": s3_key, + "sql_conn_id": "mysql_conn_id", + "aws_conn_id": "aws_conn_id", + "task_id": "task_id", + "replace": True, + "dag": None, + } + if df_type_param is not None: + kwargs["df_type"] = df_type_param + + op = SqlToS3Operator(**kwargs) + op._get_hook = mock_dbapi_hook + op.execute(None) + + mock_s3_hook.assert_called_once_with(aws_conn_id="aws_conn_id", verify=None) + get_df_mock.assert_called_once_with(sql=query, parameters=None, df_type=expected_df_type) + file_obj = mock_s3_hook.return_value.load_file_obj.call_args[1]["file_obj"] + assert isinstance(file_obj, io.BytesIO) + mock_s3_hook.return_value.load_file_obj.assert_called_once_with( + file_obj=file_obj, key=s3_key, bucket_name=s3_bucket, replace=True + ) + + @pytest.mark.parametrize( + "df_type,input_df_creator", + [ + pytest.param( + "pandas", + lambda: pd.DataFrame({"category": ["A", "A", "B", "B"], "value": [1, 2, 3, 4]}), + id="with-pandas-dataframe", + ), + pytest.param( + "polars", + lambda: pytest.importorskip("polars").DataFrame( + {"category": ["A", "A", "B", "B"], "value": [1, 2, 3, 4]} + ), + id="with-polars-dataframe", + ), + ], + ) + def test_partition_dataframe(self, df_type, input_df_creator): + """Test that _partition_dataframe works with both pandas and polars DataFrames.""" + op = SqlToS3Operator( + query="query", + s3_bucket="bucket", + s3_key="key", + sql_conn_id="mysql_conn_id", + task_id="task_id", + df_type=df_type, + groupby_kwargs={"by": "category"}, + ) + + input_df = input_df_creator() + partitions = list(op._partition_dataframe(input_df)) + + assert len(partitions) == 2 + for group_name, df in partitions: + assert isinstance(df, pd.DataFrame) + assert group_name in ["A", "B"] + + @pytest.mark.parametrize( + "kwargs,expected_warning,expected_error,expected_read_kwargs", + [ + pytest.param( + {"read_pd_kwargs": {"dtype_backend": "pyarrow"}}, + "The 'read_pd_kwargs' parameter is deprecated", + None, + {"dtype_backend": "pyarrow"}, + id="deprecated-read-pd-kwargs-warning", + ), + pytest.param( + { + "read_kwargs": {"dtype_backend": "pyarrow"}, + "read_pd_kwargs": {"dtype_backend": "numpy_nullable"}, + }, + "The 'read_pd_kwargs' parameter is deprecated", + None, + {"dtype_backend": "pyarrow"}, + id="read-kwargs-priority-over-deprecated", + ), + pytest.param( + {"max_rows_per_file": 2, "groupby_kwargs": {"by": "category"}}, + None, + "can not be both specified", + None, + id="max-rows-groupby-conflict-error", + ), + pytest.param( + {"pd_kwargs": {"index": False}}, + "The 'pd_kwargs' parameter is deprecated", + None, + None, + id="deprecated-pd-kwargs-warning", + ), + pytest.param( + {"df_kwargs": {"index": False}, "pd_kwargs": {"header": False}}, + "The 'pd_kwargs' parameter is deprecated", + None, + None, + id="df-kwargs-priority-over-deprecated", + ), + ], + ) + def test_parameter_validation(self, kwargs, expected_warning, expected_error, expected_read_kwargs): + """Test parameter validation and deprecation warnings.""" + base_kwargs = { + "query": "query", + "s3_bucket": "bucket", + "s3_key": "key", + "sql_conn_id": "mysql_conn_id", + "task_id": "task_id", + } + base_kwargs.update(kwargs) + + if expected_error: + with pytest.raises(AirflowException, match=expected_error): + SqlToS3Operator(**base_kwargs) + elif expected_warning: + with pytest.warns(AirflowProviderDeprecationWarning, match=expected_warning): + op = SqlToS3Operator(**base_kwargs) + if expected_read_kwargs: + assert op.read_kwargs == expected_read_kwargs + else: + op = SqlToS3Operator(**base_kwargs) + if expected_read_kwargs: + assert op.read_kwargs == expected_read_kwargs + + @pytest.mark.parametrize( + "df_type,should_call_fix_dtypes", + [ + pytest.param("pandas", True, id="pandas-calls-fix-dtypes"), + pytest.param("polars", False, id="polars-skips-fix-dtypes"), + ], + ) + @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook") + def test_fix_dtypes_behavior_by_df_type(self, mock_s3_hook, df_type, should_call_fix_dtypes): + """Test that _fix_dtypes is called/not called based on df_type.""" + query = "query" + s3_bucket = "bucket" + s3_key = "key" + + mock_dbapi_hook = mock.Mock() + test_df = pd.DataFrame({"a": "1", "b": "2"}, index=[0, 1]) + get_df_mock = mock_dbapi_hook.return_value.get_df + get_df_mock.return_value = test_df + + op = SqlToS3Operator( + query=query, + s3_bucket=s3_bucket, + s3_key=s3_key, + sql_conn_id="mysql_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + df_type=df_type, + replace=True, + dag=None, + ) + op._get_hook = mock_dbapi_hook + + with mock.patch.object(SqlToS3Operator, "_fix_dtypes") as mock_fix_dtypes: + op.execute(None) + + if should_call_fix_dtypes: + mock_fix_dtypes.assert_called_once() + else: + mock_fix_dtypes.assert_not_called() + + @pytest.mark.parametrize( + "kwargs,expected_warning,expected_read_kwargs,expected_df_kwargs", + [ + pytest.param( + { + "read_kwargs": {"dtype_backend": "pyarrow"}, + "read_pd_kwargs": {"dtype_backend": "numpy_nullable"}, + }, + "The 'read_pd_kwargs' parameter is deprecated", + {"dtype_backend": "pyarrow"}, + {}, + id="read-kwargs-priority-over-deprecated", + ), + pytest.param( + {"read_pd_kwargs": {"dtype_backend": "numpy_nullable"}}, + "The 'read_pd_kwargs' parameter is deprecated", + {"dtype_backend": "numpy_nullable"}, + {}, + id="read-pd-kwargs-used-when-read-kwargs-none", + ), + pytest.param( + { + "df_kwargs": {"index": False}, + "pd_kwargs": {"header": False}, + }, + "The 'pd_kwargs' parameter is deprecated", + {}, + {"index": False}, + id="df-kwargs-priority-over-deprecated", + ), + pytest.param( + {"pd_kwargs": {"header": False}}, + "The 'pd_kwargs' parameter is deprecated", + {}, + {"header": False}, + id="pd-kwargs-used-when-df-kwargs-none", + ), + ], + ) + def test_deprecated_kwargs_priority_behavior( + self, kwargs, expected_warning, expected_read_kwargs, expected_df_kwargs + ): + """Test priority behavior and deprecation warnings for deprecated parameters.""" + base_kwargs = { + "query": "query", + "s3_bucket": "bucket", + "s3_key": "key", + "sql_conn_id": "mysql_conn_id", + "task_id": "task_id", + } + base_kwargs.update(kwargs) + + with pytest.warns(AirflowProviderDeprecationWarning, match=expected_warning): + op = SqlToS3Operator(**base_kwargs) + + assert op.read_kwargs == expected_read_kwargs + assert op.df_kwargs == expected_df_kwargs