Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class S3ToSFTPOperator(BaseOperator):
where the file is downloaded.
:param s3_key: The targeted s3 key. This is the specified file path for
downloading the file from S3.
:param confirm: specify if the SFTP operation should be confirmed, defaults to True.
When True, a stat will be performed on the remote file after upload to verify
the file size matches and confirm successful transfer.
"""

template_fields: Sequence[str] = ("s3_key", "sftp_path", "s3_bucket")
Expand All @@ -63,6 +66,7 @@ def __init__(
sftp_path: str,
sftp_conn_id: str = "ssh_default",
aws_conn_id: str | None = "aws_default",
confirm: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -71,6 +75,7 @@ def __init__(
self.s3_bucket = s3_bucket
self.s3_key = s3_key
self.aws_conn_id = aws_conn_id
self.confirm = confirm

@staticmethod
def get_s3_key(s3_key: str) -> str:
Expand All @@ -88,4 +93,4 @@ def execute(self, context: Context) -> None:

with NamedTemporaryFile("w") as f:
s3_client.download_file(self.s3_bucket, self.s3_key, f.name)
sftp_client.put(f.name, self.sftp_path)
sftp_client.put(f.name, self.sftp_path, confirm=self.confirm)
117 changes: 117 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,5 +139,122 @@ def delete_remote_resource(self):
assert remove_file_task is not None
remove_file_task.execute(None)

@mock_aws
@conf_vars({("core", "enable_xcom_pickling"): "True"})
def test_s3_to_sftp_operation_confirm_true_default(self):
"""Test that S3ToSFTPOperator works with confirm=True by default (real SSH connection)"""
s3_hook = S3Hook(aws_conn_id=None)
# Setting
test_remote_file_content = (
"This is remote file content for confirm=True test \n which is also multiline "
"another line here \n this is last line. EOF"
)

# Test for creation of s3 bucket
conn = boto3.client("s3")
conn.create_bucket(Bucket=self.s3_bucket)
assert s3_hook.check_for_bucket(self.s3_bucket)

with open(LOCAL_FILE_PATH, "w") as file:
file.write(test_remote_file_content)
s3_hook.load_file(LOCAL_FILE_PATH, self.s3_key, bucket_name=BUCKET)

# Check if object was created in s3
objects_in_dest_bucket = conn.list_objects(Bucket=self.s3_bucket, Prefix=self.s3_key)
# there should be object found, and there should only be one object found
assert len(objects_in_dest_bucket["Contents"]) == 1

# the object found should be consistent with dest_key specified earlier
assert objects_in_dest_bucket["Contents"][0]["Key"] == self.s3_key

# get remote file to local - Test with default confirm=True
run_task = S3ToSFTPOperator(
s3_bucket=BUCKET,
s3_key=S3_KEY,
sftp_path=SFTP_PATH,
sftp_conn_id=SFTP_CONN_ID,
task_id=TASK_ID + "_confirm_true",
dag=self.dag,
)
assert run_task is not None

run_task.execute(None)

# Check that the file is created remotely with correct content
check_file_task = SSHOperator(
task_id="test_check_file_confirm_true",
ssh_hook=self.hook,
command=f"cat {self.sftp_path}",
do_xcom_push=True,
dag=self.dag,
)
assert check_file_task is not None
result = check_file_task.execute(None)
assert result.strip() == test_remote_file_content.encode("utf-8")

# Clean up after finishing with test
conn.delete_object(Bucket=self.s3_bucket, Key=self.s3_key)
conn.delete_bucket(Bucket=self.s3_bucket)
assert not s3_hook.check_for_bucket(self.s3_bucket)

@mock_aws
@conf_vars({("core", "enable_xcom_pickling"): "True"})
def test_s3_to_sftp_operation_confirm_false(self):
"""Test that S3ToSFTPOperator works with confirm=False when specified (real SSH connection)"""
s3_hook = S3Hook(aws_conn_id=None)
# Setting
test_remote_file_content = (
"This is remote file content for confirm=False test \n which is also multiline "
"another line here \n this is last line. EOF"
)

# Test for creation of s3 bucket
conn = boto3.client("s3")
conn.create_bucket(Bucket=self.s3_bucket)
assert s3_hook.check_for_bucket(self.s3_bucket)

with open(LOCAL_FILE_PATH, "w") as file:
file.write(test_remote_file_content)
s3_hook.load_file(LOCAL_FILE_PATH, self.s3_key, bucket_name=BUCKET)

# Check if object was created in s3
objects_in_dest_bucket = conn.list_objects(Bucket=self.s3_bucket, Prefix=self.s3_key)
# there should be object found, and there should only be one object found
assert len(objects_in_dest_bucket["Contents"]) == 1

# the object found should be consistent with dest_key specified earlier
assert objects_in_dest_bucket["Contents"][0]["Key"] == self.s3_key

# get remote file to local - Test with explicit confirm=False
run_task = S3ToSFTPOperator(
s3_bucket=BUCKET,
s3_key=S3_KEY,
sftp_path=SFTP_PATH,
sftp_conn_id=SFTP_CONN_ID,
task_id=TASK_ID + "_confirm_false",
confirm=False, # Explicitly set to False
dag=self.dag,
)
assert run_task is not None

run_task.execute(None)

# Check that the file is created remotely with correct content
check_file_task = SSHOperator(
task_id="test_check_file_confirm_false",
ssh_hook=self.hook,
command=f"cat {self.sftp_path}",
do_xcom_push=True,
dag=self.dag,
)
assert check_file_task is not None
result = check_file_task.execute(None)
assert result.strip() == test_remote_file_content.encode("utf-8")

# Clean up after finishing with test
conn.delete_object(Bucket=self.s3_bucket, Key=self.s3_key)
conn.delete_bucket(Bucket=self.s3_bucket)
assert not s3_hook.check_for_bucket(self.s3_bucket)

def teardown_method(self):
self.delete_remote_resource()