diff --git a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py index edc08ef5b851b..0fb1e75a3f434 100644 --- a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py +++ b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py @@ -226,6 +226,7 @@ def run( ps_path: str | None = None, output_encoding: str = "utf-8", return_output: bool = True, + working_directory: str | None = None, ) -> tuple[int, list[bytes], list[bytes]]: """ Run a command. @@ -235,12 +236,13 @@ def run( If specified, it will execute the command as powershell script. :param output_encoding: the encoding used to decode stout and stderr. :param return_output: Whether to accumulate and return the stdout or not. + :param working_directory: specify working directory. :return: returns a tuple containing return_code, stdout and stderr in order. """ winrm_client = self.get_conn() self.log.info("Establishing WinRM connection to host: %s", self.remote_host) try: - shell_id = winrm_client.open_shell() + shell_id = winrm_client.open_shell(working_directory=working_directory) except Exception as error: error_msg = f"Error connecting to host: {self.remote_host}, error: {error}" self.log.error(error_msg) diff --git a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py index 917297c898ea5..c72dd3bf1140e 100644 --- a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py +++ b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py @@ -50,10 +50,14 @@ class WinRMOperator(BaseOperator): :param output_encoding: the encoding used to decode stout and stderr :param timeout: timeout for executing the command. :param expected_return_code: expected return code value(s) of command. + :param working_directory: specify working directory. """ - template_fields: Sequence[str] = ("command",) - template_fields_renderers = {"command": "powershell"} + template_fields: Sequence[str] = ( + "command", + "working_directory", + ) + template_fields_renderers = {"command": "powershell", "working_directory": "powershell"} def __init__( self, @@ -66,6 +70,7 @@ def __init__( output_encoding: str = "utf-8", timeout: int = 10, expected_return_code: int | list[int] | range = 0, + working_directory: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -77,6 +82,7 @@ def __init__( self.output_encoding = output_encoding self.timeout = timeout self.expected_return_code = expected_return_code + self.working_directory = working_directory def execute(self, context: Context) -> list | str: if self.ssh_conn_id and not self.winrm_hook: @@ -97,6 +103,7 @@ def execute(self, context: Context) -> list | str: ps_path=self.ps_path, output_encoding=self.output_encoding, return_output=self.do_xcom_push, + working_directory=self.working_directory, ) success = False diff --git a/providers/microsoft/winrm/tests/unit/microsoft/winrm/hooks/test_winrm.py b/providers/microsoft/winrm/tests/unit/microsoft/winrm/hooks/test_winrm.py index b4ff975b7eb7c..2403fd6628e07 100644 --- a/providers/microsoft/winrm/tests/unit/microsoft/winrm/hooks/test_winrm.py +++ b/providers/microsoft/winrm/tests/unit/microsoft/winrm/hooks/test_winrm.py @@ -22,9 +22,13 @@ import pytest from airflow.exceptions import AirflowException -from airflow.models import Connection from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook +try: + from airflow.sdk import Connection # type: ignore +except ImportError: + from airflow.models import Connection # type: ignore + class TestWinRMHook: def test_get_conn_missing_remote_host(self): @@ -42,6 +46,8 @@ def test_get_conn_error(self, mock_protocol): @patch( "airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection", return_value=Connection( + conn_id="", + conn_type="", login="username", password="password", host="remote_host", @@ -113,6 +119,8 @@ def test_get_conn_no_endpoint(self, mock_protocol): @patch( "airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection", return_value=Connection( + conn_id="", + conn_type="", login="username", password="password", host="remote_host", @@ -154,6 +162,8 @@ def test_run_with_stdout(self, mock_get_connection, mock_protocol): @patch( "airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection", return_value=Connection( + conn_id="", + conn_type="", login="username", password="password", host="remote_host", @@ -177,16 +187,20 @@ def test_run_with_stdout(self, mock_get_connection, mock_protocol): }""", ), ) - def test_run_without_stdout(self, mock_get_connection, mock_protocol): + def test_run_without_stdout_and_working_dir(self, mock_get_connection, mock_protocol): winrm_hook = WinRMHook(ssh_conn_id="conn_id") - + working_dir = "c:\\test" mock_protocol.return_value.run_command = MagicMock(return_value="command_id") mock_protocol.return_value.get_command_output_raw = MagicMock( return_value=(b"stdout", b"stderr", 0, True) ) + mock_protocol.return_value.open_shell = MagicMock() - return_code, stdout_buffer, stderr_buffer = winrm_hook.run("dir", return_output=False) + return_code, stdout_buffer, stderr_buffer = winrm_hook.run( + "dir", return_output=False, working_directory=working_dir + ) + mock_protocol.return_value.open_shell.assert_called_once_with(working_directory=working_dir) assert return_code == 0 assert not stdout_buffer assert stderr_buffer == [b"stderr"] diff --git a/providers/microsoft/winrm/tests/unit/microsoft/winrm/operators/test_winrm.py b/providers/microsoft/winrm/tests/unit/microsoft/winrm/operators/test_winrm.py index 8997395b0bf5c..95650799a3be9 100644 --- a/providers/microsoft/winrm/tests/unit/microsoft/winrm/operators/test_winrm.py +++ b/providers/microsoft/winrm/tests/unit/microsoft/winrm/operators/test_winrm.py @@ -44,8 +44,11 @@ def test_no_command(self, mock_hook): def test_default_returning_0_command(self, mock_hook): stdout = [b"O", b"K"] command = "not_empty" + working_dir = "c:\\temp" mock_hook.run.return_value = (0, stdout, []) - op = WinRMOperator(task_id="test_task_id", winrm_hook=mock_hook, command=command) + op = WinRMOperator( + task_id="test_task_id", winrm_hook=mock_hook, command=command, working_directory=working_dir + ) execute_result = op.execute(None) assert execute_result == b64encode(b"".join(stdout)).decode("utf-8") mock_hook.run.assert_called_once_with( @@ -53,6 +56,7 @@ def test_default_returning_0_command(self, mock_hook): ps_path=None, output_encoding="utf-8", return_output=True, + working_directory=working_dir, ) @mock.patch("airflow.providers.microsoft.winrm.operators.winrm.WinRMHook") @@ -94,6 +98,7 @@ def test_expected_return_code_command(self, mock_hook, expected_return_code, rea ps_path=None, output_encoding="utf-8", return_output=True, + working_directory=None, ) else: exception_msg = f"Error running cmd: {command}, return code: {real_return_code}, error: KO"