diff --git a/.github/settings.yml b/.github/settings.yml index 50ad365f..4e9d96d1 100644 --- a/.github/settings.yml +++ b/.github/settings.yml @@ -79,6 +79,7 @@ labels: - { name: '📦 package: postgres', color: '#0052CC', description: '' } - { name: '📦 package: rabbitmq', color: '#0052CC', description: '' } - { name: '📦 package: selenium', color: '#0052CC', description: '' } + - { name: '📦 package: sftp', color: '#0052CC', description: '' } - { name: '🔀 requires triage', color: '#bfdadc', description: '' } - { name: '🔧 maintenance', color: '#c2f759', description: '' } - { name: '🚀 enhancement', color: '#84b6eb', description: '' } diff --git a/modules/sftp/README.rst b/modules/sftp/README.rst new file mode 100644 index 00000000..2287d59c --- /dev/null +++ b/modules/sftp/README.rst @@ -0,0 +1,3 @@ +.. autoclass:: testcontainers.sftp.SFTPContainer +.. autoclass:: testcontainers.sftp.SFTPUser +.. title:: testcontainers.sftp.SFTPContainer diff --git a/modules/sftp/testcontainers/sftp/__init__.py b/modules/sftp/testcontainers/sftp/__init__.py new file mode 100644 index 00000000..0e073ea1 --- /dev/null +++ b/modules/sftp/testcontainers/sftp/__init__.py @@ -0,0 +1,301 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import tempfile +from typing import TYPE_CHECKING, Any, NamedTuple + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa + +from testcontainers.core.container import DockerContainer +from testcontainers.core.waiting_utils import wait_for_logs + +if TYPE_CHECKING: + from typing_extensions import Self + + +class SFTPUser: + """ + Helper class to define a user for SFTPContainer authentication. + + Constructor args/kwargs: + + * ``name``: (req.) username + * ``public_key``: (opt.) bytes of publickey + * ``private_key``: (opt.) bytes of privatekey (useful if you want to access \ + them later in test code) + * ``password``: (opt.) password + * ``uid``: (opt.) user ID + * ``gid``: (opt.) group ID + * ``folders``: (opt.) folders to create inside the user's directory (e.g. upload/) + * ``mount_dir``: (opt.) a local folder to mount to the user's root directory + + Properties: + + * ``public_key_file``: str path of public key tempfile (gets mounted to \ + SFTPContainer as a volume) + * ``private_key_file``: str path of private key tempfile (useful to pass to \ + paramiko when connecting to the sftp server using ssh + + Methods: + + * ``with_keypair``: classmethod to create a new user with an auto-generated RSA keypair + * ``conf``: str configuration string to register user on server + + + Example: + + .. doctest:: + + >>> from testcontainers.sftp import SFTPUser + + >>> users = [ + ... SFTPUser("jane", password="secret"), + ... SFTPUser.with_keypair("ron", folders=["stuff"]), + ... ] + + >>> for user in users: + ... print(user.name, user.folders[0]) + ... + jane upload + ron stuff + + >>> assert users[0].password == "secret" + + >>> assert users[1].public_key is not None + + >>> assert users[1].public_key.decode().startswith("ssh-rsa ") + + >>> assert users[1].private_key is not None + + >>> assert users[1].private_key.decode().startswith("-----BEGIN RSA PRIVATE KEY-----") + """ + + def __init__( + self, + name: str, + *, + public_key: bytes | None = None, + private_key: bytes | None = None, + password: str | None = None, + uid: str | None = None, + gid: str | None = None, + folders: list[str] | None = None, + mount_dir: str | None = None, + ) -> None: + if folders is None: + folders = ["upload"] + self.name = name + self.public_key = public_key + self.private_key = private_key + self.password = password + self.uid = uid + self.gid = gid + self.folders = folders + self.mount_dir = mount_dir + + self.public_key_file: str | None = None + if self.public_key is not None: + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(self.public_key) + self.public_key_file = f.name + + self.private_key_file: str | None = None + if self.private_key is not None: + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(self.private_key) + self.private_key_file = f.name + + def __del__(self) -> None: + """Clean up keypair temp files""" + if self.public_key_file is not None: + os.unlink(self.public_key_file) + if self.private_key_file is not None: + os.unlink(self.private_key_file) + + @property + def conf(self) -> str: + """Configuration string to register user on server""" + return ":".join( + [ + self.name, + self.password or "", + self.uid or "", + self.gid or "", + ",".join(self.folders), + ] + ) + + @classmethod + def with_keypair( + cls, + name: str, + password: str | None = None, + uid: str | None = None, + gid: str | None = None, + folders: list[str] | None = None, + mount_dir: str | None = None, + ) -> SFTPUser: + """Construct a new SFTPUser with an auto-generated RSA keypair""" + keypair = _generate_keypair() + return SFTPUser( + name=name, + public_key=keypair.public_key, + private_key=keypair.private_key, + password=password, + uid=uid, + gid=gid, + folders=folders, + mount_dir=mount_dir, + ) + + def __repr__(self) -> str: + return ( + f"SFTPUser({self.name}, password={self.password}, uid={self.uid}," + f" gid={self.gid}, folders={self.folders}," + f" public_key_file={self.public_key_file}," + f" private_key_file={self.private_key_file})" + ) + + +class SFTPContainer(DockerContainer): + """Test container for an SFTP server. + + Default configuration creates two users, ``basic:password`` and ``keypair`` + which has no password but should use the private key accessible at + ``my_container.users[1].private_key``. + + **Users can only download from their root user folder, but can upload & + download from any subfolder** (``upload/`` by default). + + Options: + + * ``users = [SFTPUser("jane", password="secret"), SFTPUser.with_keypair("ron")]`` \ + creates ``jane:secret`` or ``ron`` who uses the private key accessible at \ + ``users[1].private_key``. + + Simple example with basic auth: + + .. doctest:: + + >>> import paramiko + + >>> from testcontainers.sftp import SFTPContainer + + >>> with SFTPContainer() as sftp_container: + ... host_ip = sftp_container.get_container_host_ip() + ... host_port = sftp_container.get_exposed_sftp_port() + ... ssh = paramiko.SSHClient() + ... ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ... ssh.connect(host_ip, host_port, "basic", "password") + ... # ssh.get(...) + ... # ssh.listdir() + ... # ssh.chdir("upload") + ... # ssh.put(...) + + Example with keypair auth: + + .. doctest:: + + >>> import paramiko + + >>> from testcontainers.sftp import SFTPContainer + + >>> with SFTPContainer() as sftp_container: + ... host_ip = sftp_container.get_container_host_ip() + ... host_port = sftp_container.get_exposed_sftp_port() + ... ssh = paramiko.SSHClient() + ... ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ... private_key_file = sftp_container.users[1].private_key_file + ... ssh.connect(host_ip, host_port, "keypair", key_filename=private_key_file) + ... # ssh.listdir() + ... # ssh.get(...) + ... # ssh.chdir("upload") + ... # ssh.put(...) + """ + + def __init__( + self, + image: str = "atmoz/sftp:alpine", + port: int = 22, + *, + users: list[SFTPUser] | None = None, + **kwargs: Any, + ) -> None: + if users is None: + users = [ + SFTPUser(name="basic", password="password"), + SFTPUser.with_keypair(name="keypair"), + ] + + super().__init__(image=image, **kwargs) + self.port = port + self.users = users + + @property + def _users_conf(self) -> str: + return " ".join(user.conf for user in self.users) + + def _configure(self) -> None: + for user in self.users: + if user.public_key_file is not None: + self.with_volume_mapping( + user.public_key_file, + f"/home/{user.name}/.ssh/keys/{user.name}.pub", + ) + if user.mount_dir is not None: + self.with_volume_mapping( + user.mount_dir, + f"/home/{user.name}/", + "rw", + ) + self.with_env("SFTP_USERS", self._users_conf) + self.with_exposed_ports(self.port) + + def start(self) -> Self: + super().start() + wait_for_logs(self, f".*Server listening on 0.0.0.0 port {self.port}.*") + return self + + def get_exposed_sftp_port(self) -> int: + return int(self.get_exposed_port(self.port)) + + +class _Keypair(NamedTuple): + """RSA keypair as bytes""" + + private_key: bytes + public_key: bytes + + +def _generate_keypair() -> _Keypair: + """Generate RSA keypair as bytes in OpenSSH format.""" + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=4096, + ) + private_key_bytes = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + public_key_bytes = private_key.public_key().public_bytes( + encoding=serialization.Encoding.OpenSSH, # paramiko flakiness fix + format=serialization.PublicFormat.OpenSSH, + ) + return _Keypair( + private_key=private_key_bytes, + public_key=public_key_bytes, + ) diff --git a/modules/sftp/testcontainers/sftp/py.typed b/modules/sftp/testcontainers/sftp/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/modules/sftp/tests/test_sftp.py b/modules/sftp/tests/test_sftp.py new file mode 100644 index 00000000..e3dab2e3 --- /dev/null +++ b/modules/sftp/tests/test_sftp.py @@ -0,0 +1,159 @@ +import tempfile +from pathlib import Path + +import paramiko +import pytest + +from testcontainers.sftp import SFTPContainer, SFTPUser + + +def test_sftp_login_with_default_basic_auth(): + with SFTPContainer() as sftp_container: + sftp_container.start() + host_ip = sftp_container.get_container_host_ip() + host_port = sftp_container.get_exposed_sftp_port() + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect( + hostname=host_ip, + port=host_port, + username=sftp_container.users[0].name, + password=sftp_container.users[0].password, + ) + + +def test_sftp_login_with_default_keypair_auth(): + with SFTPContainer() as sftp_container: + sftp_container.start() + host_ip = sftp_container.get_container_host_ip() + host_port = sftp_container.get_exposed_sftp_port() + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect( + hostname=host_ip, + port=host_port, + username=sftp_container.users[1].name, + key_filename=sftp_container.users[1].private_key_file, + ) + + +def test_sftp_login_with_custom_user_basic_auth(): + user = SFTPUser(name="custom", password="custom_password") + with SFTPContainer(users=[user]) as sftp_container: + sftp_container.start() + host_ip = sftp_container.get_container_host_ip() + host_port = sftp_container.get_exposed_sftp_port() + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect( + hostname=host_ip, + port=host_port, + username=user.name, + password=user.password, + ) + + +def test_sftp_login_with_custom_user_keypair_auth(): + user = SFTPUser.with_keypair(name="custom") + with SFTPContainer(users=[user]) as sftp_container: + sftp_container.start() + host_ip = sftp_container.get_container_host_ip() + host_port = sftp_container.get_exposed_sftp_port() + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect( + hostname=host_ip, + port=host_port, + username=user.name, + key_filename=user.private_key_file, + ) + + +def test_sftp_login_with_custom_user_password_and_keypair_auth(): + user = SFTPUser.with_keypair(name="custom", password="custom_password") + with SFTPContainer(users=[user]) as sftp_container: + sftp_container.start() + host_ip = sftp_container.get_container_host_ip() + host_port = sftp_container.get_exposed_sftp_port() + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect( + hostname=host_ip, + port=host_port, + username=user.name, + password=user.password, + key_filename=user.private_key_file, + ) + + +def test_sftp_user_can_upload(): + with SFTPContainer() as sftp_container: + sftp_container.start() + host_ip = sftp_container.get_container_host_ip() + host_port = sftp_container.get_exposed_sftp_port() + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect( + hostname=host_ip, + port=host_port, + username=sftp_container.users[0].name, + password=sftp_container.users[0].password, + ) + sftp = ssh.open_sftp() + sftp.chdir("upload") + with tempfile.NamedTemporaryFile() as f: + f.write(b"test") + f.seek(0) + sftp.put(f.name, "test.txt") + + with tempfile.NamedTemporaryFile() as f: + sftp.get("test.txt", f.name) + f.seek(0) + assert f.read() == b"test" + + +def test_sftp_user_can_download_from_mounted(tmp_path: Path): + temp_dir = tmp_path / "sub" + temp_dir.mkdir() + temp_file = temp_dir / "test.txt" + temp_file.write_text("test") + user = SFTPUser.with_keypair(name="custom", mount_dir=temp_dir.as_posix()) + with SFTPContainer(users=[user]) as sftp_container: + sftp_container.start() + host_ip = sftp_container.get_container_host_ip() + host_port = sftp_container.get_exposed_sftp_port() + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect( + hostname=host_ip, + port=host_port, + username=user.name, + key_filename=user.private_key_file, + ) + sftp = ssh.open_sftp() + with tempfile.NamedTemporaryFile() as f: + sftp.get(temp_file.name, f.name) + f.seek(0) + assert f.read() == b"test" + + +def test_sftp_user_cant_upload_to_root(tmp_path: Path): + temp_dir = tmp_path / "sub" + temp_dir.mkdir() + temp_file = temp_dir / "test.txt" + temp_file.write_text("test") + with SFTPContainer() as sftp_container: + sftp_container.start() + host_ip = sftp_container.get_container_host_ip() + host_port = sftp_container.get_exposed_sftp_port() + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect( + hostname=host_ip, + port=host_port, + username=sftp_container.users[0].name, + password=sftp_container.users[0].password, + ) + sftp = ssh.open_sftp() + with pytest.raises(PermissionError): + sftp.put(temp_file.as_posix(), temp_file.name) diff --git a/poetry.lock b/poetry.lock index 548521e5..596444b7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -239,7 +239,7 @@ files = [ name = "bcrypt" version = "4.1.2" description = "Modern password hashing for your software and your servers" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "bcrypt-4.1.2-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:ac621c093edb28200728a9cca214d7e838529e557027ef0581685909acd28b5e"}, @@ -1913,6 +1913,7 @@ python-versions = ">=3.7" files = [ {file = "milvus_lite-2.4.7-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:c828190118b104b05b8c8e0b5a4147811c86b54b8fb67bc2e726ad10fc0b544e"}, {file = "milvus_lite-2.4.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e1537633c39879714fb15082be56a4b97f74c905a6e98e302ec01320561081af"}, + {file = "milvus_lite-2.4.7-py3-none-manylinux2014_aarch64.whl", hash = "sha256:fcb909d38c83f21478ca9cb500c84264f988c69f62715ae9462e966767fb76dd"}, {file = "milvus_lite-2.4.7-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f016474d663045787dddf1c3aad13b7d8b61fd329220318f858184918143dcbf"}, ] @@ -2541,6 +2542,27 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "paramiko" +version = "3.4.0" +description = "SSH2 protocol library" +optional = false +python-versions = ">=3.6" +files = [ + {file = "paramiko-3.4.0-py3-none-any.whl", hash = "sha256:43f0b51115a896f9c00f59618023484cb3a14b98bbceab43394a39c6739b7ee7"}, + {file = "paramiko-3.4.0.tar.gz", hash = "sha256:aac08f26a31dc4dffd92821527d1682d99d52f9ef6851968114a8728f3c274d3"}, +] + +[package.dependencies] +bcrypt = ">=3.2" +cryptography = ">=3.3" +pynacl = ">=1.5" + +[package.extras] +all = ["gssapi (>=1.4.1)", "invoke (>=2.0)", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8)"] +gssapi = ["gssapi (>=1.4.1)", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8)"] +invoke = ["invoke (>=2.0)"] + [[package]] name = "pg8000" version = "1.30.5" @@ -3259,6 +3281,32 @@ cryptography = {version = "*", optional = true, markers = "extra == \"rsa\""} ed25519 = ["PyNaCl (>=1.4.0)"] rsa = ["cryptography"] +[[package]] +name = "pynacl" +version = "1.5.0" +description = "Python binding to the Networking and Cryptography (NaCl) library" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyNaCl-1.5.0-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:401002a4aaa07c9414132aaed7f6836ff98f59277a234704ff66878c2ee4a0d1"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:52cb72a79269189d4e0dc537556f4740f7f0a9ec41c1322598799b0bdad4ef92"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a36d4a9dda1f19ce6e03c9a784a2921a4b726b02e1c736600ca9c22029474394"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:0c84947a22519e013607c9be43706dd42513f9e6ae5d39d3613ca1e142fba44d"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06b8f6fa7f5de8d5d2f7573fe8c863c051225a27b61e6860fd047b1775807858"}, + {file = "PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:a422368fc821589c228f4c49438a368831cb5bbc0eab5ebe1d7fac9dded6567b"}, + {file = "PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:61f642bf2378713e2c2e1de73444a3778e5f0a38be6fee0fe532fe30060282ff"}, + {file = "PyNaCl-1.5.0-cp36-abi3-win32.whl", hash = "sha256:e46dae94e34b085175f8abb3b0aaa7da40767865ac82c928eeb9e57e1ea8a543"}, + {file = "PyNaCl-1.5.0-cp36-abi3-win_amd64.whl", hash = "sha256:20f42270d27e1b6a29f54032090b972d97f0a1b0948cc52392041ef7831fee93"}, + {file = "PyNaCl-1.5.0.tar.gz", hash = "sha256:8ac7448f09ab85811607bdd21ec2464495ac8b7c66d146bf545b0f08fb9220ba"}, +] + +[package.dependencies] +cffi = ">=1.4.1" + +[package.extras] +docs = ["sphinx (>=1.6.5)", "sphinx-rtd-theme"] +tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"] + [[package]] name = "pysocks" version = "1.7.1" @@ -4118,6 +4166,20 @@ rfc3986 = ">=1.4.0" rich = ">=12.0.0" urllib3 = ">=1.26.0" +[[package]] +name = "types-paramiko" +version = "3.4.0.20240423" +description = "Typing stubs for paramiko" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-paramiko-3.4.0.20240423.tar.gz", hash = "sha256:aaa98dda232c47886563d66743d3a8b66c432790c596bc3bdd3f17f91be2a8c1"}, + {file = "types_paramiko-3.4.0.20240423-py3-none-any.whl", hash = "sha256:c56e0d43399a1b909901b1e0375e0ff6ee62e16cd6e00695024abc2e9fe02035"}, +] + +[package.dependencies] +cryptography = ">=37.0.0" + [[package]] name = "typing-extensions" version = "4.11.0" @@ -4507,6 +4569,7 @@ rabbitmq = ["pika"] redis = ["redis"] registry = ["bcrypt"] selenium = ["selenium"] +sftp = ["cryptography"] test-module-import = ["httpx"] vault = [] weaviate = ["weaviate-client"] @@ -4514,4 +4577,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "e95316f2de630e690a4e62f240dad0461e9adb936474c9d7cb4a556ec54cb70b" +content-hash = "4694e6bedeb7263ba9b7de579b81913f285161d5d27d453bd36a616e9ce3eade" diff --git a/pyproject.toml b/pyproject.toml index 616c4206..dfa47f7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ packages = [ { include = "testcontainers", from = "modules/rabbitmq" }, { include = "testcontainers", from = "modules/redis" }, { include = "testcontainers", from = "modules/registry" }, + { include = "testcontainers", from = "modules/sftp" }, { include = "testcontainers", from = "modules/selenium" }, { include = "testcontainers", from = "modules/vault" }, { include = "testcontainers", from = "modules/weaviate" }, @@ -148,6 +149,7 @@ rabbitmq = ["pika"] redis = ["redis"] registry = ["bcrypt"] selenium = ["selenium"] +sftp = ["cryptography"] vault = [] weaviate = ["weaviate-client"] chroma = ["chromadb-client"] @@ -173,6 +175,8 @@ pymilvus = "2.4.3" httpx = "0.27.0" paho-mqtt = "2.1.0" sqlalchemy-cockroachdb = "2.0.2" +paramiko = "^3.4.0" +types-paramiko = "^3.4.0.20240423" [[tool.poetry.source]] name = "PyPI" @@ -293,6 +297,7 @@ mypy_path = [ # "modules/rabbitmq", # "modules/redis", # "modules/selenium" + "modules/sftp", # "modules/vault" # "modules/weaviate" ]