From 6e882588341897765443bc1ab2eab455752bd458 Mon Sep 17 00:00:00 2001 From: Yonathan Randolph Date: Sun, 21 Jan 2018 14:14:28 -0800 Subject: [PATCH] Add support for Elliptic Curve keys --- src/josepy/jwk.py | 101 +++++++++++++++++++++++++++++++++++++++++---- src/josepy/util.py | 31 +++++++++++++- 2 files changed, 124 insertions(+), 8 deletions(-) diff --git a/src/josepy/jwk.py b/src/josepy/jwk.py index e9071ea7e..bfe9b26d9 100644 --- a/src/josepy/jwk.py +++ b/src/josepy/jwk.py @@ -9,7 +9,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes # type: ignore from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import ec # type: ignore +from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.asymmetric import rsa from josepy import errors, json_util, util @@ -121,27 +121,114 @@ def load(cls, data, password=None, backend=None): @JWK.register -class JWKES(JWK): # pragma: no cover +class JWKEC(JWK): # pragma: no cover # pylint: disable=abstract-class-not-used - """ES JWK. + """EC JWK. .. warning:: This is not yet implemented! """ - typ = 'ES' + typ = 'EC' + __slots__ = ('key',) cryptography_key_types = ( ec.EllipticCurvePublicKey, ec.EllipticCurvePrivateKey) required = ('crv', JWK.type_field_name, 'x', 'y') + def __init__(self, *args, **kwargs): + if 'key' in kwargs and not isinstance( + kwargs['key'], util.ComparableECKey): + kwargs['key'] = util.ComparableECKey(kwargs['key']) + super(JWKEC, self).__init__(*args, **kwargs) + + @classmethod + def _encode_param(cls, data): + """Encode Base64urlUInt. + + :type data: long + :rtype: unicode + + """ + def _leading_zeros(arg): + if len(arg) % 2: + return '0' + arg + return arg + + return json_util.encode_b64jose(binascii.unhexlify( + _leading_zeros(hex(data)[2:].rstrip('L')))) + + @classmethod + def _decode_param(cls, data, name, expected_length): + """Decode Base64urlUInt.""" + try: + binary = json_util.decode_b64jose(data) + if len(binary) != expected_length: + raise errors.Error( + 'Expected {name} to be {expected_length} bytes after base64-decoding; got {length}', + name=name, expected_length=expected_length, length=len(binary)) + return int(binascii.hexlify(binary), 16) + except ValueError: # invalid literal for long() with base 16 + raise errors.DeserializationError() + def fields_to_partial_json(self): - raise NotImplementedError() + params = {} + if isinstance(self.key._wrapped, ec.EllipticCurvePublicKey): + public = self.key.public_numbers() + elif isinstance(self.key._wrapped, ec.EllipticCurvePrivateKey): + private = self.key.private_numbers() + public = self.key.public_key().public_numbers() + params.update({ + 'd': private.private_value, + }) + else: raise AssertionError( + "key was not an EllipticCurvePublicKey or EllipticCurvePrivateKey") + + params.update({ + 'x': public.x, + 'y': public.y, + }) + params = dict((key, self._encode_param(value)) + for key, value in six.iteritems(params)) + params['crv'] = self._curve_name_to_crv(public.curve.name) + return params + + @classmethod + def _curve_name_to_crv(cls, curve_name): + if curve_name == "secp256r1": return "P-256" + if curve_name == "secp384r1": return "P-384" + if curve_name == "secp521r1": return "P-521" + raise errors.SerializationError() + + @classmethod + def _crv_to_curve(cls, crv): + # crv is case-sensitive + if crv == "P-256": return ec.SECP256R1() + if crv == "P-384": return ec.SECP384R1() + if crv == "P-521": return ec.SECP521R1() + raise errors.DeserializationError() @classmethod def fields_from_json(cls, jobj): - raise NotImplementedError() + # pylint: disable=invalid-name + curve = cls._crv_to_curve(jobj['crv']) + coord_length = (curve.key_size+7)//8 + x, y = (cls._decode_param(jobj[n], n, coord_length) for n in ('x', 'y')) + public_numbers = ec.EllipticCurvePublicNumbers(x=x, y=y, curve=curve) + if 'd' not in jobj: # public key + key = public_numbers.public_key(default_backend()) + else: # private key + exp_length = (curve.key_size.bit_length()+7)//8 + d = cls._decode_param(jobj['d'], 'd', exp_length) + key = ec.EllipticCurvePrivateNumbers(d, public_numbers).private_key( + default_backend()) + return cls(key=key) def public_key(self): - raise NotImplementedError() + # Unlike RSAPrivateKey, EllipticCurvePrivateKey does not contain public_key() + if hasattr(self.key, 'public_key'): + key = self.key.public_key() + else: + key = self.key.public_numbers().public_key(default_backend()) + return type(self)(key=key) @JWK.register diff --git a/src/josepy/util.py b/src/josepy/util.py index d2ed1c5cc..3aa832c42 100644 --- a/src/josepy/util.py +++ b/src/josepy/util.py @@ -3,7 +3,8 @@ import OpenSSL import six -from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import ec, rsa class abstractclassmethod(classmethod): @@ -134,6 +135,34 @@ def __hash__(self): pub = self.public_numbers() return hash((self.__class__, pub.n, pub.e)) +class ComparableECKey(ComparableKey): # pylint: disable=too-few-public-methods + """Wrapper for ``cryptography`` RSA keys. + + Wraps around: + + - :class:`~cryptography.hazmat.primitives.asymmetric.rsa.EllipticCurvePrivateKey` + - :class:`~cryptography.hazmat.primitives.asymmetric.rsa.EllipticCurvePublicKey` + + """ + + def __hash__(self): + # public_numbers() hasn't got stable hash! + # https://github.com/pyca/cryptography/issues/2143 + if isinstance(self._wrapped, ec.EllipticCurvePrivateKeyWithSerialization): + priv = self.private_numbers() + pub = priv.public_numbers + return hash((self.__class__, pub.curve.name, pub.x, pub.y, priv.d)) + elif isinstance(self._wrapped, rsa.EllipticCurvePublicKeyWithSerialization): + pub = self.public_numbers() + return hash((self.__class__, pub.curve.name, pub.x, pub.y)) + def public_key(self): + """Get wrapped public key.""" + # Unlike RSAPrivateKey, EllipticCurvePrivateKey does not have public_key() + if hasattr(self._wrapped, 'public_key'): + key = self._wrapped.public_key() + else: + key = self._wrapped.public_numbers().public_key(default_backend()) + return self.__class__(key) class ImmutableMap(collections.Mapping, collections.Hashable): # type: ignore # pylint: disable=too-few-public-methods