diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d166a7d..3733c2c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,7 +39,6 @@ jobs: strategy: matrix: python-version: - - "3.7" - "3.8" - "3.9" - "3.10" @@ -56,7 +55,7 @@ jobs: allow-prereleases: true cache: pip - - name: Install & run tox + - name: Prepare tox & run tests run: | V=${{ matrix.python-version }} @@ -69,6 +68,9 @@ jobs: python -Im pip install tox python -Im tox run -f "$V" + - name: Run Mypy on API + run: python -Im tox run -e mypy-api + - name: Upload coverage data uses: actions/upload-artifact@v3 with: @@ -111,6 +113,21 @@ jobs: path: htmlcov if: ${{ failure() }} + mypy-pkg: + name: Type-check package + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + cache: pip + + - name: Install & run tox + run: | + python -Im pip install tox + python -Im tox run -e mypy-pkg + install-dev: strategy: matrix: @@ -155,6 +172,7 @@ jobs: - docs - install-dev - lint + - mypy-pkg runs-on: ubuntu-latest diff --git a/CHANGELOG.md b/CHANGELOG.md index f8f3bf7..0360719 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ If breaking changes are needed do be done, they are: ### Backwards-incompatible Changes -- All Python versions up to and including 3.6 have been dropped. +- All Python versions up to and including 3.7 have been dropped. - Support for `commonName` in certificates has been dropped. It has been deprecated since 2017 and isn't supported by any major browser. - The oldest supported pyOpenSSL version (when using the `pyopenssl` backend) is now 17.0.0. @@ -33,6 +33,8 @@ If breaking changes are needed do be done, they are: - `service_identity.(cryptography|pyopenssl).extract_patterns()` are now public APIs (FKA `extract_ids()`). You can use them to extract the patterns from a certificate without verifying anything. [#55](https://github.com/pyca/service-identity/pull/55) +- *service-identity* is now fully typed. + [#57](https://github.com/pyca/service-identity/pull/57) ## 21.1.0 (2021-05-09) diff --git a/docs/api.rst b/docs/api.rst index 598808d..ff07d53 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -4,7 +4,7 @@ API .. note:: - So far, public APIs are only available for host names (:rfc:`6125`) and IP addresses (:rfc:`2818`). + So far, public high-level APIs are only available for host names (:rfc:`6125`) and IP addresses (:rfc:`2818`). All IDs specified by :rfc:`6125` are already implemented though. If you'd like to play with them and provide feedback have a look at the ``verify_service_identity`` function in the `hazmat module `_. @@ -54,6 +54,10 @@ The following are the objects return by the ``extract_patterns`` functions. They each carry the attributes that are necessary to match an ID of their type. +.. autoclass:: CertificatePattern + + It includes all of those that follow now. + .. autoclass:: DNSPattern :members: .. autoclass:: IPAddressPattern diff --git a/docs/conf.py b/docs/conf.py index b459e42..1fa44b3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -28,6 +28,9 @@ "deflist", ] +# Move type hints into the description block, instead of the func definition. +autodoc_typehints = "description" +autodoc_typehints_description_target = "documented" # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] diff --git a/docs/pyopenssl_example.py b/docs/pyopenssl_example.py index 8c93081..ae15305 100644 --- a/docs/pyopenssl_example.py +++ b/docs/pyopenssl_example.py @@ -12,7 +12,7 @@ hostname = sys.argv[1] ctx = SSL.Context(SSL.TLSv1_2_METHOD) -ctx.set_verify(SSL.VERIFY_PEER, lambda conn, cert, errno, depth, ok: ok) +ctx.set_verify(SSL.VERIFY_PEER, lambda conn, cert, errno, depth, ok: bool(ok)) ctx.set_default_verify_paths() conn = SSL.Connection(ctx, socket.socket(socket.AF_INET, socket.SOCK_STREAM)) @@ -22,12 +22,9 @@ try: conn.do_handshake() - print("Server certificate is valid for the following patterns:\n") - pprint.pprint( - service_identity.pyopenssl.extract_patterns( - conn.get_peer_certificate() - ) - ) + if cert := conn.get_peer_certificate(): + print("Server certificate is valid for the following patterns:\n") + pprint.pprint(service_identity.pyopenssl.extract_patterns(cert)) try: service_identity.pyopenssl.verify_hostname(conn, hostname) diff --git a/pyproject.toml b/pyproject.toml index a36a242..f3eeebe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,14 +6,13 @@ build-backend = "hatchling.build" name = "service-identity" authors = [{ name = "Hynek Schlawack", email = "hs@ox.cx" }] license = "MIT" -requires-python = ">=3.7" +requires-python = ">=3.8" description = "Service identity verification for pyOpenSSL & cryptography." keywords = ["cryptography", "openssl", "pyopenssl"] classifiers = [ "Development Status :: 5 - Production/Stable", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", @@ -23,6 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", "Topic :: Security :: Cryptography", "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", ] dependencies = [ # Keep in-sync with tests/constraints/*. @@ -30,15 +30,15 @@ dependencies = [ "pyasn1-modules", "pyasn1", "cryptography", - "importlib_metadata;python_version<'3.8'", ] dynamic = ["version", "readme"] [project.optional-dependencies] idna = ["idna"] tests = ["coverage[toml]>=5.0.2", "pytest"] -docs = ["sphinx", "furo", "myst-parser", "sphinx-notfound-page"] -dev = ["service-identity[tests,docs,idna]", "pyOpenSSL"] +docs = ["sphinx", "furo", "myst-parser", "sphinx-notfound-page", "pyOpenSSL"] +mypy = ["mypy", "types-pyOpenSSL", "idna"] +dev = ["service-identity[tests,mypy,docs,idna]", "pyOpenSSL"] [project.urls] Documentation = "https://service-identity.readthedocs.io/" @@ -105,6 +105,22 @@ source = ["src", ".tox/py*/**/site-packages"] [tool.coverage.report] show_missing = true skip_covered = true +exclude_lines = [ + # a more strict default pragma + "\\# pragma: no cover\\b", + + # allow defensive code + "^\\s*raise AssertionError\\b", + "^\\s*raise NotImplementedError\\b", + "^\\s*return NotImplemented\\b", + "^\\s*raise$", + + # typing-related code + "^if (False|TYPE_CHECKING):", + ": \\.\\.\\.(\\s*#.*)?$", + "^ +\\.\\.\\.$", + "-> ['\"]?NoReturn['\"]?:", +] [tool.black] @@ -152,3 +168,23 @@ ignore = [ [tool.ruff.isort] lines-between-types = 1 lines-after-imports = 2 + + +[tool.mypy] +strict = true + +show_error_codes = true +enable_error_code = ["ignore-without-code"] +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "tests.*" +ignore_errors = true + +[[tool.mypy.overrides]] +module = "tests.typing.*" +ignore_errors = false + +[[tool.mypy.overrides]] +module = "cryptography.*" +follow_imports = "skip" diff --git a/src/service_identity/__init__.py b/src/service_identity/__init__.py index 981f52c..7d1ad3b 100644 --- a/src/service_identity/__init__.py +++ b/src/service_identity/__init__.py @@ -40,13 +40,9 @@ def __getattr__(name: str) -> str: if name not in dunder_to_metadata.keys(): raise AttributeError(f"module {__name__} has no attribute {name}") - import sys import warnings - if sys.version_info < (3, 8): - from importlib_metadata import metadata - else: - from importlib.metadata import metadata + from importlib.metadata import metadata warnings.warn( f"Accessing service_identity.{name} is deprecated and will be " diff --git a/src/service_identity/cryptography.py b/src/service_identity/cryptography.py index 259dc2d..9397fcf 100644 --- a/src/service_identity/cryptography.py +++ b/src/service_identity/cryptography.py @@ -141,7 +141,7 @@ def extract_patterns(cert: Certificate) -> Sequence[CertificatePattern]: srv, _ = decode(other.value) if isinstance(srv, IA5String): ids.append(SRVPattern.from_bytes(srv.asOctets())) - else: # pragma: nocover + else: # pragma: no cover raise CertificateError("Unexpected certificate content.") return ids diff --git a/src/service_identity/exceptions.py b/src/service_identity/exceptions.py index 3fbb279..5d573f6 100644 --- a/src/service_identity/exceptions.py +++ b/src/service_identity/exceptions.py @@ -5,6 +5,13 @@ them from __init__.py. """ +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence + + +if TYPE_CHECKING: + from .hazmat import ServiceID import attr @@ -23,52 +30,45 @@ class SubjectAltNameWarning(DeprecationWarning): """ -@attr.s(auto_exc=True) -class VerificationError(Exception): - """ - Service identity verification failed. - """ - - errors = attr.ib() - - def __str__(self): - return self.__repr__() +@attr.s(slots=True) +class Mismatch: + mismatched_id: ServiceID = attr.ib() -@attr.s -class DNSMismatch: +class DNSMismatch(Mismatch): """ No matching DNSPattern could be found. """ - mismatched_id = attr.ib() - -@attr.s -class SRVMismatch: +class SRVMismatch(Mismatch): """ No matching SRVPattern could be found. """ - mismatched_id = attr.ib() - -@attr.s -class URIMismatch: +class URIMismatch(Mismatch): """ No matching URIPattern could be found. """ - mismatched_id = attr.ib() - -@attr.s -class IPAddressMismatch: +class IPAddressMismatch(Mismatch): """ No matching IPAddressPattern could be found. """ - mismatched_id = attr.ib() + +@attr.s(auto_exc=True) +class VerificationError(Exception): + """ + Service identity verification failed. + """ + + errors: Sequence[Mismatch] = attr.ib() + + def __str__(self) -> str: + return self.__repr__() class CertificateError(Exception): diff --git a/src/service_identity/hazmat.py b/src/service_identity/hazmat.py index dbba000..694757e 100644 --- a/src/service_identity/hazmat.py +++ b/src/service_identity/hazmat.py @@ -7,7 +7,7 @@ import ipaddress import re -from typing import Union +from typing import Protocol, Sequence, Union, runtime_checkable import attr @@ -15,6 +15,7 @@ CertificateError, DNSMismatch, IPAddressMismatch, + Mismatch, SRVMismatch, URIMismatch, VerificationError, @@ -24,7 +25,7 @@ try: import idna except ImportError: - idna = None + idna = None # type: ignore[assignment] @attr.s(slots=True) @@ -33,11 +34,15 @@ class ServiceMatch: A match of a service id and a certificate pattern. """ - service_id = attr.ib() - cert_pattern = attr.ib() + service_id: ServiceID = attr.ib() + cert_pattern: CertificatePattern = attr.ib() -def verify_service_identity(cert_patterns, obligatory_ids, optional_ids): +def verify_service_identity( + cert_patterns: Sequence[CertificatePattern], + obligatory_ids: Sequence[ServiceID], + optional_ids: Sequence[ServiceID], +) -> list[ServiceMatch]: """ Verify whether *cert_patterns* are valid for *obligatory_ids* and *optional_ids*. @@ -71,17 +76,15 @@ def verify_service_identity(cert_patterns, obligatory_ids, optional_ids): return matches -def _find_matches(cert_patterns, service_ids): +def _find_matches( + cert_patterns: Sequence[CertificatePattern], + service_ids: Sequence[ServiceID], +) -> list[ServiceMatch]: """ Search for matching certificate patterns and service_ids. - :param cert_ids: List certificate IDs like DNSPattern. - :type cert_ids: `list` - :param service_ids: List of service IDs like DNS_ID. :type service_ids: `list` - - :rtype: `list` of `ServiceMatch` """ matches = [] for sid in service_ids: @@ -91,25 +94,17 @@ def _find_matches(cert_patterns, service_ids): return matches -def _contains_instance_of(seq, cl): - """ - :type seq: iterable - :type cl: type - - :rtype: bool - """ +def _contains_instance_of(seq: Sequence[object], cl: type) -> bool: return any(isinstance(e, cl) for e in seq) -def _is_ip_address(pattern): +def _is_ip_address(pattern: str | bytes) -> bool: """ Check whether *pattern* could be/match an IP address. :param pattern: A pattern for a host name. - :type pattern: `bytes` or `str` :return: `True` if *pattern* could be an IP address, else `False`. - :rtype: bool """ if isinstance(pattern, bytes): try: @@ -143,7 +138,7 @@ class DNSPattern: _RE_LEGAL_CHARS = re.compile(rb"^[a-z0-9\-_.]+$") @classmethod - def from_bytes(cls, pattern) -> DNSPattern: + def from_bytes(cls, pattern: bytes) -> DNSPattern: if not isinstance(pattern, bytes): raise TypeError("The DNS pattern must be a bytes string.") @@ -243,10 +238,25 @@ def from_bytes(cls, pattern: bytes) -> SRVPattern: SRVPattern, URIPattern, DNSPattern, IPAddressPattern ] """ -All possible patterns that can be extracted from a certificate. +A :class:`Union` of all possible patterns that can be extracted from a +certificate. """ +@runtime_checkable +class ServiceID(Protocol): + @property + def pattern_class(self) -> type[CertificatePattern]: + ... + + @property + def error_on_mismatch(self) -> type[Mismatch]: + ... + + def verify(self, pattern: CertificatePattern) -> bool: + ... + + @attr.s(init=False, slots=True) class DNS_ID: """ @@ -260,7 +270,7 @@ class DNS_ID: pattern_class = DNSPattern error_on_mismatch = DNSMismatch - def __init__(self, hostname): + def __init__(self, hostname: str): if not isinstance(hostname, str): raise TypeError("DNS-ID must be a text string.") @@ -282,7 +292,7 @@ def __init__(self, hostname): if self._RE_LEGAL_CHARS.match(self.hostname) is None: raise ValueError("Invalid DNS-ID.") - def verify(self, pattern: object) -> None: + def verify(self, pattern: CertificatePattern) -> bool: """ https://tools.ietf.org/search/rfc6125#section-6.4 """ @@ -305,11 +315,14 @@ class IPAddress_ID: pattern_class = IPAddressPattern error_on_mismatch = IPAddressMismatch - def verify(self, pattern): + def verify(self, pattern: CertificatePattern) -> bool: """ https://tools.ietf.org/search/rfc2818#section-3.1 """ - return self.ip == pattern.pattern + if isinstance(pattern, self.pattern_class): + return self.ip == pattern.pattern + + return False @attr.s(init=False, slots=True) @@ -318,8 +331,8 @@ class URI_ID: An URI service ID. """ - protocol = attr.ib() - dns_id = attr.ib() + protocol: bytes = attr.ib() + dns_id: DNS_ID = attr.ib() pattern_class = URIPattern error_on_mismatch = URIMismatch @@ -337,7 +350,7 @@ def __init__(self, uri: str): self.protocol = prot.encode("ascii").translate(_TRANS_TO_LOWER) self.dns_id = DNS_ID(hostname.strip("/")) - def verify(self, pattern): + def verify(self, pattern: CertificatePattern) -> bool: """ https://tools.ietf.org/search/rfc6125#section-6.5.2 """ @@ -356,8 +369,8 @@ class SRV_ID: An SRV service ID. """ - name = attr.ib() - dns_id = attr.ib() + name: bytes = attr.ib() + dns_id: DNS_ID = attr.ib() pattern_class = SRVPattern error_on_mismatch = SRVMismatch @@ -375,7 +388,7 @@ def __init__(self, srv: str): self.name = name[1:].encode("ascii").translate(_TRANS_TO_LOWER) self.dns_id = DNS_ID(hostname) - def verify(self, pattern): + def verify(self, pattern: CertificatePattern) -> bool: """ https://tools.ietf.org/search/rfc6125#section-6.5.1 """ @@ -387,13 +400,9 @@ def verify(self, pattern): return False -def _hostname_matches(cert_pattern, actual_hostname): +def _hostname_matches(cert_pattern: bytes, actual_hostname: bytes) -> bool: """ - :type cert_pattern: `bytes` - :type actual_hostname: `bytes` - :return: `True` if *cert_pattern* matches *actual_hostname*, else `False`. - :rtype: `bool` """ if b"*" in cert_pattern: cert_head, cert_tail = cert_pattern.split(b".", 1) @@ -409,14 +418,10 @@ def _hostname_matches(cert_pattern, actual_hostname): return cert_pattern == actual_hostname -def _validate_pattern(cert_pattern): +def _validate_pattern(cert_pattern: bytes) -> None: """ Check whether the usage of wildcards within *cert_pattern* conforms with our expectations. - - :type hostname: `bytes` - - :return: None """ cnt = cert_pattern.count(b"*") if cnt > 1: diff --git a/src/service_identity/py.typed b/src/service_identity/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/service_identity/pyopenssl.py b/src/service_identity/pyopenssl.py index e788ffc..30b5d58 100644 --- a/src/service_identity/pyopenssl.py +++ b/src/service_identity/pyopenssl.py @@ -14,9 +14,9 @@ from pyasn1.type.univ import ObjectIdentifier from pyasn1_modules.rfc2459 import GeneralNames +from .exceptions import CertificateError from .hazmat import ( DNS_ID, - CertificateError, CertificatePattern, DNSPattern, IPAddress_ID, @@ -27,13 +27,16 @@ ) -with contextlib.suppress(ImportError): # we only use it for docstrings - from OpenSSL import SSL +with contextlib.suppress(ImportError): + # We only use it for docstrings -- `if TYPE_CHECKING`` does not work. + from OpenSSL.crypto import X509 + from OpenSSL.SSL import Connection + __all__ = ["verify_hostname"] -def verify_hostname(connection: SSL.Connection, hostname: str): +def verify_hostname(connection: Connection, hostname: str) -> None: """ Verify whether the certificate of *connection* is valid for *hostname*. @@ -49,13 +52,15 @@ def verify_hostname(connection: SSL.Connection, hostname: str): :returns: ``None`` """ verify_service_identity( - cert_patterns=extract_patterns(connection.get_peer_certificate()), + cert_patterns=extract_patterns( + connection.get_peer_certificate() # type:ignore[arg-type] + ), obligatory_ids=[DNS_ID(hostname)], optional_ids=[], ) -def verify_ip_address(connection: SSL.Connection, ip_address: str): +def verify_ip_address(connection: Connection, ip_address: str) -> None: """ Verify whether the certificate of *connection* is valid for *ip_address*. @@ -74,7 +79,9 @@ def verify_ip_address(connection: SSL.Connection, ip_address: str): .. versionadded:: 18.1.0 """ verify_service_identity( - cert_patterns=extract_patterns(connection.get_peer_certificate()), + cert_patterns=extract_patterns( + connection.get_peer_certificate() # type:ignore[arg-type] + ), obligatory_ids=[IPAddress_ID(ip_address)], optional_ids=[], ) @@ -83,7 +90,7 @@ def verify_ip_address(connection: SSL.Connection, ip_address: str): ID_ON_DNS_SRV = ObjectIdentifier("1.3.6.1.5.5.7.8.7") # id_on_dnsSRV -def extract_patterns(cert: SSL.X509) -> Sequence[CertificatePattern]: +def extract_patterns(cert: X509) -> Sequence[CertificatePattern]: """ Extract all valid ID patterns from a certificate for service verification. @@ -122,19 +129,19 @@ def extract_patterns(cert: SSL.X509) -> Sequence[CertificatePattern]: srv, _ = decode(comp.getComponentByPosition(1)) if isinstance(srv, IA5String): ids.append(SRVPattern.from_bytes(srv.asOctets())) - else: # pragma: nocover + else: # pragma: no cover raise CertificateError( "Unexpected certificate content." ) - else: # pragma: nocover + else: # pragma: no cover pass - else: # pragma: nocover + else: # pragma: no cover pass return ids -def extract_ids(cert: SSL.X509) -> Sequence[CertificatePattern]: +def extract_ids(cert: X509) -> Sequence[CertificatePattern]: """ Deprecated and never public API. Use :func:`extract_patterns` instead. diff --git a/tests/typing/api.py b/tests/typing/api.py new file mode 100644 index 0000000..d2836d6 --- /dev/null +++ b/tests/typing/api.py @@ -0,0 +1,43 @@ +""" +This module is used to test the typing of the public API of service-identity. + +It is NOT intended to be executed. +""" + +from __future__ import annotations + +import socket + +from typing import Sequence + +from cryptography.hazmat.backends import default_backend +from cryptography.x509 import load_pem_x509_certificate +from OpenSSL import SSL + +import service_identity + + +backend = default_backend() +c_cert = load_pem_x509_certificate("foo.pem", backend) + +c_ids: Sequence[ + service_identity.hazmat.CertificatePattern +] = service_identity.cryptography.extract_patterns(c_cert) +service_identity.cryptography.verify_certificate_hostname( + c_cert, "example.com" +) +service_identity.cryptography.verify_certificate_ip_address( + c_cert, "127.0.0.1" +) + + +ctx = SSL.Context(SSL.TLSv1_2_METHOD) +conn = SSL.Connection(ctx, socket.socket(socket.AF_INET, socket.SOCK_STREAM)) +p_cert = conn.get_peer_certificate() +assert p_cert + +p_ids: Sequence[ + service_identity.hazmat.CertificatePattern +] = service_identity.pyopenssl.extract_patterns(p_cert) +service_identity.pyopenssl.verify_hostname(conn, "example.com") +service_identity.pyopenssl.verify_ip_address(conn, "127.0.0.1") diff --git a/tox.ini b/tox.ini index f4f1558..3767b02 100644 --- a/tox.ini +++ b/tox.ini @@ -2,9 +2,10 @@ min_version = 4 env_list = lint, + mypy, docs, pypy3{,-pyopenssl-latest-idna}, - py3{7,8,9,10,11,12}{,-pyopenssl}{,-oldest}{,-idna}, + py3{8,9,10,11,12}{,-pyopenssl}{,-oldest}{,-idna}, coverage-report @@ -33,6 +34,16 @@ deps = pre-commit commands = pre-commit run --all-files {posargs} +[testenv:mypy-api] +extras = mypy +commands = mypy tests/typing docs/pyopenssl_example.py + + +[testenv:mypy-pkg] +extras = mypy +commands = mypy src + + [testenv:docs] # Keep in-sync with gh-actions and .readthedocs.yaml. base_python = py311