diff --git a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sftp.py b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sftp.py index 817e88b8a4850..849c12bdc2af5 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sftp.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_sftp.py @@ -51,6 +51,9 @@ class S3ToSFTPOperator(BaseOperator): where the file is downloaded. :param s3_key: The targeted s3 key. This is the specified file path for downloading the file from S3. + :param confirm: specify if the SFTP operation should be confirmed, defaults to True. + When True, a stat will be performed on the remote file after upload to verify + the file size matches and confirm successful transfer. """ template_fields: Sequence[str] = ("s3_key", "sftp_path", "s3_bucket") @@ -63,6 +66,7 @@ def __init__( sftp_path: str, sftp_conn_id: str = "ssh_default", aws_conn_id: str | None = "aws_default", + confirm: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) @@ -71,6 +75,7 @@ def __init__( self.s3_bucket = s3_bucket self.s3_key = s3_key self.aws_conn_id = aws_conn_id + self.confirm = confirm @staticmethod def get_s3_key(s3_key: str) -> str: @@ -88,4 +93,4 @@ def execute(self, context: Context) -> None: with NamedTemporaryFile("w") as f: s3_client.download_file(self.s3_bucket, self.s3_key, f.name) - sftp_client.put(f.name, self.sftp_path) + sftp_client.put(f.name, self.sftp_path, confirm=self.confirm) diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sftp.py b/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sftp.py index 088f95e3fc925..fecf207c6f429 100644 --- a/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sftp.py +++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sftp.py @@ -139,5 +139,122 @@ def delete_remote_resource(self): assert remove_file_task is not None remove_file_task.execute(None) + @mock_aws + @conf_vars({("core", "enable_xcom_pickling"): "True"}) + def test_s3_to_sftp_operation_confirm_true_default(self): + """Test that S3ToSFTPOperator works with confirm=True by default (real SSH connection)""" + s3_hook = S3Hook(aws_conn_id=None) + # Setting + test_remote_file_content = ( + "This is remote file content for confirm=True test \n which is also multiline " + "another line here \n this is last line. EOF" + ) + + # Test for creation of s3 bucket + conn = boto3.client("s3") + conn.create_bucket(Bucket=self.s3_bucket) + assert s3_hook.check_for_bucket(self.s3_bucket) + + with open(LOCAL_FILE_PATH, "w") as file: + file.write(test_remote_file_content) + s3_hook.load_file(LOCAL_FILE_PATH, self.s3_key, bucket_name=BUCKET) + + # Check if object was created in s3 + objects_in_dest_bucket = conn.list_objects(Bucket=self.s3_bucket, Prefix=self.s3_key) + # there should be object found, and there should only be one object found + assert len(objects_in_dest_bucket["Contents"]) == 1 + + # the object found should be consistent with dest_key specified earlier + assert objects_in_dest_bucket["Contents"][0]["Key"] == self.s3_key + + # get remote file to local - Test with default confirm=True + run_task = S3ToSFTPOperator( + s3_bucket=BUCKET, + s3_key=S3_KEY, + sftp_path=SFTP_PATH, + sftp_conn_id=SFTP_CONN_ID, + task_id=TASK_ID + "_confirm_true", + dag=self.dag, + ) + assert run_task is not None + + run_task.execute(None) + + # Check that the file is created remotely with correct content + check_file_task = SSHOperator( + task_id="test_check_file_confirm_true", + ssh_hook=self.hook, + command=f"cat {self.sftp_path}", + do_xcom_push=True, + dag=self.dag, + ) + assert check_file_task is not None + result = check_file_task.execute(None) + assert result.strip() == test_remote_file_content.encode("utf-8") + + # Clean up after finishing with test + conn.delete_object(Bucket=self.s3_bucket, Key=self.s3_key) + conn.delete_bucket(Bucket=self.s3_bucket) + assert not s3_hook.check_for_bucket(self.s3_bucket) + + @mock_aws + @conf_vars({("core", "enable_xcom_pickling"): "True"}) + def test_s3_to_sftp_operation_confirm_false(self): + """Test that S3ToSFTPOperator works with confirm=False when specified (real SSH connection)""" + s3_hook = S3Hook(aws_conn_id=None) + # Setting + test_remote_file_content = ( + "This is remote file content for confirm=False test \n which is also multiline " + "another line here \n this is last line. EOF" + ) + + # Test for creation of s3 bucket + conn = boto3.client("s3") + conn.create_bucket(Bucket=self.s3_bucket) + assert s3_hook.check_for_bucket(self.s3_bucket) + + with open(LOCAL_FILE_PATH, "w") as file: + file.write(test_remote_file_content) + s3_hook.load_file(LOCAL_FILE_PATH, self.s3_key, bucket_name=BUCKET) + + # Check if object was created in s3 + objects_in_dest_bucket = conn.list_objects(Bucket=self.s3_bucket, Prefix=self.s3_key) + # there should be object found, and there should only be one object found + assert len(objects_in_dest_bucket["Contents"]) == 1 + + # the object found should be consistent with dest_key specified earlier + assert objects_in_dest_bucket["Contents"][0]["Key"] == self.s3_key + + # get remote file to local - Test with explicit confirm=False + run_task = S3ToSFTPOperator( + s3_bucket=BUCKET, + s3_key=S3_KEY, + sftp_path=SFTP_PATH, + sftp_conn_id=SFTP_CONN_ID, + task_id=TASK_ID + "_confirm_false", + confirm=False, # Explicitly set to False + dag=self.dag, + ) + assert run_task is not None + + run_task.execute(None) + + # Check that the file is created remotely with correct content + check_file_task = SSHOperator( + task_id="test_check_file_confirm_false", + ssh_hook=self.hook, + command=f"cat {self.sftp_path}", + do_xcom_push=True, + dag=self.dag, + ) + assert check_file_task is not None + result = check_file_task.execute(None) + assert result.strip() == test_remote_file_content.encode("utf-8") + + # Clean up after finishing with test + conn.delete_object(Bucket=self.s3_bucket, Key=self.s3_key) + conn.delete_bucket(Bucket=self.s3_bucket) + assert not s3_hook.check_for_bucket(self.s3_bucket) + def teardown_method(self): self.delete_remote_resource()