Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions providers/samba/docs/connections.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,6 @@ Login

Password
The password of the user that will be used for authentication against the Samba server.

Share Type
The share OS type (``posix`` or ``windows``). Used to determine the formatting of file and folder paths.
48 changes: 43 additions & 5 deletions providers/samba/src/airflow/providers/samba/hooks/samba.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
# under the License.
from __future__ import annotations

import posixpath
from functools import wraps
from pathlib import PurePosixPath, PureWindowsPath
from shutil import copyfileobj
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

import smbclient

Expand All @@ -41,14 +41,21 @@ class SambaHook(BaseHook):
:param share:
An optional share name. If this is unset then the "schema" field of
the connection is used in its place.
:param share_type:
An optional share type name. If this is unset then it will assume a posix share type.
"""

conn_name_attr = "samba_conn_id"
default_conn_name = "samba_default"
conn_type = "samba"
hook_name = "Samba"

def __init__(self, samba_conn_id: str = default_conn_name, share: str | None = None) -> None:
def __init__(
self,
samba_conn_id: str = default_conn_name,
share: str | None = None,
share_type: Literal["posix", "windows"] | None = None,
) -> None:
super().__init__()
conn = self.get_connection(samba_conn_id)

Expand All @@ -58,6 +65,13 @@ def __init__(self, samba_conn_id: str = default_conn_name, share: str | None = N
if not conn.password:
self.log.info("Password not provided")

self._share_type = share_type or conn.extra_dejson.get("share_type", "posix")
if self._share_type not in {"posix", "windows"}:
self._share_type = "posix"
self.log.warning(
"Invalid share_type specified. It must be either 'posix' or 'windows'. Falling back to 'posix'"
)

connection_cache: dict[str, smbprotocol.connection.Connection] = {}

self._host = conn.host
Expand All @@ -84,8 +98,18 @@ def __exit__(self, exc_type, exc_value, traceback):
connection.disconnect()
self._connection_cache.clear()

@staticmethod
def _join_posix_path(host: str, share: str, path: str) -> str:
return str(PurePosixPath("//" + host, share, path.lstrip("/")))

@staticmethod
def _join_windows_path(host: str, share: str, path: str) -> str:
return "\\{}".format(PureWindowsPath(rf"\\{host}\\{share}", path.lstrip(r"\/")))

def _join_path(self, path):
return f"//{posixpath.join(self._host, self._share, path.lstrip('/'))}"
if self._share_type == "windows":
return self._join_windows_path(self._host, self._share, path)
return self._join_posix_path(self._host, self._share, path)

@wraps(smbclient.link)
def link(self, src, dst, follow_symlinks=True):
Expand Down Expand Up @@ -293,6 +317,20 @@ def push_from_local(self, destination_filepath: str, local_filepath: str, buffer
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Return custom field behaviour."""
return {
"hidden_fields": ["extra"],
"hidden_fields": [],
"relabeling": {"schema": "Share"},
}

@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Return connection widgets to add to connection form."""
from flask_babel import lazy_gettext
from wtforms import StringField

return {
"share_type": StringField(
label=lazy_gettext("Share Type"),
description="The share OS type (`posix` or `windows`). Used to determine the formatting of file and folder paths.",
default="posix",
)
}
34 changes: 29 additions & 5 deletions providers/samba/tests/unit/samba/hooks/test_samba.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,38 @@ def test_method(self, get_conn_mock, name):
assert dict(kwargs, **connection_settings) == p_kwargs

@pytest.mark.parametrize(
"path, full_path",
"path, path_type, full_path",
[
("/start/path/with/slash", "//ip/share/start/path/with/slash"),
("start/path/without/slash", "//ip/share/start/path/without/slash"),
# Linux path -> Linux path, no path_type (default)
("/start/path/with/slash", None, "//ip/share/start/path/with/slash"),
("start/path/without/slash", None, "//ip/share/start/path/without/slash"),
# Linux path -> Linux path, explicit path_type (posix)
("/start/path/with/slash/posix", "posix", "//ip/share/start/path/with/slash/posix"),
("start/path/without/slash/posix", "posix", "//ip/share/start/path/without/slash/posix"),
# Linux path -> Windows path, explicit path_type (windows)
("/start/path/with/slash/windows", "windows", r"\\ip\share\start\path\with\slash\windows"),
("start/path/without/slash/windows", "windows", r"\\ip\share\start\path\without\slash\windows"),
# Windows path -> Windows path, explicit path_type (windows)
(
r"\start\path\with\backslash\windows",
"windows",
r"\\ip\share\start\path\with\backslash\windows",
),
(
r"start\path\without\backslash\windows",
"windows",
r"\\ip\share\start\path\without\backslash\windows",
),
],
)
@mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection")
def test__join_path(self, get_conn_mock, path, full_path):
def test__join_path(
self,
get_conn_mock,
path,
path_type,
full_path,
):
CONNECTION = Connection(
host="ip",
schema="share",
Expand All @@ -164,7 +188,7 @@ def test__join_path(self, get_conn_mock, path, full_path):
)

get_conn_mock.return_value = CONNECTION
hook = SambaHook("samba_default")
hook = SambaHook("samba_default", share_type=path_type)
assert hook._join_path(path) == full_path

@mock.patch("airflow.providers.samba.hooks.samba.smbclient.open_file", return_value=mock.Mock())
Expand Down