diff --git a/pyscitt/pyscitt/crypto.py b/pyscitt/pyscitt/crypto.py index a20be9db..0e855c61 100644 --- a/pyscitt/pyscitt/crypto.py +++ b/pyscitt/pyscitt/crypto.py @@ -8,7 +8,7 @@ import warnings from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union from uuid import uuid4 warnings.filterwarnings("ignore", category=Warning) @@ -69,12 +69,6 @@ RECOMMENDED_RSA_PUBLIC_EXPONENT = 65537 -REGISTERED_EC_CURVES = { - "P-256": P256, - "P-384": P384, - "P-521": P521, -} - Pem = str COSE_HEADER_PARAM_ISSUER = 391 @@ -90,12 +84,36 @@ RegistrationInfoValue = Union[str, bytes, int] RegistrationInfo = Dict[str, RegistrationInfoValue] +CoseCurveTypes = Union[Type[P256], Type[P384], Type[P521]] +CoseCurveType = Tuple[str, CoseCurveTypes] + + +def ec_curve_from_name(name: str) -> EllipticCurve: + if name == "P-256": + return ec.SECP256R1() + elif name == "P-384": + return ec.SECP384R1() + elif name == "P-521": + return ec.SECP521R1() + else: + raise ValueError(f"Unsupported EC curve: {name}") + + +def cose_curve_from_ec(curve: EllipticCurve) -> CoseCurveType: + if isinstance(curve, ec.SECP256R1): + return ("P-256", P256) + elif isinstance(curve, ec.SECP384R1): + return ("P-384", P384) + elif isinstance(curve, ec.SECP521R1): + return ("P-521", P521) + else: + raise ValueError(f"Unsupported EC curve: {curve}") -def generate_rsa_keypair(key_size: int) -> Tuple[Pem, Pem]: +def generate_rsa_keypair() -> Tuple[Pem, Pem]: priv = rsa.generate_private_key( public_exponent=RECOMMENDED_RSA_PUBLIC_EXPONENT, - key_size=key_size, + key_size=2048, ) pub = priv.public_key() priv_pem = priv.private_bytes( @@ -107,12 +125,8 @@ def generate_rsa_keypair(key_size: int) -> Tuple[Pem, Pem]: return priv_pem, pub_pem -def generate_ec_keypair(curve: str) -> Tuple[Pem, Pem]: - if curve not in REGISTERED_EC_CURVES: - raise NotImplementedError(f"Unsupported curve: {curve}") - curve_obj = REGISTERED_EC_CURVES[curve].curve_obj - assert isinstance(curve_obj, EllipticCurve) - priv = ec.generate_private_key(curve=curve_obj) +def generate_ec_keypair(curve_name: str) -> Tuple[Pem, Pem]: + priv = ec.generate_private_key(curve=ec_curve_from_name(curve_name)) pub = priv.public_key() priv_pem = priv.private_bytes( Encoding.PEM, PrivateFormat.PKCS8, NoEncryption() @@ -140,11 +154,10 @@ def generate_ed25519_keypair() -> Tuple[Pem, Pem]: def generate_keypair( kty: str, *, - rsa_key_size: Optional[int] = None, ec_curve: Optional[str] = None, ) -> Tuple[str, str]: if kty == "rsa": - return generate_rsa_keypair(rsa_key_size or 2048) + return generate_rsa_keypair() elif kty == "ec": return generate_ec_keypair(ec_curve or "P-256") elif kty == "ed25519": @@ -492,13 +505,7 @@ def from_cryptography_eckey_obj( priv_nums = None pub_nums = ext_key.public_numbers() - # Create map of cryptography curves to cose curves. E.g. {ec.SECP256R1: P256, ...} - registered_crvs = { - type(crv.curve_obj): crv for crv in REGISTERED_EC_CURVES.values() - } - if type(pub_nums.curve) not in registered_crvs: - raise ValueError(f"Unsupported EC Curve: {type(pub_nums.curve)}") - curve = registered_crvs[type(pub_nums.curve)] + _, curve = cose_curve_from_ec(pub_nums.curve) cose_key = {} if pub_nums: @@ -616,7 +623,7 @@ def get_last_embedded_receipt_from_cose(buf: bytes) -> Union[bytes, None]: def load_private_key(key_path: Path) -> Pem: - with open(key_path) as f: + with open(key_path, encoding="utf-8") as f: key_priv_pem = f.read() if is_ssh_private_key(key_priv_pem): key_priv_pem = ssh_private_key_to_pem(key_priv_pem) @@ -654,16 +661,9 @@ def encode_pub_num_jwk(dec): elif isinstance(pub_key, EllipticCurvePublicKey): pub_numbers = pub_key.public_numbers() curve = pub_numbers.curve - # Create map of curves to names. E.g. {ec.SECP256R1: "P-256", ...} - registered_crvs = { - type(crv.curve_obj): name for name, crv in REGISTERED_EC_CURVES.items() - } - if type(curve) not in registered_crvs: - raise ValueError(f"Unsupported EC Curve: {curve}") - crv_name = registered_crvs[type(curve)] - x = pub_numbers.x.to_bytes(REGISTERED_EC_CURVES[crv_name].size, "big") - y = pub_numbers.y.to_bytes(REGISTERED_EC_CURVES[crv_name].size, "big") - + crv_name, crv = cose_curve_from_ec(curve) + x = pub_numbers.x.to_bytes(crv.size, "big") + y = pub_numbers.y.to_bytes(crv.size, "big") jwk = { "kty": "EC", "crv": crv_name, @@ -715,7 +715,7 @@ def sign_claimset( claims: bytes, content_type: str, feed: Optional[str] = None, - registration_info: RegistrationInfo = {}, + registration_info: Optional[RegistrationInfo] = None, svn: Optional[int] = None, cwt: bool = False, ) -> bytes: @@ -804,8 +804,7 @@ def convert_jwk_to_pem(jwk: dict) -> Pem: if jwk.get("kty") == "EC": x = int.from_bytes(base64.urlsafe_b64decode(jwk["x"]), "big") y = int.from_bytes(base64.urlsafe_b64decode(jwk["y"]), "big") - crv = REGISTERED_EC_CURVES[jwk["crv"]].curve_obj - assert isinstance(crv, EllipticCurve) + crv = ec_curve_from_name(jwk["crv"]) key = EllipticCurvePublicNumbers(x, y, crv).public_key() else: raise NotImplementedError("Unsupported JWK type") diff --git a/pyscitt/pyscitt/key_vault_sign_client.py b/pyscitt/pyscitt/key_vault_sign_client.py index d4593448..e4598559 100644 --- a/pyscitt/pyscitt/key_vault_sign_client.py +++ b/pyscitt/pyscitt/key_vault_sign_client.py @@ -17,12 +17,6 @@ from . import crypto -ALGORITHMS = { - 256: ("ES256", "sha256"), - 384: ("ES384", "sha384"), - 521: ("ES512", "sha384"), -} - class KeyVaultSignClient(MemberAuthenticationMethod): """MemberIdentity implementation that uses Azure Key Vault.""" @@ -120,7 +114,15 @@ def http_sign(self, data: bytes): pub_key = cert.public_key() assert isinstance(pub_key, (EllipticCurvePublicKey)) key_size = pub_key.curve.key_size - signature_algorithm, hash_algorithm = ALGORITHMS[key_size] + + if key_size == 256: + signature_algorithm, hash_algorithm = ("ES256", "sha256") + elif key_size == 384: + signature_algorithm, hash_algorithm = ("ES384", "sha384") + elif key_size == 521: + signature_algorithm, hash_algorithm = ("ES512", "sha512") + else: + raise ValueError(f"Unsupported EC size: {key_size}") digest_to_sign = hashlib.new(hash_algorithm, data).digest() sign_result = crypto_client.sign( diff --git a/test/infra/did_web_server.py b/test/infra/did_web_server.py index 3bf0a29d..b0a1a984 100644 --- a/test/infra/did_web_server.py +++ b/test/infra/did_web_server.py @@ -99,7 +99,7 @@ def do_GET(handler_self): self.port = self.httpd.server_address[1] self.base_url = f"https://{self.host}:{self.port}" - tls_key_pem, _ = crypto.generate_rsa_keypair(2048) + tls_key_pem, _ = crypto.generate_rsa_keypair() self.tls_cert_pem = crypto.generate_cert(tls_key_pem, cn=host) context = _create_tls_context(self.tls_cert_pem, tls_key_pem) diff --git a/test/infra/jwt_issuer.py b/test/infra/jwt_issuer.py index 422c0aa8..d8b510ed 100644 --- a/test/infra/jwt_issuer.py +++ b/test/infra/jwt_issuer.py @@ -9,7 +9,7 @@ class JwtIssuer: def __init__(self, name="example.com"): self.name = name - self.key, _ = crypto.generate_rsa_keypair(2048) + self.key, _ = crypto.generate_rsa_keypair() self.cert = crypto.generate_cert(self.key, cn=name) self.key_id = crypto.get_cert_fingerprint(self.cert) diff --git a/test/test_cli.py b/test/test_cli.py index 65d8f3e8..9d8c6daa 100644 --- a/test/test_cli.py +++ b/test/test_cli.py @@ -347,7 +347,7 @@ def test_local_development(run, service_url, tmp_path: Path): def test_create_ssh_did_web(run, tmp_path: Path): - private_key, public_key = crypto.generate_rsa_keypair(2048) + private_key, public_key = crypto.generate_rsa_keypair() ssh_private_key = crypto.private_key_pem_to_ssh(private_key) ssh_public_key = crypto.pub_key_pem_to_ssh(public_key) @@ -401,7 +401,7 @@ def test_create_ssh_did_web(run, tmp_path: Path): def test_adhoc_signer(run, tmp_path: Path): - private_key, public_key = crypto.generate_rsa_keypair(2048) + private_key, public_key = crypto.generate_rsa_keypair() (tmp_path / "key.pem").write_text(private_key) (tmp_path / "key_pub.pem").write_text(public_key) (tmp_path / "claims.json").write_text(json.dumps({"foo": "bar"})) @@ -519,7 +519,7 @@ def test_prefix_tree(run, tmp_path: Path): def test_registration_info(run, tmp_path: Path): - private_key, public_key = crypto.generate_rsa_keypair(2048) + private_key, public_key = crypto.generate_rsa_keypair() (tmp_path / "key.pem").write_text(private_key) (tmp_path / "claims.json").write_text(json.dumps({"foo": "bar"}))