forked from redis/redis-py
-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtoken.py
131 lines (95 loc) · 3.24 KB
/
token.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from redis.auth.err import InvalidTokenSchemaErr
class TokenInterface(ABC):
@abstractmethod
def is_expired(self) -> bool:
pass
@abstractmethod
def ttl(self) -> float:
pass
@abstractmethod
def try_get(self, key: str) -> str:
pass
@abstractmethod
def get_value(self) -> str:
pass
@abstractmethod
def get_expires_at_ms(self) -> float:
pass
@abstractmethod
def get_received_at_ms(self) -> float:
pass
class TokenResponse:
def __init__(self, token: TokenInterface):
self._token = token
def get_token(self) -> TokenInterface:
return self._token
def get_ttl_ms(self) -> float:
return self._token.get_expires_at_ms() - self._token.get_received_at_ms()
class SimpleToken(TokenInterface):
def __init__(
self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict
) -> None:
self.value = value
self.expires_at = expires_at_ms
self.received_at = received_at_ms
self.claims = claims
def ttl(self) -> float:
if self.expires_at == -1:
return -1
return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000)
def is_expired(self) -> bool:
if self.expires_at == -1:
return False
return self.ttl() <= 0
def try_get(self, key: str) -> str:
return self.claims.get(key)
def get_value(self) -> str:
return self.value
def get_expires_at_ms(self) -> float:
return self.expires_at
def get_received_at_ms(self) -> float:
return self.received_at
class JWToken(TokenInterface):
REQUIRED_FIELDS = {"exp"}
def __init__(self, token: str):
try:
import jwt
except ImportError as ie:
raise ImportError(
f"The PyJWT library is required for {self.__class__.__name__}.",
) from ie
self._value = token
self._decoded = jwt.decode(
self._value,
options={"verify_signature": False},
algorithms=[jwt.get_unverified_header(self._value).get("alg")],
)
self._validate_token()
def is_expired(self) -> bool:
exp = self._decoded["exp"]
if exp == -1:
return False
return (
self._decoded["exp"] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000
)
def ttl(self) -> float:
exp = self._decoded["exp"]
if exp == -1:
return -1
return (
self._decoded["exp"] * 1000 - datetime.now(timezone.utc).timestamp() * 1000
)
def try_get(self, key: str) -> str:
return self._decoded.get(key)
def get_value(self) -> str:
return self._value
def get_expires_at_ms(self) -> float:
return float(self._decoded["exp"] * 1000)
def get_received_at_ms(self) -> float:
return datetime.now(timezone.utc).timestamp() * 1000
def _validate_token(self):
actual_fields = {x for x in self._decoded.keys()}
if len(self.REQUIRED_FIELDS - actual_fields) != 0:
raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields)