From e3ed15fd9130f18cdc01ad68b0ed92dad74de11a Mon Sep 17 00:00:00 2001 From: Hynek Schlawack Date: Mon, 12 Jun 2023 18:27:30 +0200 Subject: [PATCH] Add type hints --- .github/workflows/ci.yml | 7 ++- CHANGELOG.md | 4 +- docs/pyopenssl_example.py | 11 ++-- pyproject.toml | 46 +++++++++++++-- src/service_identity/__init__.py | 6 +- src/service_identity/cryptography.py | 6 +- src/service_identity/exceptions.py | 50 ++++++++--------- src/service_identity/hazmat.py | 83 ++++++++++++++-------------- src/service_identity/py.typed | 0 src/service_identity/pyopenssl.py | 31 ++++++----- tox.ini | 8 ++- 11 files changed, 148 insertions(+), 104 deletions(-) create mode 100644 src/service_identity/py.typed diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d166a7d..e300b13 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,7 +39,6 @@ jobs: strategy: matrix: python-version: - - "3.7" - "3.8" - "3.9" - "3.10" @@ -56,7 +55,7 @@ jobs: allow-prereleases: true cache: pip - - name: Install & run tox + - name: Prepare tox run: | V=${{ matrix.python-version }} @@ -67,7 +66,9 @@ jobs: fi python -Im pip install tox - python -Im tox run -f "$V" + + - run: python -Im tox run -f "$V" + - run: python -Im tox run -e mypy - name: Upload coverage data uses: actions/upload-artifact@v3 diff --git a/CHANGELOG.md b/CHANGELOG.md index f8f3bf7..0360719 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ If breaking changes are needed do be done, they are: ### Backwards-incompatible Changes -- All Python versions up to and including 3.6 have been dropped. +- All Python versions up to and including 3.7 have been dropped. - Support for `commonName` in certificates has been dropped. It has been deprecated since 2017 and isn't supported by any major browser. - The oldest supported pyOpenSSL version (when using the `pyopenssl` backend) is now 17.0.0. @@ -33,6 +33,8 @@ If breaking changes are needed do be done, they are: - `service_identity.(cryptography|pyopenssl).extract_patterns()` are now public APIs (FKA `extract_ids()`). You can use them to extract the patterns from a certificate without verifying anything. [#55](https://github.com/pyca/service-identity/pull/55) +- *service-identity* is now fully typed. + [#57](https://github.com/pyca/service-identity/pull/57) ## 21.1.0 (2021-05-09) diff --git a/docs/pyopenssl_example.py b/docs/pyopenssl_example.py index 8c93081..ae15305 100644 --- a/docs/pyopenssl_example.py +++ b/docs/pyopenssl_example.py @@ -12,7 +12,7 @@ hostname = sys.argv[1] ctx = SSL.Context(SSL.TLSv1_2_METHOD) -ctx.set_verify(SSL.VERIFY_PEER, lambda conn, cert, errno, depth, ok: ok) +ctx.set_verify(SSL.VERIFY_PEER, lambda conn, cert, errno, depth, ok: bool(ok)) ctx.set_default_verify_paths() conn = SSL.Connection(ctx, socket.socket(socket.AF_INET, socket.SOCK_STREAM)) @@ -22,12 +22,9 @@ try: conn.do_handshake() - print("Server certificate is valid for the following patterns:\n") - pprint.pprint( - service_identity.pyopenssl.extract_patterns( - conn.get_peer_certificate() - ) - ) + if cert := conn.get_peer_certificate(): + print("Server certificate is valid for the following patterns:\n") + pprint.pprint(service_identity.pyopenssl.extract_patterns(cert)) try: service_identity.pyopenssl.verify_hostname(conn, hostname) diff --git a/pyproject.toml b/pyproject.toml index a36a242..b64a207 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,14 +6,13 @@ build-backend = "hatchling.build" name = "service-identity" authors = [{ name = "Hynek Schlawack", email = "hs@ox.cx" }] license = "MIT" -requires-python = ">=3.7" +requires-python = ">=3.8" description = "Service identity verification for pyOpenSSL & cryptography." keywords = ["cryptography", "openssl", "pyopenssl"] classifiers = [ "Development Status :: 5 - Production/Stable", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", @@ -23,6 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", "Topic :: Security :: Cryptography", "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", ] dependencies = [ # Keep in-sync with tests/constraints/*. @@ -30,7 +30,6 @@ dependencies = [ "pyasn1-modules", "pyasn1", "cryptography", - "importlib_metadata;python_version<'3.8'", ] dynamic = ["version", "readme"] @@ -38,7 +37,8 @@ dynamic = ["version", "readme"] idna = ["idna"] tests = ["coverage[toml]>=5.0.2", "pytest"] docs = ["sphinx", "furo", "myst-parser", "sphinx-notfound-page"] -dev = ["service-identity[tests,docs,idna]", "pyOpenSSL"] +mypy = ["mypy", "types-pyOpenSSL", "idna"] +dev = ["service-identity[tests,mypy,docs,idna]", "pyOpenSSL"] [project.urls] Documentation = "https://service-identity.readthedocs.io/" @@ -105,6 +105,22 @@ source = ["src", ".tox/py*/**/site-packages"] [tool.coverage.report] show_missing = true skip_covered = true +exclude_lines = [ + # a more strict default pragma + "\\# pragma: no cover\\b", + + # allow defensive code + "^\\s*raise AssertionError\\b", + "^\\s*raise NotImplementedError\\b", + "^\\s*return NotImplemented\\b", + "^\\s*raise$", + + # typing-related code + "^if (False|TYPE_CHECKING):", + ": \\.\\.\\.(\\s*#.*)?$", + "^ +\\.\\.\\.$", + "-> ['\"]?NoReturn['\"]?:", +] [tool.black] @@ -152,3 +168,25 @@ ignore = [ [tool.ruff.isort] lines-between-types = 1 lines-after-imports = 2 + + +[tool.mypy] +strict = true + +show_error_codes = true +enable_error_code = ["ignore-without-code"] +ignore_missing_imports = true + +warn_return_any = false + +[[tool.mypy.overrides]] +module = "tests.*" +ignore_errors = true + +[[tool.mypy.overrides]] +module = "tests.typing.*" +ignore_errors = false + +[[tool.mypy.overrides]] +module = "cryptography.*" +follow_imports = "skip" diff --git a/src/service_identity/__init__.py b/src/service_identity/__init__.py index 981f52c..7d1ad3b 100644 --- a/src/service_identity/__init__.py +++ b/src/service_identity/__init__.py @@ -40,13 +40,9 @@ def __getattr__(name: str) -> str: if name not in dunder_to_metadata.keys(): raise AttributeError(f"module {__name__} has no attribute {name}") - import sys import warnings - if sys.version_info < (3, 8): - from importlib_metadata import metadata - else: - from importlib.metadata import metadata + from importlib.metadata import metadata warnings.warn( f"Accessing service_identity.{name} is deprecated and will be " diff --git a/src/service_identity/cryptography.py b/src/service_identity/cryptography.py index 259dc2d..a60e32f 100644 --- a/src/service_identity/cryptography.py +++ b/src/service_identity/cryptography.py @@ -59,7 +59,7 @@ def verify_certificate_hostname( """ verify_service_identity( cert_patterns=extract_patterns(certificate), - obligatory_ids=[DNS_ID(hostname)], + obligatory_ids=[DNS_ID(hostname)], # type: ignore[list-item] optional_ids=[], ) @@ -89,7 +89,7 @@ def verify_certificate_ip_address( """ verify_service_identity( cert_patterns=extract_patterns(certificate), - obligatory_ids=[IPAddress_ID(ip_address)], + obligatory_ids=[IPAddress_ID(ip_address)], # type: ignore[list-item] optional_ids=[], ) @@ -141,7 +141,7 @@ def extract_patterns(cert: Certificate) -> Sequence[CertificatePattern]: srv, _ = decode(other.value) if isinstance(srv, IA5String): ids.append(SRVPattern.from_bytes(srv.asOctets())) - else: # pragma: nocover + else: # pragma: no cover raise CertificateError("Unexpected certificate content.") return ids diff --git a/src/service_identity/exceptions.py b/src/service_identity/exceptions.py index 3fbb279..5d573f6 100644 --- a/src/service_identity/exceptions.py +++ b/src/service_identity/exceptions.py @@ -5,6 +5,13 @@ them from __init__.py. """ +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence + + +if TYPE_CHECKING: + from .hazmat import ServiceID import attr @@ -23,52 +30,45 @@ class SubjectAltNameWarning(DeprecationWarning): """ -@attr.s(auto_exc=True) -class VerificationError(Exception): - """ - Service identity verification failed. - """ - - errors = attr.ib() - - def __str__(self): - return self.__repr__() +@attr.s(slots=True) +class Mismatch: + mismatched_id: ServiceID = attr.ib() -@attr.s -class DNSMismatch: +class DNSMismatch(Mismatch): """ No matching DNSPattern could be found. """ - mismatched_id = attr.ib() - -@attr.s -class SRVMismatch: +class SRVMismatch(Mismatch): """ No matching SRVPattern could be found. """ - mismatched_id = attr.ib() - -@attr.s -class URIMismatch: +class URIMismatch(Mismatch): """ No matching URIPattern could be found. """ - mismatched_id = attr.ib() - -@attr.s -class IPAddressMismatch: +class IPAddressMismatch(Mismatch): """ No matching IPAddressPattern could be found. """ - mismatched_id = attr.ib() + +@attr.s(auto_exc=True) +class VerificationError(Exception): + """ + Service identity verification failed. + """ + + errors: Sequence[Mismatch] = attr.ib() + + def __str__(self) -> str: + return self.__repr__() class CertificateError(Exception): diff --git a/src/service_identity/hazmat.py b/src/service_identity/hazmat.py index dbba000..5ed3017 100644 --- a/src/service_identity/hazmat.py +++ b/src/service_identity/hazmat.py @@ -7,7 +7,7 @@ import ipaddress import re -from typing import Union +from typing import Protocol, Sequence, Union, runtime_checkable import attr @@ -15,6 +15,7 @@ CertificateError, DNSMismatch, IPAddressMismatch, + Mismatch, SRVMismatch, URIMismatch, VerificationError, @@ -24,7 +25,7 @@ try: import idna except ImportError: - idna = None + idna = None # type: ignore[assignment] @attr.s(slots=True) @@ -33,11 +34,15 @@ class ServiceMatch: A match of a service id and a certificate pattern. """ - service_id = attr.ib() - cert_pattern = attr.ib() + service_id: ServiceID = attr.ib() + cert_pattern: CertificatePattern = attr.ib() -def verify_service_identity(cert_patterns, obligatory_ids, optional_ids): +def verify_service_identity( + cert_patterns: Sequence[CertificatePattern], + obligatory_ids: Sequence[ServiceID], + optional_ids: Sequence[ServiceID], +) -> list[ServiceMatch]: """ Verify whether *cert_patterns* are valid for *obligatory_ids* and *optional_ids*. @@ -71,17 +76,15 @@ def verify_service_identity(cert_patterns, obligatory_ids, optional_ids): return matches -def _find_matches(cert_patterns, service_ids): +def _find_matches( + cert_patterns: Sequence[CertificatePattern], + service_ids: Sequence[ServiceID], +) -> list[ServiceMatch]: """ Search for matching certificate patterns and service_ids. - :param cert_ids: List certificate IDs like DNSPattern. - :type cert_ids: `list` - :param service_ids: List of service IDs like DNS_ID. :type service_ids: `list` - - :rtype: `list` of `ServiceMatch` """ matches = [] for sid in service_ids: @@ -91,25 +94,17 @@ def _find_matches(cert_patterns, service_ids): return matches -def _contains_instance_of(seq, cl): - """ - :type seq: iterable - :type cl: type - - :rtype: bool - """ +def _contains_instance_of(seq: Sequence[object], cl: type) -> bool: return any(isinstance(e, cl) for e in seq) -def _is_ip_address(pattern): +def _is_ip_address(pattern: str | bytes) -> bool: """ Check whether *pattern* could be/match an IP address. :param pattern: A pattern for a host name. - :type pattern: `bytes` or `str` :return: `True` if *pattern* could be an IP address, else `False`. - :rtype: bool """ if isinstance(pattern, bytes): try: @@ -143,7 +138,7 @@ class DNSPattern: _RE_LEGAL_CHARS = re.compile(rb"^[a-z0-9\-_.]+$") @classmethod - def from_bytes(cls, pattern) -> DNSPattern: + def from_bytes(cls, pattern: bytes) -> DNSPattern: if not isinstance(pattern, bytes): raise TypeError("The DNS pattern must be a bytes string.") @@ -247,6 +242,15 @@ def from_bytes(cls, pattern: bytes) -> SRVPattern: """ +@runtime_checkable +class ServiceID(Protocol): + pattern_class: type[CertificatePattern] + error_on_mismatch: type[Mismatch] + + def verify(self, pattern: CertificatePattern) -> bool: + ... + + @attr.s(init=False, slots=True) class DNS_ID: """ @@ -260,7 +264,7 @@ class DNS_ID: pattern_class = DNSPattern error_on_mismatch = DNSMismatch - def __init__(self, hostname): + def __init__(self, hostname: str): if not isinstance(hostname, str): raise TypeError("DNS-ID must be a text string.") @@ -282,7 +286,7 @@ def __init__(self, hostname): if self._RE_LEGAL_CHARS.match(self.hostname) is None: raise ValueError("Invalid DNS-ID.") - def verify(self, pattern: object) -> None: + def verify(self, pattern: CertificatePattern) -> bool: """ https://tools.ietf.org/search/rfc6125#section-6.4 """ @@ -305,11 +309,14 @@ class IPAddress_ID: pattern_class = IPAddressPattern error_on_mismatch = IPAddressMismatch - def verify(self, pattern): + def verify(self, pattern: CertificatePattern) -> bool: """ https://tools.ietf.org/search/rfc2818#section-3.1 """ - return self.ip == pattern.pattern + if isinstance(pattern, self.pattern_class): + return self.ip == pattern.pattern + + return False @attr.s(init=False, slots=True) @@ -318,8 +325,8 @@ class URI_ID: An URI service ID. """ - protocol = attr.ib() - dns_id = attr.ib() + protocol: bytes = attr.ib() + dns_id: DNS_ID = attr.ib() pattern_class = URIPattern error_on_mismatch = URIMismatch @@ -337,7 +344,7 @@ def __init__(self, uri: str): self.protocol = prot.encode("ascii").translate(_TRANS_TO_LOWER) self.dns_id = DNS_ID(hostname.strip("/")) - def verify(self, pattern): + def verify(self, pattern: CertificatePattern) -> bool: """ https://tools.ietf.org/search/rfc6125#section-6.5.2 """ @@ -356,8 +363,8 @@ class SRV_ID: An SRV service ID. """ - name = attr.ib() - dns_id = attr.ib() + name: bytes = attr.ib() + dns_id: DNS_ID = attr.ib() pattern_class = SRVPattern error_on_mismatch = SRVMismatch @@ -375,7 +382,7 @@ def __init__(self, srv: str): self.name = name[1:].encode("ascii").translate(_TRANS_TO_LOWER) self.dns_id = DNS_ID(hostname) - def verify(self, pattern): + def verify(self, pattern: CertificatePattern) -> bool: """ https://tools.ietf.org/search/rfc6125#section-6.5.1 """ @@ -387,13 +394,9 @@ def verify(self, pattern): return False -def _hostname_matches(cert_pattern, actual_hostname): +def _hostname_matches(cert_pattern: bytes, actual_hostname: bytes) -> bool: """ - :type cert_pattern: `bytes` - :type actual_hostname: `bytes` - :return: `True` if *cert_pattern* matches *actual_hostname*, else `False`. - :rtype: `bool` """ if b"*" in cert_pattern: cert_head, cert_tail = cert_pattern.split(b".", 1) @@ -409,14 +412,10 @@ def _hostname_matches(cert_pattern, actual_hostname): return cert_pattern == actual_hostname -def _validate_pattern(cert_pattern): +def _validate_pattern(cert_pattern: bytes) -> None: """ Check whether the usage of wildcards within *cert_pattern* conforms with our expectations. - - :type hostname: `bytes` - - :return: None """ cnt = cert_pattern.count(b"*") if cnt > 1: diff --git a/src/service_identity/py.typed b/src/service_identity/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/service_identity/pyopenssl.py b/src/service_identity/pyopenssl.py index e788ffc..7bfc619 100644 --- a/src/service_identity/pyopenssl.py +++ b/src/service_identity/pyopenssl.py @@ -14,9 +14,9 @@ from pyasn1.type.univ import ObjectIdentifier from pyasn1_modules.rfc2459 import GeneralNames +from .exceptions import CertificateError from .hazmat import ( DNS_ID, - CertificateError, CertificatePattern, DNSPattern, IPAddress_ID, @@ -28,12 +28,13 @@ with contextlib.suppress(ImportError): # we only use it for docstrings - from OpenSSL import SSL + from OpenSSL.crypto import X509 + from OpenSSL.SSL import Connection __all__ = ["verify_hostname"] -def verify_hostname(connection: SSL.Connection, hostname: str): +def verify_hostname(connection: Connection, hostname: str) -> None: """ Verify whether the certificate of *connection* is valid for *hostname*. @@ -49,13 +50,15 @@ def verify_hostname(connection: SSL.Connection, hostname: str): :returns: ``None`` """ verify_service_identity( - cert_patterns=extract_patterns(connection.get_peer_certificate()), - obligatory_ids=[DNS_ID(hostname)], + cert_patterns=extract_patterns( + connection.get_peer_certificate() # type:ignore[arg-type] + ), + obligatory_ids=[DNS_ID(hostname)], # type: ignore[list-item] optional_ids=[], ) -def verify_ip_address(connection: SSL.Connection, ip_address: str): +def verify_ip_address(connection: Connection, ip_address: str) -> None: """ Verify whether the certificate of *connection* is valid for *ip_address*. @@ -74,8 +77,10 @@ def verify_ip_address(connection: SSL.Connection, ip_address: str): .. versionadded:: 18.1.0 """ verify_service_identity( - cert_patterns=extract_patterns(connection.get_peer_certificate()), - obligatory_ids=[IPAddress_ID(ip_address)], + cert_patterns=extract_patterns( + connection.get_peer_certificate() # type:ignore[arg-type] + ), + obligatory_ids=[IPAddress_ID(ip_address)], # type: ignore[list-item] optional_ids=[], ) @@ -83,7 +88,7 @@ def verify_ip_address(connection: SSL.Connection, ip_address: str): ID_ON_DNS_SRV = ObjectIdentifier("1.3.6.1.5.5.7.8.7") # id_on_dnsSRV -def extract_patterns(cert: SSL.X509) -> Sequence[CertificatePattern]: +def extract_patterns(cert: X509) -> Sequence[CertificatePattern]: """ Extract all valid ID patterns from a certificate for service verification. @@ -122,19 +127,19 @@ def extract_patterns(cert: SSL.X509) -> Sequence[CertificatePattern]: srv, _ = decode(comp.getComponentByPosition(1)) if isinstance(srv, IA5String): ids.append(SRVPattern.from_bytes(srv.asOctets())) - else: # pragma: nocover + else: # pragma: no cover raise CertificateError( "Unexpected certificate content." ) - else: # pragma: nocover + else: # pragma: no cover pass - else: # pragma: nocover + else: # pragma: no cover pass return ids -def extract_ids(cert: SSL.X509) -> Sequence[CertificatePattern]: +def extract_ids(cert: X509) -> Sequence[CertificatePattern]: """ Deprecated and never public API. Use :func:`extract_patterns` instead. diff --git a/tox.ini b/tox.ini index f4f1558..94b5c40 100644 --- a/tox.ini +++ b/tox.ini @@ -2,9 +2,10 @@ min_version = 4 env_list = lint, + mypy, docs, pypy3{,-pyopenssl-latest-idna}, - py3{7,8,9,10,11,12}{,-pyopenssl}{,-oldest}{,-idna}, + py3{8,9,10,11,12}{,-pyopenssl}{,-oldest}{,-idna}, coverage-report @@ -33,6 +34,11 @@ deps = pre-commit commands = pre-commit run --all-files {posargs} +[testenv:mypy] +extras = mypy +commands = mypy src docs/pyopenssl_example.py + + [testenv:docs] # Keep in-sync with gh-actions and .readthedocs.yaml. base_python = py311