diff --git a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py index f5e3c798b1b01..337410f07f6f0 100644 --- a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py @@ -148,6 +148,8 @@ def get_managed_conn(self) -> Generator[SFTPClient, None, None]: self._sftp_conn = None self._ssh_conn.close() self._ssh_conn = None + if hasattr(self, "host_proxy"): + del self.host_proxy def get_conn_count(self) -> int: """Get the number of open connections.""" diff --git a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py index bfd83b81e97c0..a99550e3a4550 100644 --- a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py +++ b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py @@ -138,6 +138,48 @@ def test_get_managed_conn(self): assert self.hook.get_conn_count() == 0 assert self.hook.conn is None + @patch("paramiko.SSHClient") + @patch("paramiko.ProxyCommand") + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + def test_proxy_command_cache_invalidated_after_connection_closed( + self, mock_get_connection, mock_proxy_command, mock_ssh_client + ): + """ + Assert that the ProxyCommand gets invalidated after the connection is closed + """ + + mock_connection = MagicMock() + mock_connection.login = "user" + mock_connection.password = None + mock_connection.host = "example.com" + mock_connection.port = 22 + mock_connection.extra = None + mock_get_connection.return_value = mock_connection + + mock_sftp_client = MagicMock(spec=SFTPClient) + mock_ssh_client.open_sftp.return_value = mock_sftp_client + + mock_transport = MagicMock() + mock_ssh_client.return_value.get_transport.return_value = mock_transport + mock_proxy_command.return_value = MagicMock() + + host_proxy_cmd = "ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p" + prev_proxy_command = None + + hook = SFTPHook( + remote_host="example.com", + username="user", + host_proxy_cmd=host_proxy_cmd, + ) + + with hook.get_managed_conn() as _: + assert hasattr(self.hook, "host_proxy") + prev_proxy_command = hook.host_proxy + + mock_proxy_command.return_value = MagicMock() + + assert prev_proxy_command != hook.host_proxy + @patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn") def test_get_close_conn(self, mock_get_conn): mock_sftp_client = MagicMock(spec=SFTPClient)