Skip to content

Commit

Permalink
Add JWK support to JWT encode (#979)
Browse files Browse the repository at this point in the history
* Allow JWK for JWS encode.

* Add PyJWK to JWT encode.

* Update CHANGELOG.

* Remove `DEFAULT_ALGORITHM`
  • Loading branch information
luhn authored Oct 7, 2024
1 parent 44d8605 commit c387281
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Changed

- Use ``Sequence`` for parameter types rather than ``List`` where applicable by @imnotjames in `#970 <https://github.com/jpadilla/pyjwt/pull/970>`__
- Remove algorithm requirement from JWT API, instead relying on JWS API for enforcement, by @luhn in `#975 <https://github.com/jpadilla/pyjwt/pull/975>`__
- Add JWK support to JWT encode by @luhn in `#979 <https://github.com/jpadilla/pyjwt/pull/979>`__

Fixed
~~~~~
Expand Down
14 changes: 11 additions & 3 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def get_algorithm_by_name(self, alg_name: str) -> Algorithm:
def encode(
self,
payload: bytes,
key: AllowedPrivateKeys | str | bytes,
algorithm: str | None = "HS256",
key: AllowedPrivateKeys | PyJWK | str | bytes,
algorithm: str | None = None,
headers: dict[str, Any] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False,
Expand All @@ -115,7 +115,13 @@ def encode(
segments = []

# declare a new var to narrow the type for type checkers
algorithm_: str = algorithm if algorithm is not None else "none"
if algorithm is None:
if isinstance(key, PyJWK):
algorithm_ = key.algorithm_name
else:
algorithm_ = "HS256"
else:
algorithm_ = algorithm

# Prefer headers values if present to function parameters.
if headers:
Expand Down Expand Up @@ -159,6 +165,8 @@ def encode(
signing_input = b".".join(segments)

alg_obj = self.get_algorithm_by_name(algorithm_)
if isinstance(key, PyJWK):
key = key.key
key = alg_obj.prepare_key(key)
signature = alg_obj.sign(signing_input, key)

Expand Down
4 changes: 2 additions & 2 deletions jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def _get_default_options() -> dict[str, bool | list[str]]:
def encode(
self,
payload: dict[str, Any],
key: AllowedPrivateKeys | str | bytes,
algorithm: str | None = "HS256",
key: AllowedPrivateKeys | PyJWK | str | bytes,
algorithm: str | None = None,
headers: dict[str, Any] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
sort_headers: bool = True,
Expand Down
35 changes: 33 additions & 2 deletions tests/test_api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ def test_decode_with_non_mapping_header_throws_exception(self, jws):
exception = context.value
assert str(exception) == "Invalid header string: must be a json object"

def test_encode_default_algorithm(self, jws, payload):
msg = jws.encode(payload, "secret")
decoded = jws.decode_complete(msg, "secret", algorithms=["HS256"])
assert decoded == {
"header": {"alg": "HS256", "typ": "JWT"},
"payload": payload,
"signature": (
b"H\x8a\xf4\xdf3:\xe1\xac\x16E\xd3\xeb\x00\xcf\xfa\xd5\x05\xac"
b"e\xc8@\xb6\x00\xd5\xde\x9aa|s\xcfZB"
),
}

def test_encode_algorithm_param_should_be_case_sensitive(self, jws, payload):
jws.encode(payload, "secret", algorithm="HS256")

Expand Down Expand Up @@ -193,6 +205,25 @@ def test_encode_with_alg_hs256_and_headers_alg_es256(self, jws, payload):
msg = jws.encode(payload, priv_key, algorithm="HS256", headers={"alg": "ES256"})
assert b"hello world" == jws.decode(msg, pub_key, algorithms=["ES256"])

def test_encode_with_jwk(self, jws, payload):
jwk = PyJWK(
{
"kty": "oct",
"alg": "HS256",
"k": "c2VjcmV0", # "secret"
}
)
msg = jws.encode(payload, key=jwk)
decoded = jws.decode_complete(msg, key=jwk, algorithms=["HS256"])
assert decoded == {
"header": {"alg": "HS256", "typ": "JWT"},
"payload": payload,
"signature": (
b"H\x8a\xf4\xdf3:\xe1\xac\x16E\xd3\xeb\x00\xcf\xfa\xd5\x05\xac"
b"e\xc8@\xb6\x00\xd5\xde\x9aa|s\xcfZB"
),
}

def test_decode_algorithm_param_should_be_case_sensitive(self, jws):
example_jws = (
"eyJhbGciOiJoczI1NiIsInR5cCI6IkpXVCJ9" # alg = hs256
Expand Down Expand Up @@ -531,13 +562,13 @@ def test_decode_invalid_crypto_padding(self, jws):
assert "Invalid crypto padding" in str(exc.value)

def test_decode_with_algo_none_should_fail(self, jws, payload):
jws_message = jws.encode(payload, key=None, algorithm=None)
jws_message = jws.encode(payload, key=None, algorithm="none")

with pytest.raises(DecodeError):
jws.decode(jws_message, algorithms=["none"])

def test_decode_with_algo_none_and_verify_false_should_pass(self, jws, payload):
jws_message = jws.encode(payload, key=None, algorithm=None)
jws_message = jws.encode(payload, key=None, algorithm="none")
jws.decode(jws_message, options={"verify_signature": False})

def test_get_unverified_header_returns_header_values(self, jws, payload):
Expand Down

0 comments on commit c387281

Please sign in to comment.