Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sort_headers parameter to api_jwt.encode #832

Merged
merged 8 commits into from
Dec 8, 2022
4 changes: 2 additions & 2 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def encode(
headers: dict[str, Any] | None = None,
json_encoder: Type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False,
sort_headers: bool = True,
) -> str:
segments = []

Expand Down Expand Up @@ -133,9 +134,8 @@ def encode(
# True is the standard value for b64, so no need for it
del header["b64"]

# Fix for headers misorder - issue #715
json_header = json.dumps(
header, separators=(",", ":"), cls=json_encoder, sort_keys=True
header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers
).encode()

segments.append(base64url_encode(json_header))
Expand Down
10 changes: 9 additions & 1 deletion jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def encode(
algorithm: Optional[str] = "HS256",
headers: Optional[Dict[str, Any]] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
sort_headers: bool = True,
) -> str:
# Check that we get a mapping
if not isinstance(payload, Mapping):
Expand All @@ -64,7 +65,14 @@ def encode(
payload, separators=(",", ":"), cls=json_encoder
).encode("utf-8")

return api_jws.encode(json_payload, key, algorithm, headers, json_encoder)
return api_jws.encode(
json_payload,
key,
algorithm,
headers,
json_encoder,
sort_headers=sort_headers,
)

def decode_complete(
self,
Expand Down
12 changes: 12 additions & 0 deletions tests/test_api_jws.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from collections import OrderedDict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use simple dict instead of orderedDict here?

Copy link
Contributor Author

@evroon evroon Dec 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I changed it now

from decimal import Decimal

import pytest
Expand Down Expand Up @@ -414,6 +415,17 @@ def test_bytes_secret(self, jws, payload):

assert decoded_payload == payload

@pytest.mark.parametrize("sort_headers", (False, True))
def test_sorting_of_headers(self, jws, payload, sort_headers):
jws_message = jws.encode(
payload,
key="\xc2",
headers=OrderedDict([("b", "1"), ("a", "2")]),
sort_headers=sort_headers,
)
header_json = base64url_decode(jws_message.split(".")[0])
assert sort_headers == (header_json.index(b'"a"') < header_json.index(b'"b"'))

def test_decode_invalid_header_padding(self, jws):
example_jws = (
"aeyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9"
Expand Down