From 419479aee7fad2609df5ea4491e5dba8475e8320 Mon Sep 17 00:00:00 2001 From: chbndrhnns Date: Fri, 7 Jan 2022 11:58:56 +0100 Subject: [PATCH] Can pass custom JSONEncoder class to jws.sign --- jose/jws.py | 10 +++++----- tests/test_jws.py | 15 ++++++++++++++- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/jose/jws.py b/jose/jws.py index bfaf6bd0..3d4ace35 100644 --- a/jose/jws.py +++ b/jose/jws.py @@ -8,8 +8,7 @@ from jose.exceptions import JWSError, JWSSignatureError from jose.utils import base64url_decode, base64url_encode - -def sign(payload, key, headers=None, algorithm=ALGORITHMS.HS256): +def sign(payload, key, headers=None, algorithm=ALGORITHMS.HS256, encoder_cls=json.JSONEncoder): """Signs a claims set and returns a JWS string. Args: @@ -30,7 +29,7 @@ def sign(payload, key, headers=None, algorithm=ALGORITHMS.HS256): Examples: - >>> jws.sign({'a': 'b'}, 'secret', algorithm='HS256') + >>> jws.sign({'a': 'b'},'secret',algorithm='HS256') 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8' """ @@ -39,7 +38,7 @@ def sign(payload, key, headers=None, algorithm=ALGORITHMS.HS256): raise JWSError("Algorithm %s not supported." % algorithm) encoded_header = _encode_header(algorithm, additional_headers=headers) - encoded_payload = _encode_payload(payload) + encoded_payload = _encode_payload(payload, encoder_cls=encoder_cls) signed_output = _sign_header_and_claims(encoded_header, encoded_payload, algorithm, key) return signed_output @@ -140,12 +139,13 @@ def _encode_header(algorithm, additional_headers=None): return base64url_encode(json_header) -def _encode_payload(payload): +def _encode_payload(payload, encoder_cls): if isinstance(payload, Mapping): try: payload = json.dumps( payload, separators=(",", ":"), + cls=encoder_cls, ).encode("utf-8") except ValueError: pass diff --git a/tests/test_jws.py b/tests/test_jws.py index 01b5fd05..199cceb7 100644 --- a/tests/test_jws.py +++ b/tests/test_jws.py @@ -1,4 +1,5 @@ import json +import typing import warnings import pytest @@ -75,6 +76,19 @@ def test_invalid_key(self, payload): with pytest.raises(JWSError): jws.sign(payload, "secret", algorithm="RS256") + def test_custom_json_encoder(self): + class MyEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, MySet): + return list(o) + return json.JSONEncoder.default(self, o) + + class MySet(typing.Set): + pass + + payload = {"custom": MySet({1, 2, 3})} + jws.sign(payload, "secret", algorithm="HS256", encoder_cls=MyEncoder) + @pytest.mark.parametrize( "key", [ @@ -126,7 +140,6 @@ def test_unsupported_alg(self, payload): jws.sign(payload, "secret", algorithm="SOMETHING") def test_add_headers(self, payload): - additional_headers = {"test": "header"} expected_headers = {