diff --git a/providers/samba/docs/connections.rst b/providers/samba/docs/connections.rst index e49477cc145d8..20661a525f09f 100644 --- a/providers/samba/docs/connections.rst +++ b/providers/samba/docs/connections.rst @@ -42,3 +42,6 @@ Login Password The password of the user that will be used for authentication against the Samba server. + +Share Type + The share OS type (``posix`` or ``windows``). Used to determine the formatting of file and folder paths. diff --git a/providers/samba/src/airflow/providers/samba/hooks/samba.py b/providers/samba/src/airflow/providers/samba/hooks/samba.py index c328ca0d6f0bb..90da67d720b5b 100644 --- a/providers/samba/src/airflow/providers/samba/hooks/samba.py +++ b/providers/samba/src/airflow/providers/samba/hooks/samba.py @@ -17,10 +17,10 @@ # under the License. from __future__ import annotations -import posixpath from functools import wraps +from pathlib import PurePosixPath, PureWindowsPath from shutil import copyfileobj -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import smbclient @@ -41,6 +41,8 @@ class SambaHook(BaseHook): :param share: An optional share name. If this is unset then the "schema" field of the connection is used in its place. + :param share_type: + An optional share type name. If this is unset then it will assume a posix share type. """ conn_name_attr = "samba_conn_id" @@ -48,7 +50,12 @@ class SambaHook(BaseHook): conn_type = "samba" hook_name = "Samba" - def __init__(self, samba_conn_id: str = default_conn_name, share: str | None = None) -> None: + def __init__( + self, + samba_conn_id: str = default_conn_name, + share: str | None = None, + share_type: Literal["posix", "windows"] | None = None, + ) -> None: super().__init__() conn = self.get_connection(samba_conn_id) @@ -58,6 +65,13 @@ def __init__(self, samba_conn_id: str = default_conn_name, share: str | None = N if not conn.password: self.log.info("Password not provided") + self._share_type = share_type or conn.extra_dejson.get("share_type", "posix") + if self._share_type not in {"posix", "windows"}: + self._share_type = "posix" + self.log.warning( + "Invalid share_type specified. It must be either 'posix' or 'windows'. Falling back to 'posix'" + ) + connection_cache: dict[str, smbprotocol.connection.Connection] = {} self._host = conn.host @@ -84,8 +98,18 @@ def __exit__(self, exc_type, exc_value, traceback): connection.disconnect() self._connection_cache.clear() + @staticmethod + def _join_posix_path(host: str, share: str, path: str) -> str: + return str(PurePosixPath("//" + host, share, path.lstrip("/"))) + + @staticmethod + def _join_windows_path(host: str, share: str, path: str) -> str: + return "\\{}".format(PureWindowsPath(rf"\\{host}\\{share}", path.lstrip(r"\/"))) + def _join_path(self, path): - return f"//{posixpath.join(self._host, self._share, path.lstrip('/'))}" + if self._share_type == "windows": + return self._join_windows_path(self._host, self._share, path) + return self._join_posix_path(self._host, self._share, path) @wraps(smbclient.link) def link(self, src, dst, follow_symlinks=True): @@ -293,6 +317,20 @@ def push_from_local(self, destination_filepath: str, local_filepath: str, buffer def get_ui_field_behaviour(cls) -> dict[str, Any]: """Return custom field behaviour.""" return { - "hidden_fields": ["extra"], + "hidden_fields": [], "relabeling": {"schema": "Share"}, } + + @classmethod + def get_connection_form_widgets(cls) -> dict[str, Any]: + """Return connection widgets to add to connection form.""" + from flask_babel import lazy_gettext + from wtforms import StringField + + return { + "share_type": StringField( + label=lazy_gettext("Share Type"), + description="The share OS type (`posix` or `windows`). Used to determine the formatting of file and folder paths.", + default="posix", + ) + } diff --git a/providers/samba/tests/unit/samba/hooks/test_samba.py b/providers/samba/tests/unit/samba/hooks/test_samba.py index 7d1e8d8366a46..1a3c1b838e849 100644 --- a/providers/samba/tests/unit/samba/hooks/test_samba.py +++ b/providers/samba/tests/unit/samba/hooks/test_samba.py @@ -148,14 +148,38 @@ def test_method(self, get_conn_mock, name): assert dict(kwargs, **connection_settings) == p_kwargs @pytest.mark.parametrize( - "path, full_path", + "path, path_type, full_path", [ - ("/start/path/with/slash", "//ip/share/start/path/with/slash"), - ("start/path/without/slash", "//ip/share/start/path/without/slash"), + # Linux path -> Linux path, no path_type (default) + ("/start/path/with/slash", None, "//ip/share/start/path/with/slash"), + ("start/path/without/slash", None, "//ip/share/start/path/without/slash"), + # Linux path -> Linux path, explicit path_type (posix) + ("/start/path/with/slash/posix", "posix", "//ip/share/start/path/with/slash/posix"), + ("start/path/without/slash/posix", "posix", "//ip/share/start/path/without/slash/posix"), + # Linux path -> Windows path, explicit path_type (windows) + ("/start/path/with/slash/windows", "windows", r"\\ip\share\start\path\with\slash\windows"), + ("start/path/without/slash/windows", "windows", r"\\ip\share\start\path\without\slash\windows"), + # Windows path -> Windows path, explicit path_type (windows) + ( + r"\start\path\with\backslash\windows", + "windows", + r"\\ip\share\start\path\with\backslash\windows", + ), + ( + r"start\path\without\backslash\windows", + "windows", + r"\\ip\share\start\path\without\backslash\windows", + ), ], ) @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") - def test__join_path(self, get_conn_mock, path, full_path): + def test__join_path( + self, + get_conn_mock, + path, + path_type, + full_path, + ): CONNECTION = Connection( host="ip", schema="share", @@ -164,7 +188,7 @@ def test__join_path(self, get_conn_mock, path, full_path): ) get_conn_mock.return_value = CONNECTION - hook = SambaHook("samba_default") + hook = SambaHook("samba_default", share_type=path_type) assert hook._join_path(path) == full_path @mock.patch("airflow.providers.samba.hooks.samba.smbclient.open_file", return_value=mock.Mock())