Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VisualStudioCodeCredential raises CredentialUnavailableError when configured for ADFS #13556

Merged
merged 2 commits into from
Sep 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any
from typing import Any, Iterable, Optional
from azure.core.credentials import AccessToken


Expand All @@ -37,9 +37,9 @@ def __init__(self, **kwargs):
# type: (**Any) -> None
self._refresh_token = None
self._client = kwargs.pop("_client", None)
self._tenant_id = kwargs.pop("tenant_id", None) or "organizations"
if not self._client:
tenant_id = kwargs.pop("tenant_id", None) or "organizations"
self._client = AadClient(tenant_id, AZURE_VSCODE_CLIENT_ID, **kwargs)
self._client = AadClient(self._tenant_id, AZURE_VSCODE_CLIENT_ID, **kwargs)

@log_get_token("VisualStudioCodeCredential")
def get_token(self, *scopes, **kwargs):
Expand All @@ -56,6 +56,11 @@ def get_token(self, *scopes, **kwargs):
if not scopes:
raise ValueError("'get_token' requires at least one scope")

if self._tenant_id.lower() == "adfs":
raise CredentialUnavailableError(
message="VisualStudioCodeCredential authentication unavailable. ADFS is not supported."
)

token = self._client.get_cached_access_token(scopes)

if not token:
Expand All @@ -68,7 +73,7 @@ def get_token(self, *scopes, **kwargs):
return token

def _redeem_refresh_token(self, scopes, **kwargs):
# type: (Sequence[str], **Any) -> Optional[AccessToken]
# type: (Iterable[str], **Any) -> Optional[AccessToken]
if not self._refresh_token:
self._refresh_token = get_credentials()
if not self._refresh_token:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any
from typing import Any, Iterable, Optional
from azure.core.credentials import AccessToken


Expand All @@ -30,9 +30,9 @@ class VisualStudioCodeCredential(AsyncContextManager):
def __init__(self, **kwargs: "Any") -> None:
self._refresh_token = None
self._client = kwargs.pop("_client", None)
self._tenant_id = kwargs.pop("tenant_id", None) or "organizations"
if not self._client:
tenant_id = kwargs.pop("tenant_id", None) or "organizations"
self._client = AadClient(tenant_id, AZURE_VSCODE_CLIENT_ID, **kwargs)
self._client = AadClient(self._tenant_id, AZURE_VSCODE_CLIENT_ID, **kwargs)

async def __aenter__(self):
if self._client:
Expand Down Expand Up @@ -60,6 +60,11 @@ async def get_token(self, *scopes, **kwargs):
if not scopes:
raise ValueError("'get_token' requires at least one scope")

if self._tenant_id.lower() == "adfs":
raise CredentialUnavailableError(
message="VisualStudioCodeCredential authentication unavailable. ADFS is not supported."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ADFS tenant / authorities not supported.

)

token = self._client.get_cached_access_token(scopes)
if not token:
token = await self._redeem_refresh_token(scopes, **kwargs)
Expand All @@ -70,7 +75,7 @@ async def get_token(self, *scopes, **kwargs):
pass
return token

async def _redeem_refresh_token(self, scopes: "Sequence[str]", **kwargs: "Any") -> "Optional[AccessToken]":
async def _redeem_refresh_token(self, scopes: "Iterable[str]", **kwargs: "Any") -> "Optional[AccessToken]":
if not self._refresh_token:
self._refresh_token = get_credentials()
if not self._refresh_token:
Expand Down
9 changes: 9 additions & 0 deletions sdk/identity/azure-identity/tests/test_vscode_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,12 @@ def test_mac_keychain_error():
credential = VisualStudioCodeCredential()
with pytest.raises(CredentialUnavailableError):
token = credential.get_token("scope")


def test_adfs():
"""The credential should raise CredentialUnavailableError when configured for ADFS"""

credential = VisualStudioCodeCredential(tenant_id="adfs")
with pytest.raises(CredentialUnavailableError) as ex:
credential.get_token("scope")
assert "adfs" in ex.value.message.lower()
10 changes: 10 additions & 0 deletions sdk/identity/azure-identity/tests/test_vscode_credential_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,13 @@ async def test_no_obtain_token_if_cached():
credential = VisualStudioCodeCredential(_client=mock_client)
token = await credential.get_token("scope")
assert token_by_refresh_token.call_count == 0


@pytest.mark.asyncio
async def test_adfs():
"""The credential should raise CredentialUnavailableError when configured for ADFS"""

credential = VisualStudioCodeCredential(tenant_id="adfs")
with pytest.raises(CredentialUnavailableError) as ex:
await credential.get_token("scope")
assert "adfs" in ex.value.message.lower()