Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sort_headers parameter to api_jwt.encode #832

Merged
merged 8 commits into from
Dec 8, 2022
3 changes: 2 additions & 1 deletion jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def encode(
headers: dict[str, Any] | None = None,
json_encoder: Type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False,
sort_headers: bool = True,
) -> str:
segments = []

Expand Down Expand Up @@ -135,7 +136,7 @@ def encode(

# Fix for headers misorder - issue #715
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe get rid of this now-irrelevant comment?

json_header = json.dumps(
header, separators=(",", ":"), cls=json_encoder, sort_keys=True
header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers
).encode()

segments.append(base64url_encode(json_header))
Expand Down
10 changes: 9 additions & 1 deletion jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def encode(
algorithm: Optional[str] = "HS256",
headers: Optional[Dict[str, Any]] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
sort_headers: bool = True,
) -> str:
# Check that we get a mapping
if not isinstance(payload, Mapping):
Expand All @@ -64,7 +65,14 @@ def encode(
payload, separators=(",", ":"), cls=json_encoder
).encode("utf-8")

return api_jws.encode(json_payload, key, algorithm, headers, json_encoder)
return api_jws.encode(
json_payload,
key,
algorithm,
headers,
json_encoder,
sort_headers=sort_headers,
)

def decode_complete(
self,
Expand Down
16 changes: 16 additions & 0 deletions tests/test_api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,22 @@ def test_bytes_secret(self, jws, payload):

assert decoded_payload == payload

def test_sorting_headers(self, jws, payload):
secret = "\xc2"
encoded_without_sorting = jws.encode(payload, secret, sort_headers=False)
encoded_with_sorting = jws.encode(payload, secret, sort_headers=True)

assert encoded_with_sorting != encoded_without_sorting

decoded_without_sorting = jws.decode(
encoded_without_sorting, secret, algorithms=["HS256"]
)
decoded_with_sorting = jws.decode(
encoded_with_sorting, secret, algorithms=["HS256"]
)

assert decoded_without_sorting == decoded_with_sorting

def test_decode_invalid_header_padding(self, jws):
example_jws = (
"aeyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9"
Expand Down