diff --git a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py index 33e994daaf9b7..376aa01f8d010 100644 --- a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py @@ -358,7 +358,11 @@ def retrieve_directory(self, remote_full_path: str, local_full_path: str, prefet self.retrieve_file(file_path, new_local_path, prefetch) def retrieve_directory_concurrently( - self, remote_full_path: str, local_full_path: str, workers: int = os.cpu_count() or 2 + self, + remote_full_path: str, + local_full_path: str, + workers: int = os.cpu_count() or 2, + prefetch: bool = True, ) -> None: """ Transfer the remote directory to a local location concurrently. @@ -405,6 +409,7 @@ def retrieve_file_chunk( conns[i], local_file_chunks[i], remote_file_chunks[i], + prefetch, ) for i in range(workers) ] diff --git a/providers/sftp/src/airflow/providers/sftp/operators/sftp.py b/providers/sftp/src/airflow/providers/sftp/operators/sftp.py index 8e1e3c85592cf..a896cfb5495a1 100644 --- a/providers/sftp/src/airflow/providers/sftp/operators/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/operators/sftp.py @@ -76,6 +76,7 @@ class SFTPOperator(BaseOperator): ) :param concurrency: Number of threads when transferring directories. Each thread opens a new SFTP connection. This parameter is used only when transferring directories, not individual files. (Default is 1) + :param prefetch: controls whether prefetch is performed (default: True) """ @@ -93,6 +94,7 @@ def __init__( confirm: bool = True, create_intermediate_dirs: bool = False, concurrency: int = 1, + prefetch: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) @@ -105,6 +107,7 @@ def __init__( self.local_filepath = local_filepath self.remote_filepath = remote_filepath self.concurrency = concurrency + self.prefetch = prefetch def execute(self, context: Any) -> str | list[str] | None: if self.local_filepath is None: @@ -173,6 +176,7 @@ def execute(self, context: Any) -> str | list[str] | None: _remote_filepath, _local_filepath, workers=self.concurrency, + prefetch=self.prefetch, ) elif self.concurrency == 1: self.sftp_hook.retrieve_directory(_remote_filepath, _local_filepath) diff --git a/providers/sftp/src/airflow/providers/sftp/triggers/sftp.py b/providers/sftp/src/airflow/providers/sftp/triggers/sftp.py index 08a047975fdf5..96af8ba73b2a7 100644 --- a/providers/sftp/src/airflow/providers/sftp/triggers/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/triggers/sftp.py @@ -84,7 +84,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: check whether the last modified time is greater, if true return file if false it polls again. """ hook = self._get_async_hook() - exc = None + if isinstance(self.newer_than, str): self.newer_than = parse_date(self.newer_than) _newer_than = convert_to_utc(self.newer_than) if self.newer_than else None diff --git a/providers/sftp/tests/unit/sftp/operators/test_sftp.py b/providers/sftp/tests/unit/sftp/operators/test_sftp.py index b7e94c376e1f9..fb4c5ed599252 100644 --- a/providers/sftp/tests/unit/sftp/operators/test_sftp.py +++ b/providers/sftp/tests/unit/sftp/operators/test_sftp.py @@ -96,6 +96,21 @@ def teardown_method(self): if os.path.exists(self.test_remote_dir): os.rmdir(self.test_remote_dir) + def test_default_args(self): + operator = SFTPOperator( + task_id="test_default_args", + remote_filepath="/tmp/remote_file", + ) + assert operator.operation == SFTPOperation.PUT + assert operator.confirm is True + assert operator.create_intermediate_dirs is False + assert operator.concurrency == 1 + assert operator.prefetch is True + assert operator.local_filepath is None + assert operator.sftp_hook is None + assert operator.ssh_conn_id is None + assert operator.remote_host is None + @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Pickle support is removed in Airflow 3") @conf_vars({("core", "enable_xcom_pickling"): "True"}) def test_pickle_file_transfer_put(self, dag_maker): @@ -442,8 +457,7 @@ def test_str_dirpaths_get_concurrently(self, mock_get): concurrency=2, ).execute(None) assert mock_get.call_count == 1 - args, _ = mock_get.call_args_list[0] - assert args == (remote_dirpath, local_dirpath) + assert mock_get.call_args == mock.call(remote_dirpath, local_dirpath, workers=2, prefetch=True) @mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.store_file") def test_str_filepaths_put(self, mock_get):