Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,11 @@ def retrieve_directory(self, remote_full_path: str, local_full_path: str, prefet
self.retrieve_file(file_path, new_local_path, prefetch)

def retrieve_directory_concurrently(
self, remote_full_path: str, local_full_path: str, workers: int = os.cpu_count() or 2
self,
remote_full_path: str,
local_full_path: str,
workers: int = os.cpu_count() or 2,
prefetch: bool = True,
) -> None:
"""
Transfer the remote directory to a local location concurrently.
Expand Down Expand Up @@ -405,6 +409,7 @@ def retrieve_file_chunk(
conns[i],
local_file_chunks[i],
remote_file_chunks[i],
prefetch,
)
for i in range(workers)
]
Expand Down
4 changes: 4 additions & 0 deletions providers/sftp/src/airflow/providers/sftp/operators/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class SFTPOperator(BaseOperator):
)
:param concurrency: Number of threads when transferring directories. Each thread opens a new SFTP connection.
This parameter is used only when transferring directories, not individual files. (Default is 1)
:param prefetch: controls whether prefetch is performed (default: True)

"""

Expand All @@ -93,6 +94,7 @@ def __init__(
confirm: bool = True,
create_intermediate_dirs: bool = False,
concurrency: int = 1,
prefetch: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -105,6 +107,7 @@ def __init__(
self.local_filepath = local_filepath
self.remote_filepath = remote_filepath
self.concurrency = concurrency
self.prefetch = prefetch

def execute(self, context: Any) -> str | list[str] | None:
if self.local_filepath is None:
Expand Down Expand Up @@ -173,6 +176,7 @@ def execute(self, context: Any) -> str | list[str] | None:
_remote_filepath,
_local_filepath,
workers=self.concurrency,
prefetch=self.prefetch,
)
elif self.concurrency == 1:
self.sftp_hook.retrieve_directory(_remote_filepath, _local_filepath)
Expand Down
2 changes: 1 addition & 1 deletion providers/sftp/src/airflow/providers/sftp/triggers/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
check whether the last modified time is greater, if true return file if false it polls again.
"""
hook = self._get_async_hook()
exc = None

if isinstance(self.newer_than, str):
self.newer_than = parse_date(self.newer_than)
_newer_than = convert_to_utc(self.newer_than) if self.newer_than else None
Expand Down
18 changes: 16 additions & 2 deletions providers/sftp/tests/unit/sftp/operators/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,21 @@ def teardown_method(self):
if os.path.exists(self.test_remote_dir):
os.rmdir(self.test_remote_dir)

def test_default_args(self):
operator = SFTPOperator(
task_id="test_default_args",
remote_filepath="/tmp/remote_file",
)
assert operator.operation == SFTPOperation.PUT
assert operator.confirm is True
assert operator.create_intermediate_dirs is False
assert operator.concurrency == 1
assert operator.prefetch is True
assert operator.local_filepath is None
assert operator.sftp_hook is None
assert operator.ssh_conn_id is None
assert operator.remote_host is None

@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Pickle support is removed in Airflow 3")
@conf_vars({("core", "enable_xcom_pickling"): "True"})
def test_pickle_file_transfer_put(self, dag_maker):
Expand Down Expand Up @@ -442,8 +457,7 @@ def test_str_dirpaths_get_concurrently(self, mock_get):
concurrency=2,
).execute(None)
assert mock_get.call_count == 1
args, _ = mock_get.call_args_list[0]
assert args == (remote_dirpath, local_dirpath)
assert mock_get.call_args == mock.call(remote_dirpath, local_dirpath, workers=2, prefetch=True)

@mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.store_file")
def test_str_filepaths_put(self, mock_get):
Expand Down