Skip to content

Commit

Permalink
Return str in execute if local_filepath was passed as str
Browse files Browse the repository at this point in the history
  • Loading branch information
pauldalewilliams committed Sep 27, 2022
1 parent 09c5f59 commit 48f12e6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
6 changes: 4 additions & 2 deletions airflow/providers/sftp/operators/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@ def __init__(
self.confirm = confirm
self.create_intermediate_dirs = create_intermediate_dirs

self.local_filepath_was_str = False
if isinstance(local_filepath, str):
self.local_filepath = [local_filepath]
self.local_filepath_was_str = True
else:
self.local_filepath = local_filepath

Expand Down Expand Up @@ -143,7 +145,7 @@ def __init__(
)
self.sftp_hook = SFTPHook(ssh_hook=self.ssh_hook)

def execute(self, context: Any) -> list[str] | None:
def execute(self, context: Any) -> str | list[str] | None:
file_msg = None
try:
if self.ssh_conn_id:
Expand Down Expand Up @@ -185,4 +187,4 @@ def execute(self, context: Any) -> list[str] | None:
except Exception as e:
raise AirflowException(f"Error while transferring {file_msg}, error: {str(e)}")

return self.local_filepath
return self.local_filepath[0] if self.local_filepath_was_str else self.local_filepath
15 changes: 15 additions & 0 deletions tests/providers/sftp/operators/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,18 @@ def test_multiple_paths_put(self, mock_put):
args1, _ = mock_put.call_args_list[1]
assert args0 == (remote_filepath[0], local_filepath[0])
assert args1 == (remote_filepath[1], local_filepath[1])

@mock.patch('airflow.providers.sftp.operators.sftp.SFTPHook.retrieve_file')
def test_return_str_when_local_filepath_was_str(self, mock_get):
local_filepath = '/tmp/ltest1'
remote_filepath = '/tmp/rtest1'
sftp_op = SFTPOperator(
task_id='test_returns_str',
sftp_hook=self.sftp_hook,
local_filepath=local_filepath,
remote_filepath=remote_filepath,
operation=SFTPOperation.GET,
)
return_value = sftp_op.execute(None)
assert isinstance(return_value, str)
assert return_value == local_filepath

0 comments on commit 48f12e6

Please sign in to comment.