-
Notifications
You must be signed in to change notification settings - Fork 12
/
opa_middleware.py
163 lines (141 loc) · 5.34 KB
/
opa_middleware.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import asyncio
import json
import logging
import re
from json.decoder import JSONDecodeError
from typing import List
from typing import Optional
import requests
from fastapi.responses import JSONResponse
from starlette.requests import Request
from starlette.responses import RedirectResponse
from starlette.types import ASGIApp
from starlette.types import Receive
from starlette.types import Scope
from starlette.types import Send
from fastapi_opa.auth.exceptions import AuthenticationException
from fastapi_opa.opa.opa_config import OPAConfig
try:
Pattern = re.Pattern
except AttributeError:
# Python3.6 does not contain re.Pattern
Pattern = None
logger = logging.getLogger(__name__)
def should_skip_endpoint(endpoint: str, skip_endpoints: List[Pattern]) -> bool:
for skip in skip_endpoints:
if skip.match(endpoint):
return True
return False
class OwnReceive:
"""
This class is required in order to access the request
body multiple times.
"""
def __init__(self, receive: Receive):
self.receive = receive
self.data = None
async def __call__(self):
if self.data is None:
self.data = await self.receive()
return self.data
class OPAMiddleware:
def __init__(
self,
app: ASGIApp,
config: OPAConfig,
skip_endpoints: Optional[List[str]] = [
"/openapi.json",
"/docs",
"/redoc",
],
) -> None:
self.config = config
self.app = app
self.skip_endpoints = [re.compile(skip) for skip in skip_endpoints]
async def __call__(
self, scope: Scope, receive: Receive, send: Send
) -> None:
if scope["type"] == "lifespan":
return await self.app(scope, receive, send)
# Small hack to ensure that later we can still receive the body
own_receive = OwnReceive(receive)
request = Request(scope, own_receive, send)
if request.method == "OPTIONS":
return await self.app(scope, receive, send)
# allow openapi endpoints without authentication
if should_skip_endpoint(request.url.path, self.skip_endpoints):
return await self.app(scope, receive, send)
# authenticate user or get redirect to identity provider
successful = False
for auth in self.config.authentication:
try:
user_info_or_auth_redirect = auth.authenticate(
request, self.config.accepted_methods
)
if asyncio.iscoroutine(user_info_or_auth_redirect):
user_info_or_auth_redirect = (
await user_info_or_auth_redirect
)
if isinstance(user_info_or_auth_redirect, dict):
successful = True
break
except AuthenticationException:
logger.error("AuthenticationException raised on login")
# Some authentication flows require a prior redirect to id provider
if isinstance(user_info_or_auth_redirect, RedirectResponse):
return await user_info_or_auth_redirect.__call__(
scope, receive, send
)
if not successful:
return await self.get_unauthorized_response(scope, receive, send)
# Check OPA decision for info provided in user_info
# Enrich user_info if injectables are provided
if self.config.injectables:
for injectable in self.config.injectables:
# Skip endpoints if needed
if should_skip_endpoint(
request.url.path, injectable.skip_endpoints
):
continue
user_info_or_auth_redirect[injectable.key] = (
await injectable.extract(request)
)
user_info_or_auth_redirect["request_method"] = scope.get("method")
# fmt: off
user_info_or_auth_redirect["request_path"] = scope.get("path").split("/")[1:] # noqa
# fmt: on
data = {"input": user_info_or_auth_redirect}
opa_decision = requests.post(
self.config.opa_url, data=json.dumps(data), timeout=5
)
return await self.get_decision(
opa_decision, scope, own_receive, receive, send
)
def get_decision(
self,
opa_decision,
scope: Scope,
own_receive: OwnReceive,
receive: Receive,
send: Send,
):
is_authorized = False
if opa_decision.status_code != 200:
logger.error(f"Returned with status {opa_decision.status_code}.")
return self.get_unauthorized_response(scope, receive, send)
try:
is_authorized = opa_decision.json().get("result", {}).get("allow")
except JSONDecodeError:
logger.error("Unable to decode OPA response.")
return self.get_unauthorized_response(scope, receive, send)
if not is_authorized:
return self.get_unauthorized_response(scope, receive, send)
return self.app(scope, own_receive, send)
@staticmethod
async def get_unauthorized_response(
scope: Scope, receive: Receive, send: Send
) -> None:
response = JSONResponse(
status_code=401, content={"message": "Unauthorized"}
)
return await response(scope, receive, send)