Skip to content

Commit

Permalink
Merge pull request #736 from OpenIDC/refactoring
Browse files Browse the repository at this point in the history
Touch of refactoring to decrease complexity
  • Loading branch information
tpazderka authored Dec 31, 2019
2 parents e315144 + d19d87d commit 44d35cf
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 81 deletions.
2 changes: 1 addition & 1 deletion pylama.ini
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ ignore = D100,D101,D102,D103,D104,D105,D106,D107,D203,D212,E203
max_line_length = 120

[pylama:mccabe]
complexity = 32
complexity = 29
82 changes: 46 additions & 36 deletions src/oic/oic/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ class InvalidSectorIdentifier(Exception):
pass


class InvalidPostLogoutUri(Exception):
"""Raised when the post_logout_redirect_uris are not valid."""


def devnull(txt):
pass

Expand Down Expand Up @@ -1187,37 +1191,29 @@ def do_client_registration(self, request, client_id, ignore=None):
_cinfo[key] = val

if "post_logout_redirect_uris" in request:
plruri = []
for uri in request["post_logout_redirect_uris"]:
if urlparse(uri).fragment:
err = ClientRegistrationErrorResponse(
error="invalid_configuration_parameter",
error_description="post_logout_redirect_uris "
"contains "
"fragment",
)
return Response(
err.to_json(),
content="application/json",
status="400 Bad Request",
)
base, query = splitquery(uri)
if query:
plruri.append((base, parse_qs(query)))
else:
plruri.append((base, query))
try:
plruri = self._verify_post_logout_uri(request)
except InvalidPostLogoutUri as err:
error = ClientRegistrationErrorResponse(
error="invalid_configuration_parameter", error_description=str(err)
)
return Response(
error.to_json(),
content="application/json",
status="400 Bad Request",
)
_cinfo["post_logout_redirect_uris"] = plruri

if "redirect_uris" in request:
try:
ruri = self.verify_redirect_uris(request)
_cinfo["redirect_uris"] = ruri
except InvalidRedirectURIError as e:
err = ClientRegistrationErrorResponse(
error = ClientRegistrationErrorResponse(
error="invalid_redirect_uri", error_description=str(e)
)
return Response(
err.to_json(), content="application/json", status_code=400
error.to_json(), content="application/json", status_code=400
)

if "sector_identifier_uri" in request:
Expand All @@ -1228,21 +1224,20 @@ def do_client_registration(self, request, client_id, ignore=None):
) = self._verify_sector_identifier(request)
except InvalidSectorIdentifier as err:
return error_response("invalid_configuration_parameter", descr=str(err))
elif "redirect_uris" in request:
if len(request["redirect_uris"]) > 1:
# check that the hostnames are the same
host = ""
for url in request["redirect_uris"]:
part = urlparse(url)
_host = part.netloc.split(":")[0]
if not host:
host = _host
else:
if host != _host:
return error_response(
"invalid_configuration_parameter",
descr="'sector_identifier_uri' must be registered",
)
elif "redirect_uris" in request and len(request["redirect_uris"]) > 1:
# check that the hostnames are the same
host = ""
for url in request["redirect_uris"]:
part = urlparse(url)
_host = part.netloc.split(":")[0]
if not host:
host = _host
else:
if host != _host:
return error_response(
"invalid_configuration_parameter",
descr="'sector_identifier_uri' must be registered",
)

for item in ["policy_uri", "logo_uri", "tos_uri"]:
if item in request:
Expand Down Expand Up @@ -1334,6 +1329,21 @@ def verify_redirect_uris(registration_request):

return verified_redirect_uris

def _verify_post_logout_uri(self, request):
"""Verify correct format of post_logout_redirect_uris."""
plruri = []
for uri in request["post_logout_redirect_uris"]:
if urlparse(uri).fragment:
raise InvalidPostLogoutUri(
"post_logout_redirect_uris contains fragment"
)
base, query = splitquery(uri)
if query:
plruri.append((base, parse_qs(query)))
else:
plruri.append((base, query))
return plruri

def _verify_sector_identifier(self, request):
"""
Verify `sector_identifier_uri` is reachable and that it contains `redirect_uri`s.
Expand Down
89 changes: 45 additions & 44 deletions src/oic/utils/rp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,42 @@ def _err(self, txt):
logger.error(sanitize(txt))
raise OIDCError(txt)

def _do_code(self, response, authresp):
"""Perform code flow."""
# get the access token
try:
args = {
"code": authresp["code"],
"redirect_uri": self.registration_response["redirect_uris"][0],
"client_id": self.client_id,
"client_secret": self.client_secret,
}

try:
args["scope"] = response["scope"]
except KeyError:
pass

atresp = self.do_access_token_request(
state=authresp["state"],
request_args=args,
authn_method=self.registration_response["token_endpoint_auth_method"],
)
msg = "Access token response: {}"
logger.info(msg.format(sanitize(atresp)))
except Exception as err:
logger.error("%s" % err)
raise

if isinstance(atresp, ErrorResponse):
msg = "Error response: {}"
self._err(msg.format(sanitize(atresp.to_dict())))

_token = atresp["access_token"]

_id_token = atresp.get("id_token")
return _token, _id_token

def callback(self, response, session, format="dict"):
"""
Call when an AuthN response has been received from the OP.
Expand All @@ -135,52 +171,17 @@ def callback(self, response, session, format="dict"):

_state = authresp["state"]

try:
_id_token = authresp["id_token"]
except KeyError:
_id_token = None
else:
if _id_token["nonce"] != self.authz_req[_state]["nonce"]:
self._err("Received nonce not the same as expected.")
_id_token = authresp.get("id_token")
if (
_id_token is not None
and _id_token["nonce"] != self.authz_req[_state]["nonce"]
):
self._err("Received nonce not the same as expected.")

if self.behaviour["response_type"] == "code":
# get the access token
try:
args = {
"code": authresp["code"],
"redirect_uri": self.registration_response["redirect_uris"][0],
"client_id": self.client_id,
"client_secret": self.client_secret,
}

try:
args["scope"] = response["scope"]
except KeyError:
pass

atresp = self.do_access_token_request(
state=authresp["state"],
request_args=args,
authn_method=self.registration_response[
"token_endpoint_auth_method"
],
)
msg = "Access token response: {}"
logger.info(msg.format(sanitize(atresp)))
except Exception as err:
logger.error("%s" % err)
raise

if isinstance(atresp, ErrorResponse):
msg = "Error response: {}"
self._err(msg.format(sanitize(atresp.to_dict())))

_token = atresp["access_token"]

try:
_id_token = atresp["id_token"]
except KeyError:
pass
_token, new_id_token = self._do_code(response, authresp)
if new_id_token is not None:
_id_token = new_id_token
else:
_token = authresp["access_token"]

Expand Down

0 comments on commit 44d35cf

Please sign in to comment.