Skip to content
6 changes: 3 additions & 3 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
except: # Python 3
from urllib.parse import urljoin
import logging
from base64 import b64encode
import sys

from .oauth2cli import Client, JwtSigner
Expand Down Expand Up @@ -404,9 +403,10 @@ def _acquire_token_by_username_password_federated(
if not grant_type:
raise RuntimeError(
"RSTR returned unknown token type: %s", wstrust_result.get("type"))
self.client.grant_assertion_encoders.setdefault( # Register a non-standard type
grant_type, self.client.encode_saml_assertion)
return self.client.obtain_token_by_assertion(
b64encode(wstrust_result["token"]),
grant_type=grant_type, scope=scopes, **kwargs)
wstrust_result["token"], grant_type, scope=scopes, **kwargs)


class ConfidentialClientApplication(ClientApplication): # server-side web app
Expand Down
2 changes: 1 addition & 1 deletion msal/oauth2cli/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jwt


logger = logging.getLogger(__file__)
logger = logging.getLogger(__name__)

class Signer(object):
def sign_assertion(
Expand Down
2 changes: 1 addition & 1 deletion msal/oauth2cli/authcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .oauth2 import Client


logger = logging.getLogger(__file__)
logger = logging.getLogger(__name__)

def obtain_auth_code(listen_port, auth_uri=None):
"""This function will start a web server listening on http://localhost:port
Expand Down
54 changes: 31 additions & 23 deletions msal/oauth2cli/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import warnings
import time
import base64

import requests

Expand All @@ -18,12 +19,21 @@ class BaseClient(object):
# This low-level interface works. Yet you'll find its sub-class
# more friendly to remind you what parameters are needed in each scenario.
# More on Client Types at https://tools.ietf.org/html/rfc6749#section-2.1

@staticmethod
def encode_saml_assertion(assertion):
return base64.urlsafe_b64encode(assertion).rstrip(b'=') # Per RFC 7522

CLIENT_ASSERTION_TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
CLIENT_ASSERTION_TYPE_SAML2 = "urn:ietf:params:oauth:client-assertion-type:saml2-bearer"
client_assertion_encoders = {CLIENT_ASSERTION_TYPE_SAML2: encode_saml_assertion}

def __init__(
self,
server_configuration, # type: dict
client_id, # type: str
client_secret=None, # type: Optional[str]
client_assertion=None, # type: Optional[str]
client_assertion=None, # type: Optional[bytes]
client_assertion_type=None, # type: Optional[str]
default_headers=None, # type: Optional[dict]
default_body=None, # type: Optional[dict]
Expand All @@ -45,14 +55,14 @@ def __init__(
https://example.com/.../.well-known/openid-configuration
client_id (str): The client's id, issued by the authorization server
client_secret (str): Triggers HTTP AUTH for Confidential Client
client_assertion (str):
client_assertion (bytes):
The client assertion to authenticate this client, per RFC 7521.
It can be a raw SAML2 assertion (this method will encode it for you),
or a raw JWT assertion.
client_assertion_type (str):
The format of the client_assertion.
If you leave it as the default None, this method will try to make
a guess between SAML2 (RFC 7522) and JWT (RFC 7523),
the only two profiles defined in RFC 7521.
But you can also explicitly provide a value, if needed.
The type of your :attr:`client_assertion` parameter.
It is typically the value of :attr:`CLIENT_ASSERTION_TYPE_SAML2` or
:attr:`CLIENT_ASSERTION_TYPE_JWT`, the only two defined in RFC 7521.
default_headers (dict):
A dict to be sent in each request header.
It is not required by OAuth2 specs, but you may use it for telemetry.
Expand All @@ -66,12 +76,10 @@ def __init__(
self.client_id = client_id
self.client_secret = client_secret
self.default_body = default_body or {}
if client_assertion is not None: # See https://tools.ietf.org/html/rfc7521#section-4.2
if client_assertion_type is None: # RFC7521 defines only 2 profiles
TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
TYPE_SAML2 = "urn:ietf:params:oauth:client-assertion-type:saml2-bearer"
client_assertion_type = TYPE_JWT if "." in client_assertion else TYPE_SAML2
self.default_body["client_assertion"] = client_assertion
if client_assertion is not None and client_assertion_type is not None:
# See https://tools.ietf.org/html/rfc7521#section-4.2
encoder = self.client_assertion_encoders.get(client_assertion_type, lambda a: a)
self.default_body["client_assertion"] = encoder(client_assertion)
self.default_body["client_assertion_type"] = client_assertion_type
self.logger = logging.getLogger(__name__)
self.session = s = requests.Session()
Expand Down Expand Up @@ -172,6 +180,8 @@ class Client(BaseClient): # We choose to implement all 4 grants in 1 class
DEVICE_FLOW_RETRIABLE_ERRORS = ("authorization_pending", "slow_down")
GRANT_TYPE_SAML2 = "urn:ietf:params:oauth:grant-type:saml2-bearer" # RFC7522
GRANT_TYPE_JWT = "urn:ietf:params:oauth:grant-type:jwt-bearer" # RFC7523
grant_assertion_encoders = {GRANT_TYPE_SAML2: BaseClient.encode_saml_assertion}


def initiate_device_flow(self, scope=None, timeout=None, **kwargs):
# type: (list, **dict) -> dict
Expand Down Expand Up @@ -409,22 +419,20 @@ def obtain_token_by_refresh_token(self, token_item, scope=None,
raise ValueError("token_item should not be a type %s" % type(token_item))

def obtain_token_by_assertion(
self, assertion, grant_type=None, scope=None, **kwargs):
# type: (str, Union[str, None], Union[str, list, set, tuple]) -> dict
self, assertion, grant_type, scope=None, **kwargs):
# type: (bytes, Union[str, None], Union[str, list, set, tuple]) -> dict
"""This method implements Assertion Framework for OAuth2 (RFC 7521).
See details at https://tools.ietf.org/html/rfc7521#section-4.1

:param assertion: The assertion string which will be sent on wire as-is
:param assertion:
The assertion bytes can be a raw SAML2 assertion, or a JWT assertion.
:param grant_type:
If you leave it as the default None, this method will try to make
a guess between SAML2 (RFC 7522) and JWT (RFC 7523),
the only two profiles defined in RFC 7521.
But you can also explicitly provide a value, if needed.
It is typically either the value of :attr:`GRANT_TYPE_SAML2`,
or :attr:`GRANT_TYPE_JWT`, the only two profiles defined in RFC 7521.
:param scope: Optional. It must be a subset of previously granted scopes.
"""
if grant_type is None:
grant_type = self.GRANT_TYPE_JWT if "." in assertion else self.GRANT_TYPE_SAML2
encoder = self.grant_assertion_encoders.get(grant_type, lambda a: a)
data = kwargs.pop("data", {})
data.update(scope=scope, assertion=assertion)
data.update(scope=scope, assertion=encoder(assertion))
return self._obtain_token(grant_type, data=data, **kwargs)

10 changes: 9 additions & 1 deletion msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .authority import canonicalize


logger = logging.getLogger(__name__)

def is_subdict_of(small, big):
return dict(big, **small) == big

Expand Down Expand Up @@ -46,7 +48,13 @@ def add(self, event):
# type: (dict) -> None
# event typically contains: client_id, scope, token_endpoint,
# resposne, params, data, grant_type
logging.debug("event=%s", json.dumps(event, indent=4))
for sensitive in ("password", "client_secret"):
if sensitive in event.get("data", {}):
# Hide them from accidental exposure in logging
event["data"][sensitive] = "********"
logger.debug("event=%s", json.dumps(event, indent=4, sort_keys=True,
default=str, # A workaround when assertion is in bytes in Python 3
))
response = event.get("response", {})
access_token = response.get("access_token", {})
refresh_token = response.get("refresh_token", {})
Expand Down
2 changes: 1 addition & 1 deletion msal/wstrust_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .wstrust_response import parse_response


logger = logging.getLogger(__file__)
logger = logging.getLogger(__name__)

def send_request(
username, password, cloud_audience_urn, endpoint_address, soap_action,
Expand Down
1 change: 1 addition & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def setUpClass(cls):
audience=CONFIG["openid_configuration"]["token_endpoint"],
issuer=CONFIG["client_id"],
),
client_assertion_type=Client.CLIENT_ASSERTION_TYPE_JWT,
)
else:
cls.client = Client(
Expand Down