diff --git a/airflow/providers/amazon/aws/transfers/sql_to_s3.py b/airflow/providers/amazon/aws/transfers/sql_to_s3.py index c00784ad4a71..f8691fa4a256 100644 --- a/airflow/providers/amazon/aws/transfers/sql_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/sql_to_s3.py @@ -81,6 +81,9 @@ class SqlToS3Operator(BaseOperator): You can specify this argument if you want to use a different CA cert bundle than the one used by botocore. :param file_format: the destination file format, only string 'csv', 'json' or 'parquet' is accepted. + :param max_rows_per_file: (optional) argument to set destination file number of rows limit, if source data + is larger than that, it will be dispatched into multiple files. + Will be ignored if ``groupby_kwargs`` argument is specified. :param pd_kwargs: arguments to include in DataFrame ``.to_parquet()``, ``.to_json()`` or ``.to_csv()``. :param groupby_kwargs: argument to include in DataFrame ``groupby()``. """ @@ -110,6 +113,7 @@ def __init__( aws_conn_id: str = "aws_default", verify: bool | str | None = None, file_format: Literal["csv", "json", "parquet"] = "csv", + max_rows_per_file: int = 0, pd_kwargs: dict | None = None, groupby_kwargs: dict | None = None, **kwargs, @@ -124,12 +128,19 @@ def __init__( self.replace = replace self.pd_kwargs = pd_kwargs or {} self.parameters = parameters + self.max_rows_per_file = max_rows_per_file self.groupby_kwargs = groupby_kwargs or {} self.sql_hook_params = sql_hook_params if "path_or_buf" in self.pd_kwargs: raise AirflowException("The argument path_or_buf is not allowed, please remove it") + if self.max_rows_per_file and self.groupby_kwargs: + raise AirflowException( + "SqlToS3Operator arguments max_rows_per_file and groupby_kwargs " + "can not be both specified. Please choose one." + ) + try: self.file_format = FILE_FORMAT[file_format.upper()] except KeyError: @@ -177,10 +188,8 @@ def execute(self, context: Context) -> None: s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) data_df = sql_hook.get_pandas_df(sql=self.query, parameters=self.parameters) self.log.info("Data from SQL obtained") - self._fix_dtypes(data_df, self.file_format) file_options = FILE_OPTIONS_MAP[self.file_format] - for group_name, df in self._partition_dataframe(df=data_df): with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file: self.log.info("Writing data to temp file") @@ -194,13 +203,32 @@ def execute(self, context: Context) -> None: def _partition_dataframe(self, df: pd.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]: """Partition dataframe using pandas groupby() method.""" + try: + import secrets + import string + + import numpy as np + except ImportError: + pass + # if max_rows_per_file argument is specified, a temporary column with a random unusual name will be + # added to the dataframe. This column is used to dispatch the dataframe into smaller ones using groupby() + random_column_name = "" + if self.max_rows_per_file and not self.groupby_kwargs: + random_column_name = "".join(secrets.choice(string.ascii_letters) for _ in range(20)) + df[random_column_name] = np.arange(len(df)) // self.max_rows_per_file + self.groupby_kwargs = {"by": random_column_name} if not self.groupby_kwargs: yield "", df return for group_label in (grouped_df := df.groupby(**self.groupby_kwargs)).groups: yield ( cast(str, group_label), - cast("pd.DataFrame", grouped_df.get_group(group_label).reset_index(drop=True)), + cast( + "pd.DataFrame", + grouped_df.get_group(group_label) + .drop(random_column_name, axis=1, errors="ignore") + .reset_index(drop=True), + ), ) def _get_hook(self) -> DbApiHook: diff --git a/tests/providers/amazon/aws/transfers/test_sql_to_s3.py b/tests/providers/amazon/aws/transfers/test_sql_to_s3.py index cc56fd064a6b..feee688d462b 100644 --- a/tests/providers/amazon/aws/transfers/test_sql_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_sql_to_s3.py @@ -271,6 +271,58 @@ def test_without_groupby_kwarg(self): ) ) + def test_with_max_rows_per_file(self): + """ + Test operator when the max_rows_per_file is specified + """ + query = "query" + s3_bucket = "bucket" + s3_key = "key" + + op = SqlToS3Operator( + query=query, + s3_bucket=s3_bucket, + s3_key=s3_key, + sql_conn_id="mysql_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + replace=True, + pd_kwargs={"index": False, "header": False}, + max_rows_per_file=3, + dag=None, + ) + example = { + "Team": ["Australia", "Australia", "India", "India"], + "Player": ["Ricky", "David Warner", "Virat Kohli", "Rohit Sharma"], + "Runs": [345, 490, 672, 560], + } + + df = pd.DataFrame(example) + data = [] + for group_name, df in op._partition_dataframe(df): + data.append((group_name, df)) + data.sort(key=lambda d: d[0]) + team, df = data[0] + assert df.equals( + pd.DataFrame( + { + "Team": ["Australia", "Australia", "India"], + "Player": ["Ricky", "David Warner", "Virat Kohli"], + "Runs": [345, 490, 672], + } + ) + ) + team, df = data[1] + assert df.equals( + pd.DataFrame( + { + "Team": ["India"], + "Player": ["Rohit Sharma"], + "Runs": [560], + } + ) + ) + @mock.patch("airflow.providers.common.sql.operators.sql.BaseHook.get_connection") def test_hook_params(self, mock_get_conn): mock_get_conn.return_value = Connection(conn_id="postgres_test", conn_type="postgres")