diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py index 24e0e012e6dc0..d71a8383e8ad7 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py @@ -1616,6 +1616,7 @@ def download_file( ExtraArgs=extra_args, Config=self.transfer_config, ) + file.flush() get_hook_lineage_collector().add_input_asset( context=self, scheme="s3", asset_kwargs={"bucket": bucket_name, "key": key} ) diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py index 74c8eb3ab3b06..f84f7b8ba0adc 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py @@ -1449,8 +1449,8 @@ def test_function_with_test_key(self, test_key, bucket_name=None): @mock.patch("airflow.providers.amazon.aws.hooks.s3.NamedTemporaryFile") def test_download_file(self, mock_temp_file, tmp_path): - path = tmp_path / "airflow_tmp_test_s3_hook" - mock_temp_file.return_value = path + mock_file = mock_temp_file.return_value + mock_file.name = str(tmp_path / "airflow_tmp_test_s3_hook") s3_hook = S3Hook(aws_conn_id="s3_test", requester_pays=True) s3_hook.check_for_key = Mock(return_value=True) s3_obj = Mock() @@ -1463,17 +1463,17 @@ def test_download_file(self, mock_temp_file, tmp_path): s3_hook.get_key.assert_called_once_with(key, bucket) s3_obj.download_fileobj.assert_called_once_with( - path, + mock_file, Config=s3_hook.transfer_config, ExtraArgs={"RequestPayer": "requester"}, ) - assert path.name == output_file + assert mock_file.name == output_file @mock.patch("airflow.providers.amazon.aws.hooks.s3.NamedTemporaryFile") def test_download_file_exposes_lineage(self, mock_temp_file, tmp_path, hook_lineage_collector): - path = tmp_path / "airflow_tmp_test_s3_hook" - mock_temp_file.return_value = path + mock_file = mock_temp_file.return_value + mock_file.name = str(tmp_path / "airflow_tmp_test_s3_hook") s3_hook = S3Hook(aws_conn_id="s3_test") s3_hook.check_for_key = Mock(return_value=True) s3_obj = Mock()