diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py index e18c3d35fd3f1..7d1f2656c4f2b 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py @@ -18,9 +18,10 @@ from __future__ import annotations import enum +import gzip +import io from collections import namedtuple from collections.abc import Iterable, Mapping, Sequence -from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Any, cast from typing_extensions import Literal @@ -191,16 +192,29 @@ def execute(self, context: Context) -> None: self.log.info("Data from SQL obtained") self._fix_dtypes(data_df, self.file_format) file_options = FILE_OPTIONS_MAP[self.file_format] + for group_name, df in self._partition_dataframe(df=data_df): - with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file: - self.log.info("Writing data to temp file") - getattr(df, file_options.function)(tmp_file.name, **self.pd_kwargs) - - self.log.info("Uploading data to S3") - object_key = f"{self.s3_key}_{group_name}" if group_name else self.s3_key - s3_conn.load_file( - filename=tmp_file.name, key=object_key, bucket_name=self.s3_bucket, replace=self.replace - ) + buf = io.BytesIO() + self.log.info("Writing data to in-memory buffer") + object_key = f"{self.s3_key}_{group_name}" if group_name else self.s3_key + + if self.pd_kwargs.get("compression") == "gzip": + pd_kwargs = {k: v for k, v in self.pd_kwargs.items() if k != "compression"} + with gzip.GzipFile(fileobj=buf, mode="wb", filename=object_key) as gz: + getattr(df, file_options.function)(gz, **pd_kwargs) + else: + if self.file_format == FILE_FORMAT.PARQUET: + getattr(df, file_options.function)(buf, **self.pd_kwargs) + else: + text_buf = io.TextIOWrapper(buf, encoding="utf-8", write_through=True) + getattr(df, file_options.function)(text_buf, **self.pd_kwargs) + text_buf.flush() + buf.seek(0) + + self.log.info("Uploading data to S3") + s3_conn.load_file_obj( + file_obj=buf, key=object_key, bucket_name=self.s3_bucket, replace=self.replace + ) def _partition_dataframe(self, df: pd.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]: """Partition dataframe using pandas groupby() method.""" diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py index ecff0d03363ef..cb1d5828df4e1 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py @@ -17,7 +17,8 @@ # under the License. from __future__ import annotations -from tempfile import NamedTemporaryFile +import gzip +import io from unittest import mock import pandas as pd @@ -29,9 +30,8 @@ class TestSqlToS3Operator: - @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.NamedTemporaryFile") @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook") - def test_execute_csv(self, mock_s3_hook, temp_mock): + def test_execute_csv(self, mock_s3_hook): query = "query" s3_bucket = "bucket" s3_key = "key" @@ -40,110 +40,136 @@ def test_execute_csv(self, mock_s3_hook, temp_mock): test_df = pd.DataFrame({"a": "1", "b": "2"}, index=[0, 1]) get_df_mock = mock_dbapi_hook.return_value.get_df get_df_mock.return_value = test_df - with NamedTemporaryFile() as f: - temp_mock.return_value.__enter__.return_value.name = f.name - op = SqlToS3Operator( - query=query, - s3_bucket=s3_bucket, - s3_key=s3_key, - sql_conn_id="mysql_conn_id", - aws_conn_id="aws_conn_id", - task_id="task_id", - replace=True, - pd_kwargs={"index": False, "header": False}, - dag=None, - ) - op._get_hook = mock_dbapi_hook - op.execute(None) - mock_s3_hook.assert_called_once_with(aws_conn_id="aws_conn_id", verify=None) - - get_df_mock.assert_called_once_with(sql=query, parameters=None, df_type="pandas") - - temp_mock.assert_called_once_with(mode="r+", suffix=".csv") - mock_s3_hook.return_value.load_file.assert_called_once_with( - filename=f.name, - key=s3_key, - bucket_name=s3_bucket, - replace=True, - ) + op = SqlToS3Operator( + query=query, + s3_bucket=s3_bucket, + s3_key=s3_key, + sql_conn_id="mysql_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + replace=True, + pd_kwargs={"index": False, "header": False}, + dag=None, + ) + op._get_hook = mock_dbapi_hook + op.execute(None) + + mock_s3_hook.assert_called_once_with(aws_conn_id="aws_conn_id", verify=None) + get_df_mock.assert_called_once_with(sql=query, parameters=None, df_type="pandas") + file_obj = mock_s3_hook.return_value.load_file_obj.call_args[1]["file_obj"] + assert isinstance(file_obj, io.BytesIO) + mock_s3_hook.return_value.load_file_obj.assert_called_once_with( + file_obj=file_obj, key=s3_key, bucket_name=s3_bucket, replace=True + ) - @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.NamedTemporaryFile") @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook") - def test_execute_parquet(self, mock_s3_hook, temp_mock): + def test_execute_parquet(self, mock_s3_hook): query = "query" s3_bucket = "bucket" s3_key = "key" mock_dbapi_hook = mock.Mock() - test_df = pd.DataFrame({"a": "1", "b": "2"}, index=[0, 1]) + get_df_mock = mock_dbapi_hook.return_value.get_df get_df_mock.return_value = test_df - with NamedTemporaryFile() as f: - temp_mock.return_value.__enter__.return_value.name = f.name - op = SqlToS3Operator( - query=query, - s3_bucket=s3_bucket, - s3_key=s3_key, - sql_conn_id="mysql_conn_id", - aws_conn_id="aws_conn_id", - task_id="task_id", - file_format="parquet", - replace=False, - dag=None, - ) - op._get_hook = mock_dbapi_hook - op.execute(None) - mock_s3_hook.assert_called_once_with(aws_conn_id="aws_conn_id", verify=None) + op = SqlToS3Operator( + query=query, + s3_bucket=s3_bucket, + s3_key=s3_key, + sql_conn_id="mysql_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + file_format="parquet", + replace=True, + dag=None, + ) + op._get_hook = mock_dbapi_hook + op.execute(None) + + mock_s3_hook.assert_called_once_with(aws_conn_id="aws_conn_id", verify=None) + get_df_mock.assert_called_once_with(sql=query, parameters=None, df_type="pandas") + file_obj = mock_s3_hook.return_value.load_file_obj.call_args[1]["file_obj"] + assert isinstance(file_obj, io.BytesIO) + mock_s3_hook.return_value.load_file_obj.assert_called_once_with( + file_obj=file_obj, key=s3_key, bucket_name=s3_bucket, replace=True + ) - get_df_mock.assert_called_once_with(sql=query, parameters=None, df_type="pandas") + @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook") + def test_execute_json(self, mock_s3_hook): + query = "query" + s3_bucket = "bucket" + s3_key = "key" - temp_mock.assert_called_once_with(mode="rb+", suffix=".parquet") - mock_s3_hook.return_value.load_file.assert_called_once_with( - filename=f.name, key=s3_key, bucket_name=s3_bucket, replace=False - ) + mock_dbapi_hook = mock.Mock() + test_df = pd.DataFrame({"a": "1", "b": "2"}, index=[0, 1]) + get_df_mock = mock_dbapi_hook.return_value.get_df + get_df_mock.return_value = test_df + + op = SqlToS3Operator( + query=query, + s3_bucket=s3_bucket, + s3_key=s3_key, + sql_conn_id="mysql_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + file_format="json", + replace=True, + pd_kwargs={"date_format": "iso", "lines": True, "orient": "records"}, + dag=None, + ) + op._get_hook = mock_dbapi_hook + op.execute(None) + + mock_s3_hook.assert_called_once_with(aws_conn_id="aws_conn_id", verify=None) + get_df_mock.assert_called_once_with(sql=query, parameters=None, df_type="pandas") + file_obj = mock_s3_hook.return_value.load_file_obj.call_args[1]["file_obj"] + assert isinstance(file_obj, io.BytesIO) + mock_s3_hook.return_value.load_file_obj.assert_called_once_with( + file_obj=file_obj, key=s3_key, bucket_name=s3_bucket, replace=True + ) - @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.NamedTemporaryFile") @mock.patch("airflow.providers.amazon.aws.transfers.sql_to_s3.S3Hook") - def test_execute_json(self, mock_s3_hook, temp_mock): + def test_execute_gzip_with_bytesio(self, mock_s3_hook): query = "query" s3_bucket = "bucket" - s3_key = "key" + s3_key = "key.csv.gz" mock_dbapi_hook = mock.Mock() test_df = pd.DataFrame({"a": "1", "b": "2"}, index=[0, 1]) get_df_mock = mock_dbapi_hook.return_value.get_df get_df_mock.return_value = test_df - with NamedTemporaryFile() as f: - temp_mock.return_value.__enter__.return_value.name = f.name - op = SqlToS3Operator( - query=query, - s3_bucket=s3_bucket, - s3_key=s3_key, - sql_conn_id="mysql_conn_id", - aws_conn_id="aws_conn_id", - task_id="task_id", - file_format="json", - replace=True, - pd_kwargs={"date_format": "iso", "lines": True, "orient": "records"}, - dag=None, - ) - op._get_hook = mock_dbapi_hook - op.execute(None) - mock_s3_hook.assert_called_once_with(aws_conn_id="aws_conn_id", verify=None) - - get_df_mock.assert_called_once_with(sql=query, parameters=None, df_type="pandas") - - temp_mock.assert_called_once_with(mode="r+", suffix=".json") - mock_s3_hook.return_value.load_file.assert_called_once_with( - filename=f.name, - key=s3_key, - bucket_name=s3_bucket, - replace=True, - ) + op = SqlToS3Operator( + query=query, + s3_bucket=s3_bucket, + s3_key=s3_key, + sql_conn_id="mysql_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + replace=True, + pd_kwargs={"index": False, "compression": "gzip"}, + dag=None, + ) + op._get_hook = mock_dbapi_hook + op.execute(None) + + mock_s3_hook.assert_called_once_with(aws_conn_id="aws_conn_id", verify=None) + get_df_mock.assert_called_once_with(sql=query, parameters=None, df_type="pandas") + file_obj = mock_s3_hook.return_value.load_file_obj.call_args[1]["file_obj"] + assert isinstance(file_obj, io.BytesIO) + mock_s3_hook.return_value.load_file_obj.assert_called_once_with( + file_obj=file_obj, key=s3_key, bucket_name=s3_bucket, replace=True + ) + + file_obj.seek(0) + with gzip.GzipFile(fileobj=file_obj, mode="rb") as gz: + decompressed_buf = io.BytesIO(gz.read()) + decompressed_buf.seek(0) + read_df = pd.read_csv(decompressed_buf, dtype={"a": str, "b": str}) + assert read_df.equals(test_df) @pytest.mark.parametrize( "params",