diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py index e13d5ca6b4ec9..45224059e994f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py @@ -261,38 +261,56 @@ def execute(self, context: Context) -> None: file_obj=buf, key=object_key, bucket_name=self.s3_bucket, replace=self.replace ) - def _partition_dataframe(self, df: pd.DataFrame | pl.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]: - """Partition dataframe using pandas groupby() method.""" + def _partition_dataframe( + self, df: pd.DataFrame | pl.DataFrame + ) -> Iterable[tuple[str, pd.DataFrame | pl.DataFrame]]: + """Partition dataframe using pandas or polars groupby() method.""" try: import secrets import string import numpy as np + import pandas as pd import polars as pl except ImportError: pass - if isinstance(df, pl.DataFrame): - df = df.to_pandas() - # if max_rows_per_file argument is specified, a temporary column with a random unusual name will be # added to the dataframe. This column is used to dispatch the dataframe into smaller ones using groupby() - - random_column_name = "" + random_column_name = None if self.max_rows_per_file and not self.groupby_kwargs: random_column_name = "".join(secrets.choice(string.ascii_letters) for _ in range(20)) - df[random_column_name] = np.arange(len(df)) // self.max_rows_per_file self.groupby_kwargs = {"by": random_column_name} + + if random_column_name: + if isinstance(df, pd.DataFrame): + df[random_column_name] = np.arange(len(df)) // self.max_rows_per_file + elif isinstance(df, pl.DataFrame): + df = df.with_columns( + (pl.int_range(pl.len()) // self.max_rows_per_file).alias(random_column_name) + ) + if not self.groupby_kwargs: yield "", df return - for group_label in (grouped_df := df.groupby(**self.groupby_kwargs)).groups: - yield ( - cast("str", group_label), - grouped_df.get_group(group_label) - .drop(random_column_name, axis=1, errors="ignore") - .reset_index(drop=True), - ) + + if isinstance(df, pd.DataFrame): + for group_label in (grouped_df := df.groupby(**self.groupby_kwargs)).groups: + group_df = grouped_df.get_group(group_label) + if random_column_name: + group_df = group_df.drop(random_column_name, axis=1, errors="ignore") + yield ( + cast("str", group_label[0] if isinstance(group_label, tuple) else group_label), + group_df.reset_index(drop=True), + ) + elif isinstance(df, pl.DataFrame): + for group_label, group_df in df.group_by(**self.groupby_kwargs): # type: ignore[assignment] + if random_column_name: + group_df = group_df.drop(random_column_name) + yield ( + cast("str", group_label[0] if isinstance(group_label, tuple) else group_label), + group_df, + ) def _get_hook(self) -> DbApiHook: self.log.debug("Get connection for %s", self.sql_conn_id) diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py index 402803640e6f5..1a1e424928fc2 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py @@ -22,6 +22,7 @@ from unittest import mock import pandas as pd +import polars as pl import pytest from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning @@ -442,7 +443,10 @@ def test_partition_dataframe(self, df_type, input_df_creator): assert len(partitions) == 2 for group_name, df in partitions: - assert isinstance(df, pd.DataFrame) + if df_type == "polars": + assert isinstance(df, pl.DataFrame) + else: + assert isinstance(df, pd.DataFrame) assert group_name in ["A", "B"] @pytest.mark.parametrize(