Skip to content

Commit

Permalink
Implement signature verification (#232)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarshalX authored Jan 12, 2024
1 parent a88db2e commit 7cf829e
Show file tree
Hide file tree
Showing 12 changed files with 239 additions and 28 deletions.
9 changes: 7 additions & 2 deletions packages/atproto_crypto/algs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import typing as t

from .p256 import P256
from .secp256k1 import Secp256k1

__all__ = ['P256', 'Secp256k1']
_ANY_ALG_TYPE = t.Union[t.Type[P256], t.Type[Secp256k1]]

AVAILABLE_ALGORITHMS: t.List[_ANY_ALG_TYPE] = [P256, Secp256k1]
ALGORITHM_TO_CLASS: t.Dict[str, _ANY_ALG_TYPE] = {alg.NAME: alg for alg in AVAILABLE_ALGORITHMS}

AVAILABLE_ALGORITHMS = [P256, Secp256k1]
__all__ = ['P256', 'Secp256k1', 'ALGORITHM_TO_CLASS']
48 changes: 46 additions & 2 deletions packages/atproto_crypto/algs/base_alg.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve, EllipticCurvePublicKey
from cryptography.hazmat.primitives.asymmetric.ec import ECDSA, EllipticCurve, EllipticCurvePublicKey
from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature
from cryptography.hazmat.primitives.hashes import SHA256

from atproto_crypto.exceptions import InvalidCompressedPubkeyError


class AlgBase:
"""Base class for all algorithms."""

def __init__(self, curve: EllipticCurve) -> None:
NAME = None

def __init__(self, curve: EllipticCurve, curve_order: hex) -> None:
self.curve = curve
self.curve_order = curve_order

def get_elliptic_curve_public_key(self, pubkey: bytes) -> EllipticCurvePublicKey:
"""Return the elliptic curve public key."""
Expand All @@ -28,3 +34,41 @@ def decompress_pubkey(self, pubkey: bytes) -> bytes:
return self.get_elliptic_curve_public_key(pubkey).public_bytes(
encoding=serialization.Encoding.X962, format=serialization.PublicFormat.UncompressedPoint
)

def _ensure_dss_signature_low_s(self, s: int) -> None:
"""Ensure DSS signature is low-S.
It prevents ECDSA signature malleability.
More info: https://atproto.com/specs/cryptography#ecdsa-signature-malleability
"""
if s > self.curve_order // 2:
raise InvalidSignature('Invalid signature. Non low-S signature variant is denied.')

def _encode_signature(self, signature: bytes) -> bytes:
"""Encode signature."""
r = int.from_bytes(signature[:32], 'big')
s = int.from_bytes(signature[32:], 'big')

self._ensure_dss_signature_low_s(s)

return encode_dss_signature(r, s)

def verify_signature(self, pubkey: bytes, signing_input: bytes, signature: bytes) -> bool:
"""Verify signature.
Args:
pubkey: Public key.
signing_input: Signing input (data).
signature: Signature.
Returns:
:obj:`bool`: True if signature is valid, False otherwise.
"""
try:
self.get_elliptic_curve_public_key(pubkey).verify(
signature=self._encode_signature(signature), data=signing_input, signature_algorithm=ECDSA(SHA256())
)
return True
except InvalidSignature:
return False
5 changes: 4 additions & 1 deletion packages/atproto_crypto/algs/p256.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from cryptography.hazmat.primitives.asymmetric.ec import SECP256R1

from atproto_crypto.algs.base_alg import AlgBase
from atproto_crypto.consts import P256_CURVE_ORDER, P256_JWT_ALG


class P256(AlgBase):
NAME = P256_JWT_ALG

def __init__(self) -> None:
super().__init__(SECP256R1())
super().__init__(SECP256R1(), P256_CURVE_ORDER)
5 changes: 4 additions & 1 deletion packages/atproto_crypto/algs/secp256k1.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from cryptography.hazmat.primitives.asymmetric.ec import SECP256K1

from atproto_crypto.algs.base_alg import AlgBase
from atproto_crypto.consts import SECP256K1_CURVE_ORDER, SECP256K1_JWT_ALG


class Secp256k1(AlgBase):
NAME = SECP256K1_JWT_ALG

def __init__(self) -> None:
super().__init__(SECP256K1())
super().__init__(SECP256K1(), SECP256K1_CURVE_ORDER)
3 changes: 3 additions & 0 deletions packages/atproto_crypto/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@

P256_JWT_ALG = 'ES256'
SECP256K1_JWT_ALG = 'ES256K'

P256_CURVE_ORDER = 0xFFFFFFFF_00000000_FFFFFFFF_FFFFFFFF_BCE6FAAD_A7179E84_F3B9CAC2_FC632551
SECP256K1_CURVE_ORDER = 0xFFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFE_BAAEDCE6_AF48A03B_BFD25E8C_D0364141
20 changes: 19 additions & 1 deletion packages/atproto_crypto/did.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
SECP256K1_DID_PREFIX,
SECP256K1_JWT_ALG,
)
from atproto_crypto.exceptions import IncorrectMultikeyPrefixError, UnsupportedKeyTypeError
from atproto_crypto.exceptions import IncorrectDidKeyPrefixError, IncorrectMultikeyPrefixError, UnsupportedKeyTypeError
from atproto_crypto.multibase import bytes_to_multibase, multibase_to_bytes


Expand Down Expand Up @@ -105,6 +105,24 @@ def format_multikey(jwt_alg: str, key: bytes) -> str:
return bytes_to_multibase(BASE58_MULTIBASE_PREFIX, prefixed_bytes)


def parse_did_key(did_key: str) -> Multikey:
"""Parse DID key.
Args:
did_key: DID key.
Returns:
:obj:`Multikey`: Multikey.
Raises:
:obj:`IncorrectDidKeyPrefixError`: Incorrect prefix for DID key.
"""
if not did_key.startswith(DID_KEY_PREFIX):
raise IncorrectDidKeyPrefixError(f'Incorrect prefix for DID key {did_key}')

return parse_multikey(did_key[len(DID_KEY_PREFIX) :])


def format_did_key(jwt_alg: str, key: bytes) -> str:
"""Format DID key.
Expand Down
8 changes: 8 additions & 0 deletions packages/atproto_crypto/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,13 @@ class IncorrectMultikeyPrefixError(DidKeyError):
...


class IncorrectDidKeyPrefixError(DidKeyError):
...


class UnsupportedKeyTypeError(DidKeyError):
...


class UnsupportedSignatureAlgorithmError(AtProtocolError):
...
29 changes: 20 additions & 9 deletions packages/atproto_crypto/verify.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
import typing as t
import warnings

from atproto_crypto.algs import ALGORITHM_TO_CLASS
from atproto_crypto.did import parse_did_key
from atproto_crypto.exceptions import UnsupportedSignatureAlgorithmError


def verify_signature(did_key: str, signing_input: t.Union[str, bytes], signature: t.Union[str, bytes]) -> bool:
# TODO(MarshalX): implement
warnings.warn(
'verify_signature is not implemented yet. Do not trust to this signing_input',
RuntimeWarning,
stacklevel=0,
)

return True
"""Verify signature.
Args:
did_key: DID key.
signing_input: Signing input (data).
signature: Signature.
Returns:
bool: True if signature is valid, False otherwise.
"""
parsed_did_key = parse_did_key(did_key)
if parsed_did_key.jwt_alg not in ALGORITHM_TO_CLASS:
raise UnsupportedSignatureAlgorithmError('Unsupported signature alg')

algorithm_class = ALGORITHM_TO_CLASS[parsed_did_key.jwt_alg]
return algorithm_class().verify_signature(parsed_did_key.key_bytes, signing_input, signature)
4 changes: 2 additions & 2 deletions packages/atproto_server/auth/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def verify_jwt(

fresh_signing_key = get_signing_key_callback(payload.iss, True) # get signing key without a cache
if fresh_signing_key == signing_key:
raise TokenInvalidSignatureError('Could not verify JWT signature. Fresh signing key is equal to the old one')
raise TokenInvalidSignatureError('Invalid signature even with fresh signing key it is equal to the old one)')

if _verify_signature(fresh_signing_key, signing_input, signature):
return payload
Expand Down Expand Up @@ -252,7 +252,7 @@ async def verify_jwt_async(

fresh_signing_key = await get_signing_key_callback(payload.iss, True) # get signing key without a cache
if fresh_signing_key == signing_key:
raise TokenInvalidSignatureError('Could not verify JWT signature. Fresh signing key is equal to the old one')
raise TokenInvalidSignatureError('Invalid signature even with fresh signing key it is equal to the old one)')

if _verify_signature(fresh_signing_key, signing_input, signature):
return payload
Expand Down
68 changes: 68 additions & 0 deletions tests/test_atproto_crypto/signature-fixtures.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
[
{
"comment": "valid P-256 key and signature, with low-S signature",
"messageBase64": "oWVoZWxsb2V3b3JsZA",
"algorithm": "ES256",
"didDocSuite": "EcdsaSecp256r1VerificationKey2019",
"publicKeyDid": "did:key:zDnaembgSGUhZULN2Caob4HLJPaxBh92N7rtH21TErzqf8HQo",
"publicKeyMultibase": "zxdM8dSstjrpZaRUwBmDvjGXweKuEMVN95A9oJBFjkWMh",
"signatureBase64": "2vZNsG3UKvvO/CDlrdvyZRISOFylinBh0Jupc6KcWoJWExHptCfduPleDbG3rko3YZnn9Lw0IjpixVmexJDegg",
"validSignature": true,
"tags": []
},
{
"comment": "valid K-256 key and signature, with low-S signature",
"messageBase64": "oWVoZWxsb2V3b3JsZA",
"algorithm": "ES256K",
"didDocSuite": "EcdsaSecp256k1VerificationKey2019",
"publicKeyDid": "did:key:zQ3shqwJEJyMBsBXCWyCBpUBMqxcon9oHB7mCvx4sSpMdLJwc",
"publicKeyMultibase": "z25z9DTpsiYYJKGsWmSPJK2NFN8PcJtZig12K59UgW7q5t",
"signatureBase64": "5WpdIuEUUfVUYaozsi8G0B3cWO09cgZbIIwg1t2YKdUn/FEznOndsz/qgiYb89zwxYCbB71f7yQK5Lr7NasfoA",
"validSignature": true,
"tags": []
},
{
"comment": "P-256 key and signature, with non-low-S signature which is invalid in atproto",
"messageBase64": "oWVoZWxsb2V3b3JsZA",
"algorithm": "ES256",
"didDocSuite": "EcdsaSecp256r1VerificationKey2019",
"publicKeyDid": "did:key:zDnaembgSGUhZULN2Caob4HLJPaxBh92N7rtH21TErzqf8HQo",
"publicKeyMultibase": "zxdM8dSstjrpZaRUwBmDvjGXweKuEMVN95A9oJBFjkWMh",
"signatureBase64": "2vZNsG3UKvvO/CDlrdvyZRISOFylinBh0Jupc6KcWoKp7O4VS9giSAah8k5IUbXIW00SuOrjfEqQ9HEkN9JGzw",
"validSignature": false,
"tags": ["high-s"]
},
{
"comment": "K-256 key and signature, with non-low-S signature which is invalid in atproto",
"messageBase64": "oWVoZWxsb2V3b3JsZA",
"algorithm": "ES256K",
"didDocSuite": "EcdsaSecp256k1VerificationKey2019",
"publicKeyDid": "did:key:zQ3shqwJEJyMBsBXCWyCBpUBMqxcon9oHB7mCvx4sSpMdLJwc",
"publicKeyMultibase": "z25z9DTpsiYYJKGsWmSPJK2NFN8PcJtZig12K59UgW7q5t",
"signatureBase64": "5WpdIuEUUfVUYaozsi8G0B3cWO09cgZbIIwg1t2YKdXYA67MYxYiTMAVfdnkDCMN9S5B3vHosRe07aORmoshoQ",
"validSignature": false,
"tags": ["high-s"]
},
{
"comment": "P-256 key and signature, with DER-encoded signature which is invalid in atproto",
"messageBase64": "oWVoZWxsb2V3b3JsZA",
"algorithm": "ES256",
"didDocSuite": "EcdsaSecp256r1VerificationKey2019",
"publicKeyDid": "did:key:zDnaeT6hL2RnTdUhAPLij1QBkhYZnmuKyM7puQLW1tkF4Zkt8",
"publicKeyMultibase": "ze8N2PPxnu19hmBQ58t5P3E9Yj6CqakJmTVCaKvf9Byq2",
"signatureBase64": "MEQCIFxYelWJ9lNcAVt+jK0y/T+DC/X4ohFZ+m8f9SEItkY1AiACX7eXz5sgtaRrz/SdPR8kprnbHMQVde0T2R8yOTBweA",
"validSignature": false,
"tags": ["der-encoded"]
},
{
"comment": "K-256 key and signature, with DER-encoded signature which is invalid in atproto",
"messageBase64": "oWVoZWxsb2V3b3JsZA",
"algorithm": "ES256K",
"didDocSuite": "EcdsaSecp256k1VerificationKey2019",
"publicKeyDid": "did:key:zQ3shnriYMXc8wvkbJqfNWh5GXn2bVAeqTC92YuNbek4npqGF",
"publicKeyMultibase": "z22uZXWP8fdHXi4jyx8cCDiBf9qQTsAe6VcycoMQPfcMQX",
"signatureBase64": "MEUCIQCWumUqJqOCqInXF7AzhIRg2MhwRz2rWZcOEsOjPmNItgIgXJH7RnqfYY6M0eg33wU0sFYDlprwdOcpRn78Sz5ePgk",
"validSignature": false,
"tags": ["der-encoded"]
}
]
32 changes: 32 additions & 0 deletions tests/test_atproto_crypto/test_verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import base64
import json
import os

import pytest
from atproto_crypto.verify import verify_signature

# Ref: https://github.com/bluesky-social/atproto/blob/main/interop-test-files/crypto/signature-fixtures.json
_FIXTURES_FILE_PATH = os.path.join(os.path.dirname(__file__), 'signature-fixtures.json')


def _load_test_cases() -> list:
with open(_FIXTURES_FILE_PATH, encoding='UTF-8') as file:
return json.load(file)


def _fix_base64_padding(data: str) -> str:
return data + '=='


def _decode_b64(data: str) -> bytes:
return base64.b64decode(_fix_base64_padding(data))


@pytest.mark.parametrize('test_case', _load_test_cases(), ids=lambda x: x['comment'])
def test_verify_signature(test_case: dict) -> None:
did_key = test_case['publicKeyDid']
data = _decode_b64(test_case['messageBase64'])
signature = _decode_b64(test_case['signatureBase64'])
expected_valid = test_case['validSignature']

assert verify_signature(did_key, data, signature) == expected_valid
36 changes: 26 additions & 10 deletions tests/test_atproto_server/auth/test_jwt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing as t

import pytest
from atproto_server.auth.jwt import get_jwt_payload, parse_jwt, validate_jwt_payload, verify_jwt, verify_jwt_async
from atproto_server.exceptions import TokenDecodeError, TokenExpiredSignatureError, TokenInvalidAudienceError
Expand All @@ -7,6 +9,10 @@
_TEST_JWT_INVALID_SIGN = 'eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NksifQ.eyJpc3MiOiJkaWQ6cGxjOmt2d3ZjbjVpcWZvb29wbXl6dmI0cXpiYSIsImF1ZCI6ImRpZDp3ZWI6ZmVlZC5hdHByb3RvLmJsdWUiLCJleHAiOjIwMDAwMDAwMDB9.50SlT6vw26HsDXVDM4D2D53_Dvzd6bjp3TDc5EyDVD4ob9i3EEB7fmaKE0XR4egMS9Kf9eMdVqH5gJNCaIah4Q' # noqa: E501


if t.TYPE_CHECKING:
from _pytest.monkeypatch import MonkeyPatch


def test_parse_jwt_empty() -> None:
with pytest.raises(TokenDecodeError):
parse_jwt('')
Expand Down Expand Up @@ -41,16 +47,26 @@ def test_validate_jwt_payload_valid() -> None:
validate_jwt_payload(payload)


def test_verify_jwt() -> None:
def test_verify_jwt_valid_signature(monkeypatch: 'MonkeyPatch') -> None:
def get_signing_key(_: str, __: bool) -> str:
return 'did:key:zQ3shc6V2kvUxn7hNmPy9JMToKT7u2NH27SnKNxGL1GcBcS4j'

# allow expired token
monkeypatch.setattr('atproto_server.auth.jwt._validate_exp', lambda *_: True)

verify_jwt(_TEST_JWT_EXPIRED, get_signing_key)


def test_verify_jwt_aud_validation(monkeypatch: 'MonkeyPatch') -> None:
expected_iss = 'did:plc:kvwvcn5iqfooopmyzvb4qzba'
expected_aud = 'did:web:feed.atproto.blue'

def get_signing_key(iss: str, force_refresh: bool) -> str:
def get_signing_key(iss: str, _: bool) -> str:
assert iss == expected_iss
return 'blabla'

if force_refresh:
return 'refreshedKey'
return 'key'
# allow invalid signature
monkeypatch.setattr('atproto_server.auth.jwt._verify_signature', lambda *_: True)

verify_jwt(_TEST_JWT_INVALID_SIGN, get_signing_key)
verify_jwt(_TEST_JWT_INVALID_SIGN, get_signing_key, expected_aud)
Expand All @@ -60,16 +76,16 @@ def get_signing_key(iss: str, force_refresh: bool) -> str:


@pytest.mark.asyncio
async def test_verify_jwt_async() -> None:
async def test_verify_jwt_aud_validation_async(monkeypatch: 'MonkeyPatch') -> None:
expected_iss = 'did:plc:kvwvcn5iqfooopmyzvb4qzba'
expected_aud = 'did:web:feed.atproto.blue'

async def get_signing_key(iss: str, force_refresh: bool) -> str:
async def get_signing_key(iss: str, _: bool) -> str:
assert iss == expected_iss
return 'blabla'

if force_refresh:
return 'refreshedKey'
return 'key'
# allow invalid signature
monkeypatch.setattr('atproto_server.auth.jwt._verify_signature', lambda *_: True)

await verify_jwt_async(_TEST_JWT_INVALID_SIGN, get_signing_key)
await verify_jwt_async(_TEST_JWT_INVALID_SIGN, get_signing_key, expected_aud)
Expand Down

0 comments on commit 7cf829e

Please sign in to comment.