Skip to content

Commit

Permalink
✨ Handle catch-all signature errors (#538)
Browse files Browse the repository at this point in the history
* ✨ Handle catch-all signature errors

* ♻️ Refactor `SignatureSerializerTest`
  • Loading branch information
yezz123 authored Feb 18, 2024
1 parent 8d75e4a commit 2bd44ed
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
3 changes: 2 additions & 1 deletion authx/_internal/_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def decode(self, token):
return None, "SignatureExpired"
except BadTimeSignature:
return None, "InvalidSignature"

except Exception:
return None, "BadSignature" # Catch-all for other signature errors
return decoded_obj, None


Expand Down
41 changes: 35 additions & 6 deletions tests/internal/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@ def test_encode_decode(self):
self.assertEqual(data["session_id"], session_id)

def test_decode_with_no_token(self):
serializer = SignatureSerializer("MY_SECRET_KEY", expired_in=1)
token = None
data, err = serializer.decode(token)
self.assertIsNone(data)
self.assertEqual(err, "NoTokenSpecified")
self.decode_serializer(None, "NoTokenSpecified")

def test_decode_with_expired_token(self):
serializer = SignatureSerializer("MY_SECRET_KEY", expired_in=1)
Expand All @@ -33,6 +29,19 @@ def test_decode_with_expired_token(self):
self.assertIsNone(data)
self.assertEqual(err, "SignatureExpired")

def test_decode_with_invalid_signature(self):
self.decode_serializer("tampered_token", "BadSignature")

def test_decode_with_malformed_token(self):
self.decode_serializer("malformedtoken", "BadSignature")

def decode_serializer(self, token, expected_data):
serializer = SignatureSerializer("MY_SECRET_KEY", expired_in=1)
token = token
data, err = serializer.decode(token)
self.assertIsNone(data)
self.assertEqual(err, expected_data)


def test_token_expiration():
serializer = SignatureSerializer("MY_SECRET_KEY", expired_in=1)
Expand All @@ -58,7 +67,6 @@ def test_token_no_expiration():
), "Failed to decode or session_id does not match."


@unittest.skip("Dropping tampering test for now.")
def test_token_tampering():
serializer = SignatureSerializer("MY_SECRET_KEY", expired_in=3600)
dict_obj = {"session_id": 999}
Expand All @@ -69,3 +77,24 @@ def test_token_tampering():
assert (
data is None and err == "InvalidSignature"
), "Tampered token did not cause an error as expected."


def test_casual_ut():
secret_key = "MY_SECRET_KEY"
expired_in = 1
session_id = 1
dict_obj = {"session_id": session_id}

# Instantiate SignatureSerializer
serializer = SignatureSerializer(secret_key, expired_in=expired_in)

# Encode the dictionary object into a token
token = serializer.encode(dict_obj)

# Decode the token
data, err = serializer.decode(token)

# Assert the results
assert (
data is not None and err is None and data["session_id"] == session_id
), "Failed to decode or session_id does not match."

0 comments on commit 2bd44ed

Please sign in to comment.