From 819d6d4bac9797edfb4c8dbd08eff5c3b4bc54b6 Mon Sep 17 00:00:00 2001 From: Kyle Finley Date: Thu, 14 Apr 2022 10:26:31 -0400 Subject: [PATCH] Add cert retrieval for requests (#5320) * Add cert retrieval for requests The authenticator.py code already retrieves credentials from the config for every request based on url matching. This change makes the authenticator also retrieve certs from the config for each request based on url matching. Also includes unit tests. Co-authored-by: Tucker Beck --- src/poetry/utils/authenticator.py | 47 +++++++++++++++++++++++++++++-- tests/utils/test_authenticator.py | 44 +++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 3 deletions(-) diff --git a/src/poetry/utils/authenticator.py b/src/poetry/utils/authenticator.py index a7d0dda3d89..6c7758a1503 100644 --- a/src/poetry/utils/authenticator.py +++ b/src/poetry/utils/authenticator.py @@ -6,16 +6,21 @@ from typing import TYPE_CHECKING from typing import Any +from typing import Iterator import requests import requests.auth import requests.exceptions from poetry.exceptions import PoetryException +from poetry.utils.helpers import get_cert +from poetry.utils.helpers import get_client_cert from poetry.utils.password_manager import PasswordManager if TYPE_CHECKING: + from pathlib import Path + from cleo.io.io import IO from poetry.config.config import Config @@ -30,6 +35,7 @@ def __init__(self, config: Config, io: IO | None = None) -> None: self._io = io self._session = None self._credentials = {} + self._certs = {} self._password_manager = PasswordManager(self._config) def _log(self, message: str, level: str = "debug") -> None: @@ -61,8 +67,16 @@ def request(self, method: str, url: str, **kwargs: Any) -> requests.Response: proxies = kwargs.get("proxies", {}) stream = kwargs.get("stream") - verify = kwargs.get("verify") - cert = kwargs.get("cert") + + certs = self.get_certs_for_url(url) + verify = kwargs.get("verify") or certs.get("verify") + cert = kwargs.get("cert") or certs.get("cert") + + if cert is not None: + cert = str(cert) + + if verify is not None: + verify = str(verify) settings = session.merge_environment_settings( prepared_request.url, proxies, stream, verify, cert @@ -157,7 +171,7 @@ def _get_http_auth(self, name: str, netloc: str | None) -> dict[str, str] | None return auth def _get_credentials_for_netloc(self, netloc: str) -> tuple[str | None, str | None]: - for repository_name in self._config.get("repositories", []): + for repository_name, _ in self._get_repository_netlocs(): auth = self._get_http_auth(repository_name, netloc) if auth is None: @@ -167,6 +181,22 @@ def _get_credentials_for_netloc(self, netloc: str) -> tuple[str | None, str | No return None, None + def get_certs_for_url(self, url: str) -> dict[str, Path | None]: + parsed_url = urllib.parse.urlsplit(url) + + netloc = parsed_url.netloc + + return self._certs.setdefault( + netloc, + self._get_certs_for_netloc_from_config(netloc), + ) + + def _get_repository_netlocs(self) -> Iterator[tuple[str, str]]: + for repository_name in self._config.get("repositories", []): + url = self._config.get(f"repositories.{repository_name}.url") + parsed_url = urllib.parse.urlsplit(url) + yield repository_name, parsed_url.netloc + def _get_credentials_for_netloc_from_keyring( self, url: str, netloc: str, username: str | None ) -> dict[str, str] | None: @@ -193,3 +223,14 @@ def _get_credentials_for_netloc_from_keyring( } return None + + def _get_certs_for_netloc_from_config(self, netloc: str) -> dict[str, Path | None]: + certs = {"cert": None, "verify": None} + + for repository_name, repository_netloc in self._get_repository_netlocs(): + if netloc == repository_netloc: + certs["cert"] = get_client_cert(self._config, repository_name) + certs["verify"] = get_cert(self._config, repository_name) + break + + return certs diff --git a/tests/utils/test_authenticator.py b/tests/utils/test_authenticator.py index 25b312211b7..4c8f77faedf 100644 --- a/tests/utils/test_authenticator.py +++ b/tests/utils/test_authenticator.py @@ -4,6 +4,7 @@ import uuid from dataclasses import dataclass +from pathlib import Path from typing import TYPE_CHECKING from typing import Any @@ -306,3 +307,46 @@ def test_authenticator_uses_env_provided_credentials( request = http.last_request() assert request.headers["Authorization"] == "Basic YmFyOmJheg==" + + +@pytest.mark.parametrize( + "cert,client_cert", + [ + (None, None), + (None, "path/to/provided/client-cert"), + ("/path/to/provided/cert", None), + ("/path/to/provided/cert", "path/to/provided/client-cert"), + ], +) +def test_authenticator_uses_certs_from_config_if_not_provided( + config: Config, + mock_remote: type[httpretty.httpretty], + http: type[httpretty.httpretty], + mocker: MockerFixture, + cert: str | None, + client_cert: str | None, +): + configured_cert = "/path/to/cert" + configured_client_cert = "/path/to/client-cert" + config.merge( + { + "repositories": {"foo": {"url": "https://foo.bar/simple/"}}, + "http-basic": {"foo": {"username": "bar", "password": "baz"}}, + "certificates": { + "foo": {"cert": configured_cert, "client-cert": configured_client_cert} + }, + } + ) + + authenticator = Authenticator(config, NullIO()) + session_send = mocker.patch.object(authenticator.session, "send") + authenticator.request( + "get", + "https://foo.bar/files/foo-0.1.0.tar.gz", + verify=cert, + cert=client_cert, + ) + kwargs = session_send.call_args[1] + + assert Path(kwargs["verify"]) == Path(cert or configured_cert) + assert Path(kwargs["cert"]) == Path(client_cert or configured_client_cert)