Skip to content

Commit

Permalink
Add cloud shell validation (#37278)
Browse files Browse the repository at this point in the history
* Add cloud shell validation

* update

* update

* update

* update

* update

* Update sdk/identity/azure-identity/CHANGELOG.md

Co-authored-by: Paul Van Eck <paulvaneck@microsoft.com>

* rename tests

---------

Co-authored-by: Paul Van Eck <paulvaneck@microsoft.com>
  • Loading branch information
xiangyan99 and pvaneck authored Sep 11, 2024
1 parent 04e0fda commit 548ea62
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 13 deletions.
1 change: 1 addition & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
### Other Changes

- Added identity config validation to `ManagedIdentityCredential` to avoid non-deterministic states (e.g. both `resource_id` and `object_id` are specified). ([#36950](https://github.com/Azure/azure-sdk-for-python/pull/36950))
- Additional validation was added for `ManagedIdentityCredential` in Azure Cloud Shell environments. ([#36438](https://github.com/Azure/azure-sdk-for-python/issues/36438))

## 1.18.0b2 (2024-08-09)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,33 @@
# ------------------------------------
import functools
import os
from typing import Any, Optional, Dict
from typing import Any, Optional, Dict, Mapping

from azure.core.pipeline.transport import HttpRequest

from .._constants import EnvironmentVariables
from .._internal import within_dac
from .._internal.managed_identity_client import ManagedIdentityClient
from .._internal.managed_identity_base import ManagedIdentityBase


def validate_client_id_and_config(client_id: Optional[str], identity_config: Optional[Mapping[str, str]]) -> None:
if within_dac.get():
return
if client_id:
raise ValueError("client_id should not be set for cloud shell managed identity.")
if identity_config:
valid_keys = {"object_id", "resource_id", "client_id"}
if len(identity_config.keys() & valid_keys) > 0:
raise ValueError(f"identity_config must not contain the following keys: {', '.join(valid_keys)}")


class CloudShellCredential(ManagedIdentityBase):
def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]:
client_id = kwargs.get("client_id")
identity_config = kwargs.get("identity_config")
validate_client_id_and_config(client_id, identity_config)

url = os.environ.get(EnvironmentVariables.MSI_ENDPOINT)
if url:
return ManagedIdentityClient(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
exclude_powershell_credential = kwargs.pop("exclude_powershell_credential", False)

credentials: List["TokenCredential"] = []
within_dac.set(True)
if not exclude_environment_credential:
credentials.append(EnvironmentCredential(authority=authority, _within_dac=True, **kwargs))
if not exclude_workload_identity_credential:
Expand Down Expand Up @@ -192,7 +193,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
)
else:
credentials.append(InteractiveBrowserCredential(tenant_id=interactive_browser_tenant_id, **kwargs))

within_dac.set(False)
super(DefaultAzureCredential, self).__init__(*credentials)

def get_token(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@
from .._internal.managed_identity_base import AsyncManagedIdentityBase
from .._internal.managed_identity_client import AsyncManagedIdentityClient
from ..._constants import EnvironmentVariables
from ..._credentials.cloud_shell import _get_request
from ..._credentials.cloud_shell import _get_request, validate_client_id_and_config


class CloudShellCredential(AsyncManagedIdentityBase):
def get_client(self, **kwargs: Any) -> Optional[AsyncManagedIdentityClient]:
client_id = kwargs.get("client_id")
identity_config = kwargs.get("identity_config")
validate_client_id_and_config(client_id, identity_config)

url = os.environ.get(EnvironmentVariables.MSI_ENDPOINT)
if url:
return AsyncManagedIdentityClient(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class DefaultAzureCredential(ChainedTokenCredential):
:caption: Create a DefaultAzureCredential.
"""

def __init__(self, **kwargs: Any) -> None:
def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statements, too-many-locals
if "tenant_id" in kwargs:
raise TypeError("'tenant_id' is not supported in DefaultAzureCredential.")

Expand Down Expand Up @@ -135,6 +135,7 @@ def __init__(self, **kwargs: Any) -> None:
exclude_powershell_credential = kwargs.pop("exclude_powershell_credential", False)

credentials = [] # type: List[AsyncTokenCredential]
within_dac.set(True)
if not exclude_environment_credential:
credentials.append(EnvironmentCredential(authority=authority, _within_dac=True, **kwargs))
if not exclude_workload_identity_credential:
Expand Down Expand Up @@ -173,7 +174,7 @@ def __init__(self, **kwargs: Any) -> None:
credentials.append(AzurePowerShellCredential(process_timeout=process_timeout))
if not exclude_developer_cli_credential:
credentials.append(AzureDeveloperCliCredential(process_timeout=process_timeout))

within_dac.set(False)
super().__init__(*credentials)

async def get_token(
Expand Down
10 changes: 10 additions & 0 deletions sdk/identity/azure-identity/tests/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,13 @@ def test_unexpected_kwarg():
def test_error_tenant_id():
with pytest.raises(TypeError):
DefaultAzureCredential(tenant_id="foo")


def test_validate_cloud_shell_credential_in_dac():
MANAGED_IDENTITY_ENVIRON = "azure.identity._credentials.managed_identity.os.environ"
with patch.dict(MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, clear=True):
DefaultAzureCredential()
DefaultAzureCredential(managed_identity_client_id="foo")
DefaultAzureCredential(identity_config={"client_id": "foo"})
DefaultAzureCredential(identity_config={"object_id": "foo"})
DefaultAzureCredential(identity_config={"resource_id": "foo"})
10 changes: 10 additions & 0 deletions sdk/identity/azure-identity/tests/test_default_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,13 @@ def test_unexpected_kwarg():
def test_error_tenant_id():
with pytest.raises(TypeError):
DefaultAzureCredential(tenant_id="foo")


def test_validate_cloud_shell_credential_in_dac():
MANAGED_IDENTITY_ENVIRON = "azure.identity.aio._credentials.managed_identity.os.environ"
with patch.dict(MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, clear=True):
DefaultAzureCredential()
DefaultAzureCredential(managed_identity_client_id="foo")
DefaultAzureCredential(identity_config={"client_id": "foo"})
DefaultAzureCredential(identity_config={"object_id": "foo"})
DefaultAzureCredential(identity_config={"resource_id": "foo"})
22 changes: 18 additions & 4 deletions sdk/identity/azure-identity/tests/test_managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,12 +309,11 @@ def test_azure_ml_tenant_id():
assert token.expires_on == expected_token.expires_on


def test_cloud_shell_user_assigned_identity():
def test_cloud_shell_identity_config():
"""Cloud Shell environment: only MSI_ENDPOINT set"""

expected_token = "****"
expires_on = 42
client_id = "some-guid"
endpoint = "http://localhost:42/token"
scope = "scope"
param_name, param_value = "foo", "bar"
Expand All @@ -325,7 +324,7 @@ def test_cloud_shell_user_assigned_identity():
base_url=endpoint,
method="POST",
required_headers={"Metadata": "true", "User-Agent": USER_AGENT},
required_data={"client_id": client_id, "resource": scope},
required_data={"resource": scope},
),
Request(
base_url=endpoint,
Expand All @@ -350,7 +349,7 @@ def test_cloud_shell_user_assigned_identity():
)

with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: endpoint}, clear=True):
token = ManagedIdentityCredential(client_id=client_id, transport=transport).get_token(scope)
token = ManagedIdentityCredential(transport=transport).get_token(scope)
assert token.token == expected_token
assert token.expires_on == expires_on

Expand Down Expand Up @@ -965,3 +964,18 @@ def test_validate_identity_config():
ManagedIdentityCredential(identity_config={"object_id": "bar", "resource_id": "foo"})
with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"object_id": "bar", "client_id": "foo"})


def test_validate_cloud_shell_credential():
with mock.patch.dict(
MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, clear=True
):
ManagedIdentityCredential()
with pytest.raises(ValueError):
ManagedIdentityCredential(client_id="foo")
with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"client_id": "foo"})
with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"object_id": "foo"})
with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"resource_id": "foo"})
22 changes: 18 additions & 4 deletions sdk/identity/azure-identity/tests/test_managed_identity_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,12 +313,11 @@ async def test_azure_ml_tenant_id():


@pytest.mark.asyncio
async def test_cloud_shell_user_assigned_identity():
async def test_cloud_shell_identity_config():
"""Cloud Shell environment: only MSI_ENDPOINT set"""

expected_token = "****"
expires_on = 42
client_id = "some-guid"
endpoint = "http://localhost:42/token"
scope = "scope"
param_name, param_value = "foo", "bar"
Expand All @@ -329,7 +328,7 @@ async def test_cloud_shell_user_assigned_identity():
base_url=endpoint,
method="POST",
required_headers={"Metadata": "true", "User-Agent": USER_AGENT},
required_data={"client_id": client_id, "resource": scope},
required_data={"resource": scope},
),
Request(
base_url=endpoint,
Expand All @@ -354,7 +353,7 @@ async def test_cloud_shell_user_assigned_identity():
)

with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: endpoint}, clear=True):
credential = ManagedIdentityCredential(client_id=client_id, transport=transport)
credential = ManagedIdentityCredential(transport=transport)
token = await credential.get_token(scope)
assert token.token == expected_token
assert token.expires_on == expires_on
Expand Down Expand Up @@ -1234,3 +1233,18 @@ def test_validate_identity_config():
ManagedIdentityCredential(identity_config={"object_id": "bar", "resource_id": "foo"})
with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"object_id": "bar", "client_id": "foo"})


def test_validate_cloud_shell_credential():
with mock.patch.dict(
MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, clear=True
):
ManagedIdentityCredential()
with pytest.raises(ValueError):
ManagedIdentityCredential(client_id="foo")
with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"client_id": "foo"})
with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"object_id": "foo"})
with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"resource_id": "foo"})

0 comments on commit 548ea62

Please sign in to comment.