diff --git a/.gitignore b/.gitignore index 633c85c02..c856fb59a 100644 --- a/.gitignore +++ b/.gitignore @@ -331,4 +331,6 @@ coverage.json .tox/ # test results / junit -junit/ \ No newline at end of file +junit/ + +pyrightconfiguration.json diff --git a/azext_edge/edge/_help.py b/azext_edge/edge/_help.py index 50525f6d6..470373445 100644 --- a/azext_edge/edge/_help.py +++ b/azext_edge/edge/_help.py @@ -403,40 +403,6 @@ def load_iotops_help(): Pre-creating an app registration is useful when the logged-in principal has constrained Entra Id permissions. For example in CI/automation scenarios, or an orgs separation of user responsibility. - - examples: - - name: Minimum input for complete setup. This includes Key Vault configuration, CSI driver deployment, TLS config and deployment of IoT Operations. - text: > - az iot ops init --cluster mycluster -g myresourcegroup --kv-id /subscriptions/2cb3a427-1abc-48d0-9d03-dd240819742a/resourceGroups/myresourcegroup/providers/Microsoft.KeyVault/vaults/mykeyvault - - - name: Same setup as prior example, except with the usage of an existing app Id and a flag to include a simulated PLC server as part of the deployment. - Including the app Id will prevent init from creating an app registration. - text: > - az iot ops init --cluster mycluster -g myresourcegroup --kv-id $KEYVAULT_ID --sp-app-id a14e216b-6802-4e9c-a6ac-844f9ffd230d --simulate-plc - - - name: To skip deployment and focus only on the Key Vault CSI driver and TLS config workflows simple pass in --no-deploy. - This can be useful when desiring to deploy from a different tool such as Portal. - text: > - az iot ops init --cluster mycluster -g myresourcegroup --kv-id $KEYVAULT_ID --sp-app-id a14e216b-6802-4e9c-a6ac-844f9ffd230d --no-deploy - - - name: To only deploy IoT Operations on a cluster that has already been prepped, simply omit --kv-id and include --no-tls. - text: > - az iot ops init --cluster mycluster -g myresourcegroup --no-tls - - - name: Use --no-block to do other work while the deployment is on-going vs waiting for the deployment to finish before starting the other work. - text: > - az iot ops init --cluster mycluster -g myresourcegroup --kv-id $KEYVAULT_ID --sp-app-id a14e216b-6802-4e9c-a6ac-844f9ffd230d --no-block - - - name: This example shows providing values for --sp-app-id, --sp-object-id and --sp-secret. These values should reflect the desired service principal - that will be used for the Key Vault CSI driver secret synchronization. Please review the command summary for additional details. - text: > - az iot ops init --cluster mycluster -g myresourcegroup --kv-id $KEYVAULT_ID --sp-app-id a14e216b-6802-4e9c-a6ac-844f9ffd230d - --sp-object-id 224a7a3f-c63d-4923-8950-c4a85f0d2f29 --sp-secret $SP_SECRET - - - name: To customize runtime configuration of the Key Vault CSI driver, --csi-config can be used. For example setting resource limits on the telegraf container dependency. - text: > - az iot ops init --cluster mycluster -g myresourcegroup --kv-id $KEYVAULT_ID --sp-app-id a14e216b-6802-4e9c-a6ac-844f9ffd230d - --csi-config telegraf.resources.limits.memory=500Mi telegraf.resources.limits.cpu=100m """ helps[ diff --git a/azext_edge/edge/commands_edge.py b/azext_edge/edge/commands_edge.py index 523065942..4883fa7d6 100644 --- a/azext_edge/edge/commands_edge.py +++ b/azext_edge/edge/commands_edge.py @@ -17,9 +17,7 @@ from .providers.check.common import ResourceOutputDetailLevel from .providers.edge_api.orc import ORC_API_V1 from .providers.orchestration.common import ( - DEFAULT_SERVICE_PRINCIPAL_SECRET_DAYS, DEFAULT_X509_CA_VALID_DAYS, - KEYVAULT_ARC_EXTENSION_VERSION, KubernetesDistroType, MqMemoryProfile, MqServiceType, @@ -102,7 +100,6 @@ def init( instance_name: Optional[str] = None, instance_description: Optional[str] = None, cluster_namespace: str = DEFAULT_NAMESPACE, - keyvault_spc_secret_name: str = DEFAULT_NAMESPACE, custom_location_name: Optional[str] = None, location: Optional[str] = None, show_template: Optional[bool] = None, @@ -125,15 +122,8 @@ def init( mq_broker_config_file: Optional[str] = None, mq_insecure: Optional[bool] = None, dataflow_profile_instances: int = 1, - disable_secret_rotation: Optional[bool] = None, - rotation_poll_interval: str = "1h", - csi_driver_version: str = KEYVAULT_ARC_EXTENSION_VERSION, - csi_driver_config: Optional[List[str]] = None, - service_principal_app_id: Optional[str] = None, - service_principal_object_id: Optional[str] = None, - service_principal_secret: Optional[str] = None, - service_principal_secret_valid_days: int = DEFAULT_SERVICE_PRINCIPAL_SECRET_DAYS, - keyvault_resource_id: Optional[str] = None, + # TODO - @digimaun csi_driver_config: Optional[List[str]] = None, + keyvault_resource_id: Optional[str] = None, # TODO - @digimaun tls_ca_path: Optional[str] = None, tls_ca_key_path: Optional[str] = None, tls_ca_dir: Optional[str] = None, @@ -144,11 +134,12 @@ def init( disable_rsync_rules: Optional[bool] = None, context_name: Optional[str] = None, ensure_latest: Optional[bool] = None, + **kwargs, ) -> Union[Dict[str, Any], None]: from .common import INIT_NO_PREFLIGHT_ENV_KEY from .providers.orchestration import deploy from .util import ( - assemble_nargs_to_dict, + # assemble_nargs_to_dict, is_env_flag_enabled, read_file_content, url_safe_random_chars, @@ -184,9 +175,6 @@ def init( if not exists(tls_ca_key_path): raise InvalidArgumentValueError("Provided CA private key file does not exist.") - if csi_driver_config: - csi_driver_config = assemble_nargs_to_dict(csi_driver_config) - # TODO - @digimaun mq_broker_config = None if mq_broker_config_file: @@ -227,20 +215,12 @@ def init( mq_insecure=mq_insecure, dataflow_profile_instances=int(dataflow_profile_instances), keyvault_resource_id=keyvault_resource_id, - keyvault_spc_secret_name=str(keyvault_spc_secret_name), - disable_secret_rotation=disable_secret_rotation, - rotation_poll_interval=str(rotation_poll_interval), - csi_driver_version=str(csi_driver_version), - csi_driver_config=csi_driver_config, - service_principal_app_id=service_principal_app_id, - service_principal_object_id=service_principal_object_id, - service_principal_secret=service_principal_secret, - service_principal_secret_valid_days=int(service_principal_secret_valid_days), tls_ca_path=tls_ca_path, tls_ca_key_path=tls_ca_key_path, tls_ca_dir=tls_ca_dir, tls_ca_valid_days=int(tls_ca_valid_days), template_path=template_path, + **kwargs, ) diff --git a/azext_edge/edge/params.py b/azext_edge/edge/params.py index 112d303ef..55a6dd465 100644 --- a/azext_edge/edge/params.py +++ b/azext_edge/edge/params.py @@ -80,7 +80,7 @@ def load_iotops_arguments(self, _): options_list=["--tags"], arg_type=tags_type, help="Instance tags. Property bag in key-value pairs with the following format: a=b c=d. " - "Use --tags \"\" to remove all tags.", + 'Use --tags "" to remove all tags.', ) context.argument( "instance_description", @@ -514,81 +514,29 @@ def load_iotops_arguments(self, _): # AKV CSI Driver context.argument( "keyvault_resource_id", - options_list=["--kv-id"], + options_list=[ + "--kv-resource-id", + context.deprecate( + target="--kv-id", + redirect="--kv-resource-id", + hide=True, + ), + ], help="Key Vault ARM resource Id. Providing this resource Id will enable the client " "to setup all necessary resources and cluster side configuration to enable " "the Key Vault CSI driver for IoT Operations.", arg_group="Key Vault CSI Driver", ) - context.argument( - "keyvault_spc_secret_name", - options_list=["--kv-spc-secret-name"], - help="The Key Vault secret **name** to use as the default SPC secret. " - "If the secret does not exist, it will be created with a cryptographically secure placeholder value.", - arg_group="Key Vault CSI Driver", - ) - context.argument( - "disable_secret_rotation", - options_list=["--disable-rotation"], - arg_type=get_three_state_flag(), - help="Flag to disable secret rotation.", - arg_group="Key Vault CSI Driver", - ) - context.argument( - "rotation_poll_interval", - options_list=["--rotation-int"], - help="Rotation poll interval.", - arg_group="Key Vault CSI Driver", - ) - context.argument( - "csi_driver_version", - options_list=["--csi-ver"], - help="CSI driver extension version.", - arg_group="Key Vault CSI Driver", - ) - context.argument( - "csi_driver_config", - options_list=["--csi-config"], - nargs="+", - action="extend", - help="CSI driver extension custom configuration. Format is space-separated key=value pairs. " - "--csi-config can be used one or more times.", - arg_group="Key Vault CSI Driver", - ) - context.argument( - "service_principal_app_id", - options_list=["--sp-app-id"], - help="Service principal app Id. If provided will be used for CSI driver setup. " - "Otherwise an app registration will be created. " - "**Required** if the logged in principal does not have permissions to query graph.", - arg_group="Key Vault CSI Driver", - ) - context.argument( - "service_principal_object_id", - options_list=["--sp-object-id"], - help="Service principal (sp) object Id. If provided will be used for CSI driver setup. " - "Otherwise the object Id will be queried from the app Id - creating the sp if one does not exist. " - "**Required** if the logged in principal does not have permissions to query graph. " - "Use `az ad sp show --id --query id -o tsv` to produce the proper object Id. " - "Alternatively using Portal you can navigate to Enterprise Applications in your Entra Id tenant.", - arg_group="Key Vault CSI Driver", - ) - context.argument( - "service_principal_secret", - options_list=["--sp-secret"], - help="The secret corresponding to the provided service principal app Id. " - "If provided will be used for CSI driver setup. Otherwise a new secret will be created. " - "**Required** if the logged in principal does not have permissions to query graph.", - arg_group="Key Vault CSI Driver", - ) - context.argument( - "service_principal_secret_valid_days", - options_list=["--sp-secret-valid-days"], - help="Option to control the duration in days of the init generated service principal secret. " - "Applicable if --sp-secret is not provided.", - arg_group="Key Vault CSI Driver", - type=int, - ) + # TODO - @digimaun - still applicable + # context.argument( + # "csi_driver_config", + # options_list=["--csi-config"], + # nargs="+", + # action="extend", + # help="CSI driver extension custom configuration. Format is space-separated key=value pairs. " + # "--csi-config can be used one or more times.", + # arg_group="Key Vault CSI Driver", + # ) # TLS context.argument( "tls_ca_path", @@ -1122,7 +1070,7 @@ def load_iotops_arguments(self, _): options_list=["--tags"], arg_type=tags_type, help="Schema registry tags. Property bag in key-value pairs with the following format: a=b c=d. " - "Use --tags \"\" to remove all tags.", + 'Use --tags "" to remove all tags.', ) context.argument( "description", diff --git a/azext_edge/edge/providers/orchestration/base.py b/azext_edge/edge/providers/orchestration/base.py index 6e5925773..fbb5bdff1 100644 --- a/azext_edge/edge/providers/orchestration/base.py +++ b/azext_edge/edge/providers/orchestration/base.py @@ -4,173 +4,99 @@ # Licensed under the MIT License. See License file in the project root for license information. # ---------------------------------------------------------------------------------------------- -import json -import logging -from time import sleep -from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional, Tuple +from typing import Optional, Tuple -from azure.cli.core.azclierror import HTTPError, ValidationError +from azure.cli.core.azclierror import ValidationError from knack.log import get_logger from ...common import K8sSecretType from ...util import ( - generate_secret, generate_self_signed_cert, get_timestamp_now_utc, read_file_content, ) from ...util.az_client import ( get_resource_client, - get_tenant_id, - get_token_from_sp_credential, - wait_for_terminal_state, ) from ..base import ( create_cluster_namespace, create_namespaced_configmap, - create_namespaced_custom_objects, create_namespaced_secret, get_cluster_namespace, ) -from ..edge_api import KEYVAULT_API_V1 from ..k8s.cluster_role_binding import get_bindings from ..k8s.config_map import get_config_map from .common import ( ARC_CONFIG_MAP, ARC_NAMESPACE, CUSTOM_LOCATIONS_RP_APP_ID, - DEFAULT_SERVICE_PRINCIPAL_SECRET_DAYS, EXTENDED_LOCATION_ROLE_BINDING, - GRAPH_V1_APP_ENDPOINT, - GRAPH_V1_ENDPOINT, GRAPH_V1_SP_ENDPOINT, - KEYVAULT_ARC_EXTENSION_VERSION, - KEYVAULT_CLOUD_API_VERSION, - KEYVAULT_DATAPLANE_API_VERSION, -) -from .components import ( - get_kv_secret_store_yaml, ) from .connected_cluster import ConnectedCluster logger = get_logger(__name__) -if TYPE_CHECKING: - from requests.models import Response - - -EXTENSION_API_VERSION = "2022-11-01" # TODO: fun testing with newer api IOT_OPERATIONS_EXTENSION_PREFIX = "microsoft.iotoperations" -PROPAGATION_DELAY_SEC = 25 - - -class ServicePrincipal(NamedTuple): - client_id: str - object_id: str - tenant_id: str - secret: str - created_app: bool - - -def provision_akv_csi_driver( - subscription_id: str, - cluster_name: str, - resource_group_name: str, - enable_secret_rotation: str, - rotation_poll_interval: str = "1h", - extension_name: str = "akvsecretsprovider", - extension_version: str = KEYVAULT_ARC_EXTENSION_VERSION, - extension_config: Optional[Dict[str, str]] = None, - **kwargs, # TODO: someday remove all kwargs from the smaller funcs -) -> dict: - resource_client = get_resource_client(subscription_id=subscription_id) - - base_config_settings: Dict[str, str] = { - "secrets-store-csi-driver.enableSecretRotation": enable_secret_rotation, - "secrets-store-csi-driver.rotationPollInterval": rotation_poll_interval, - "secrets-store-csi-driver.syncSecret.enabled": "false", - } - - if extension_config: - base_config_settings.update(extension_config) - - return wait_for_terminal_state( - resource_client.resources.begin_create_or_update_by_id( - resource_id=f"/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}" - f"/providers/Microsoft.Kubernetes/connectedClusters/{cluster_name}/Providers" - f"/Microsoft.KubernetesConfiguration/extensions/{extension_name}", - api_version=EXTENSION_API_VERSION, - parameters={ - "identity": {"type": "SystemAssigned"}, - "properties": { - "autoUpgradeMinorVersion": False, - "version": extension_version, - "extensionType": "microsoft.azurekeyvaultsecretsprovider", - "configurationSettings": base_config_settings, - "configurationProtectedSettings": {}, - }, - }, - ) - ).as_dict() - - -def configure_cluster_secrets( - cluster_namespace: str, - cluster_secret_ref: str, - cluster_akv_secret_class_name: str, - keyvault_spc_secret_name: str, - keyvault_resource_id: str, - sp_record: ServicePrincipal, - **kwargs, -): - if not KEYVAULT_API_V1.is_deployed(): - raise ValidationError( - f"The API {KEYVAULT_API_V1.as_str()} " - "is not available on the cluster the local kubeconfig is configured for.\n" - "Please ensure the local kubeconfig matches the target cluster intended for deployment." - ) - - if not get_cluster_namespace(namespace=cluster_namespace): - create_cluster_namespace(namespace=cluster_namespace) - - create_namespaced_secret( - secret_name=cluster_secret_ref, - namespace=cluster_namespace, - data={"clientid": sp_record.client_id, "clientsecret": sp_record.secret}, - labels={"secrets-store.csi.k8s.io/used": "true"}, - delete_first=True, - ) - yaml_configs = [] - keyvault_split = keyvault_resource_id.split("/") - keyvault_name = keyvault_split[-1] - - for secret_class in [ - cluster_akv_secret_class_name, - "aio-opc-ua-broker-client-certificate", - "aio-opc-ua-broker-user-authentication", - "aio-opc-ua-broker-trust-list", - "aio-opc-ua-broker-issuer-list", - ]: - yaml_configs.append( - get_kv_secret_store_yaml( - name=secret_class, - namespace=cluster_namespace, - keyvault_name=keyvault_name, - secret_name=keyvault_spc_secret_name, - tenantId=sp_record.tenant_id, - ) - ) - - create_namespaced_custom_objects( - group=KEYVAULT_API_V1.group, - version=KEYVAULT_API_V1.version, - plural="secretproviderclasses", # TODO - namespace=cluster_namespace, - yaml_objects=yaml_configs, - delete_first=True, - ) +# TODO - @digimaun - potentially can reuse +# def configure_cluster_secrets( +# cluster_namespace: str, +# cluster_secret_ref: str, +# cluster_akv_secret_class_name: str, +# keyvault_spc_secret_name: str, +# keyvault_resource_id: str, +# sp_record: ServicePrincipal, +# **kwargs, +# ): +# if not KEYVAULT_API_V1.is_deployed(): +# raise ValidationError( +# f"The API {KEYVAULT_API_V1.as_str()} " +# "is not available on the cluster the local kubeconfig is configured for.\n" +# "Please ensure the local kubeconfig matches the target cluster intended for deployment." +# ) + +# if not get_cluster_namespace(namespace=cluster_namespace): +# create_cluster_namespace(namespace=cluster_namespace) + +# create_namespaced_secret( +# secret_name=cluster_secret_ref, +# namespace=cluster_namespace, +# data={"clientid": sp_record.client_id, "clientsecret": sp_record.secret}, +# labels={"secrets-store.csi.k8s.io/used": "true"}, +# delete_first=True, +# ) + +# yaml_configs = [] +# keyvault_split = keyvault_resource_id.split("/") +# keyvault_name = keyvault_split[-1] + +# for secret_class in [ +# cluster_akv_secret_class_name, +# "aio-opc-ua-broker-client-certificate", +# "aio-opc-ua-broker-user-authentication", +# "aio-opc-ua-broker-trust-list", +# "aio-opc-ua-broker-issuer-list", +# ]: +# yaml_configs.append( +# get_kv_secret_store_yaml( +# name=secret_class, +# namespace=cluster_namespace, +# keyvault_name=keyvault_name, +# secret_name=keyvault_spc_secret_name, +# tenantId=sp_record.tenant_id, +# ) +# ) + +# create_namespaced_custom_objects( +# group=KEYVAULT_API_V1.group, +# version=KEYVAULT_API_V1.version, +# plural="secretproviderclasses", # TODO +# namespace=cluster_namespace, +# yaml_objects=yaml_configs, +# delete_first=True, +# ) def prepare_ca( @@ -232,246 +158,6 @@ def configure_cluster_tls( create_namespaced_configmap(namespace=cluster_namespace, cm_name=cm_name, data=data, delete_first=True) -def prepare_sp(cmd, deployment_name: str, **kwargs) -> ServicePrincipal: - from datetime import datetime, timedelta, timezone - - from azure.cli.core.util import send_raw_request - - sp_app_id = kwargs.get("service_principal_app_id") - sp_object_id = kwargs.get("service_principal_object_id") - sp_secret = kwargs.get("service_principal_secret") - sp_secret_valid_days = kwargs.get("service_principal_secret_valid_days", DEFAULT_SERVICE_PRINCIPAL_SECRET_DAYS) - - timestamp = datetime.now(timezone.utc) + timedelta(days=sp_secret_valid_days) - timestamp_str = timestamp.strftime("%Y-%m-%dT%H:%M:%SZ") - - app_reg = {} - app_created = False - - if all([sp_app_id, sp_object_id, sp_secret]): - return ServicePrincipal( - client_id=sp_app_id, - object_id=sp_object_id, - secret=sp_secret, - tenant_id=get_tenant_id(), - created_app=app_created, - ) - - if sp_object_id and not sp_app_id: - existing_sp = send_raw_request( - cli_ctx=cmd.cli_ctx, - method="GET", - url=f"{GRAPH_V1_SP_ENDPOINT}/{sp_object_id}", - ).json() - sp_app_id = existing_sp["appId"] - app_reg = send_raw_request( - cli_ctx=cmd.cli_ctx, - method="GET", - url=f"{GRAPH_V1_APP_ENDPOINT}(appId='{sp_app_id}')", - ).json() - - if not sp_app_id: - app_reg = send_raw_request( - cli_ctx=cmd.cli_ctx, - method="POST", - url=GRAPH_V1_APP_ENDPOINT, - body=json.dumps({"displayName": deployment_name, "signInAudience": "AzureADMyOrg"}), - ).json() - app_created = True - sp_app_id = app_reg["appId"] - - if not sp_object_id or app_created: - try: - existing_sp = send_raw_request( - cli_ctx=cmd.cli_ctx, - method="GET", - url=f"{GRAPH_V1_SP_ENDPOINT}(appId='{sp_app_id}')", - ).json() - sp_object_id = existing_sp["id"] - except HTTPError as http_error: - if http_error.response.status_code != 404: - raise http_error - sp = send_raw_request( - cli_ctx=cmd.cli_ctx, - method="POST", - url=GRAPH_V1_SP_ENDPOINT, - body=json.dumps({"appId": sp_app_id}), - ).json() - sp_object_id = sp["id"] - - if app_reg: - ensure_correct_access(cmd, sp_app_id, app_reg["requiredResourceAccess"]) - - if not sp_secret: - add_secret_op = send_raw_request( - cli_ctx=cmd.cli_ctx, - method="POST", - url=f"{GRAPH_V1_ENDPOINT}/myorganization/applications(appId='{sp_app_id}')/addPassword", - body=json.dumps({"passwordCredential": {"displayName": deployment_name, "endDateTime": timestamp_str}}), - ) - sp_secret = add_secret_op.json()["secretText"] - sleep(PROPAGATION_DELAY_SEC) - - return ServicePrincipal( - client_id=sp_app_id, - object_id=sp_object_id, - secret=sp_secret, - tenant_id=get_tenant_id(), - created_app=app_created, - ) - - -def ensure_correct_access(cmd, sp_app_id: str, existing_resource_access: List[dict]): - from azure.cli.core.util import send_raw_request - - permission_map = { - # keyvault to have full access to akv service - "cfa8b339-82a2-471a-a3c9-0fc0be7a4093": "f53da476-18e3-4152-8e01-aec403e6edc0", - # ms graph to Sign in and read user profile - "00000003-0000-0000-c000-000000000000": "e1fe6dd8-ba31-4d61-89e7-88639da4683d", - } - for resource_app in existing_resource_access: - if resource_app["resourceAppId"] in permission_map: - permission_map.pop(resource_app["resourceAppId"], None) - - for app, permission in permission_map.items(): - existing_resource_access.append( - { - "resourceAppId": app, - "resourceAccess": [{"id": permission, "type": "Scope"}], - }, - ) - - if permission_map: - send_raw_request( - cli_ctx=cmd.cli_ctx, - method="PATCH", - url=f"{GRAPH_V1_ENDPOINT}/myorganization/applications(appId='{sp_app_id}')", - body=json.dumps({"requiredResourceAccess": existing_resource_access}), - ) - - -def validate_keyvault_permission_model(subscription_id: str, keyvault_resource_id: str, **kwargs) -> dict: - resource_client = get_resource_client(subscription_id=subscription_id) - keyvault_resource: dict = resource_client.resources.get_by_id( - resource_id=keyvault_resource_id, api_version=KEYVAULT_CLOUD_API_VERSION - ).as_dict() - kv_properties = keyvault_resource["properties"] - if "enableRbacAuthorization" in kv_properties and kv_properties["enableRbacAuthorization"] is True: - raise ValidationError( - "Target Key Vault must be configured for access policy based permission model. " - "Rbac is not currently supported." - ) - return keyvault_resource - - -def prepare_keyvault_access_policy( - subscription_id: str, keyvault_resource: dict, keyvault_resource_id: str, sp_record: ServicePrincipal, **kwargs -) -> str: - resource_client = get_resource_client(subscription_id=subscription_id) - vault_uri: str = keyvault_resource["properties"]["vaultUri"] - if vault_uri[-1] == "/": - vault_uri = vault_uri[:-1] - - keyvault_access_policies: List[dict] = keyvault_resource["properties"].get("accessPolicies", []) - - add_access_policy = True - for access_policy in keyvault_access_policies: - if "objectId" in access_policy and access_policy["objectId"] == sp_record.object_id: - add_access_policy = False - - if add_access_policy: - keyvault_access_policies.append( - { - "tenantId": sp_record.tenant_id, - "objectId": sp_record.object_id, - # "applicationId": sp_record.client_id, # @digimaun - including turns into compound assignment. - "permissions": {"secrets": ["get", "list"], "keys": [], "certificates": [], "storage": []}, - } - ) - keyvault_resource["properties"]["accessPolicies"] = keyvault_access_policies - resource_client.resources.begin_create_or_update_by_id( - resource_id=f"{keyvault_resource_id}/accessPolicies/add", - api_version=KEYVAULT_CLOUD_API_VERSION, - parameters={"properties": {"accessPolicies": keyvault_access_policies}}, - ).result() - sleep(PROPAGATION_DELAY_SEC) - - return vault_uri - - -def prepare_keyvault_secret( - cmd, deployment_name: str, vault_uri: str, keyvault_spc_secret_name: Optional[str] = None, **kwargs -) -> str: - from azure.cli.core.util import send_raw_request - - url = vault_uri + "/secrets/{0}{1}?api-version={2}" - if keyvault_spc_secret_name: - get_secret_version: dict = send_raw_request( - cli_ctx=cmd.cli_ctx, - method="GET", - url=url.format(keyvault_spc_secret_name, "/versions", KEYVAULT_DATAPLANE_API_VERSION), - resource="https://vault.azure.net", - ).json() - if not get_secret_version.get("value"): - send_raw_request( - cli_ctx=cmd.cli_ctx, - method="PUT", - url=url.format(keyvault_spc_secret_name, "", KEYVAULT_DATAPLANE_API_VERSION), - resource="https://vault.azure.net", - body=json.dumps({"value": generate_secret()}), - ).json() - else: - keyvault_spc_secret_name = deployment_name.replace(".", "-") - send_raw_request( - cli_ctx=cmd.cli_ctx, - method="PUT", - url=url.format(keyvault_spc_secret_name, "", KEYVAULT_DATAPLANE_API_VERSION), - resource="https://vault.azure.net", - body=json.dumps({"value": generate_secret()}), - ).json() - - return keyvault_spc_secret_name - - -def eval_secret_via_sp(cmd, vault_uri: str, keyvault_spc_secret_name: str, sp_record: ServicePrincipal): - from azure.cli.core.util import send_raw_request - - identity_logger = logging.getLogger("azure.identity") - identity_logger_level = identity_logger.level - identity_logger.setLevel(logging.ERROR) - - auth_token = get_token_from_sp_credential( - tenant_id=sp_record.tenant_id, - client_id=sp_record.client_id, - client_secret=sp_record.secret, - scope="https://vault.azure.net/.default", - ) - identity_logger.setLevel(identity_logger_level) - - kv_secret_url = vault_uri + "/secrets/{0}?api-version={1}" - try: - send_raw_request( - cli_ctx=cmd.cli_ctx, - method="GET", - headers=[f"Authorization=Bearer {auth_token}"], # Expected header format :) - url=kv_secret_url.format(keyvault_spc_secret_name, KEYVAULT_DATAPLANE_API_VERSION), - ) - except HTTPError as e: - error_response: Response = e.response - http_error_msg = str(e) - if error_response.status_code in [401, 403]: - custom_error_msg = ( - f"{http_error_msg}\n\n" - "The error indicates an auth failure to fetch the default SPC secret from Key Vault. " - "If no access policy exists for the service principal used to setup the CSI driver" - "init will create a suitable access policy given the logged-in principal " - "has permission to do so." - ) - raise ValidationError(error_msg=custom_error_msg) - raise ValidationError(error_msg=http_error_msg) - - def deploy_template( template: dict, parameters: dict, diff --git a/azext_edge/edge/providers/orchestration/common.py b/azext_edge/edge/providers/orchestration/common.py index 337e5307d..a6efe4c38 100644 --- a/azext_edge/edge/providers/orchestration/common.py +++ b/azext_edge/edge/providers/orchestration/common.py @@ -14,8 +14,7 @@ GRAPH_ENDPOINT = "https://graph.microsoft.com/" GRAPH_V1_ENDPOINT = f"{GRAPH_ENDPOINT}v1.0" GRAPH_V1_SP_ENDPOINT = f"{GRAPH_V1_ENDPOINT}/servicePrincipals" -GRAPH_V1_APP_ENDPOINT = f"{GRAPH_V1_ENDPOINT}/applications" -DEFAULT_SERVICE_PRINCIPAL_SECRET_DAYS = 365 + CUSTOM_LOCATIONS_RP_APP_ID = "bc313c14-388c-4e7d-a58e-70017303ee3b" EXTENDED_LOCATION_ROLE_BINDING = "AzureArc-Microsoft.ExtendedLocation-RP-RoleBinding" @@ -23,9 +22,8 @@ ARC_NAMESPACE = "azure-arc" # Key Vault KPIs -KEYVAULT_ARC_EXTENSION_VERSION = "1.5.6" -KEYVAULT_DATAPLANE_API_VERSION = "7.4" -KEYVAULT_CLOUD_API_VERSION = "2022-07-01" +KEYVAULT_DATAPLANE_API_VERSION = "7.4" # TODO - @digimaun, maybe needed +KEYVAULT_CLOUD_API_VERSION = "2022-07-01" # TODO - @digimaun, maybe needed # Custom Locations KPIs CUSTOM_LOCATIONS_API_VERSION = "2021-08-31-preview" @@ -61,8 +59,6 @@ class KubernetesDistroType(Enum): "MqServiceType", "KubernetesDistroType", "DEFAULT_X509_CA_VALID_DAYS", - "DEFAULT_SERVICE_PRINCIPAL_SECRET_DAYS", - "KEYVAULT_ARC_EXTENSION_VERSION", "KEYVAULT_DATAPLANE_API_VERSION", "KEYVAULT_CLOUD_API_VERSION", ] diff --git a/azext_edge/edge/providers/orchestration/host.py b/azext_edge/edge/providers/orchestration/host.py index 3daa0d233..e0fa85fc1 100644 --- a/azext_edge/edge/providers/orchestration/host.py +++ b/azext_edge/edge/providers/orchestration/host.py @@ -12,7 +12,7 @@ from knack.log import get_logger from rich.console import Console -from .common import ARM_ENDPOINT, GRAPH_ENDPOINT, MCR_ENDPOINT +from .common import ARM_ENDPOINT, MCR_ENDPOINT logger = get_logger(__name__) console = Console(width=88) @@ -107,8 +107,6 @@ def preflight_http_connections(endpoints: List[str]) -> EndpointConnections: return EndpointConnections(connect_map=endpoint_connect_map) -def verify_cli_client_connections(include_graph: bool): +def verify_cli_client_connections(): test_endpoints = [ARM_ENDPOINT] - if include_graph: - test_endpoints.append(GRAPH_ENDPOINT) preflight_http_connections(test_endpoints).throw_if_failure(include_cluster=False) diff --git a/azext_edge/edge/providers/orchestration/work.py b/azext_edge/edge/providers/orchestration/work.py index 69cb51c32..1234336cc 100644 --- a/azext_edge/edge/providers/orchestration/work.py +++ b/azext_edge/edge/providers/orchestration/work.py @@ -21,6 +21,7 @@ from rich.table import Table from ...util import get_timestamp_now_utc +from ...util.az_client import wait_for_terminal_state from ...util.x509 import DEFAULT_EC_ALGO, DEFAULT_VALID_DAYS from .template import ( CURRENT_TEMPLATE, @@ -34,26 +35,16 @@ class WorkCategoryKey(IntEnum): PRE_FLIGHT = 1 - CSI_DRIVER = 2 - TLS_CA = 3 - DEPLOY_AIO = 4 + TLS_CA = 2 + DEPLOY_AIO = 3 class WorkStepKey(IntEnum): REG_RP = 1 ENUMERATE_PRE_FLIGHT = 2 WHAT_IF = 3 - - KV_CLOUD_PERM_MODEL = 4 - SP = 5 - KV_CLOUD_AP = 6 - KV_CLOUD_SEC = 7 - KV_CLOUD_TEST = 8 - KV_CSI_DEPLOY = 9 - KV_CSI_CLUSTER = 10 - - TLS_CERT = 11 - TLS_CLUSTER = 12 + TLS_CERT = 4 + TLS_CLUSTER = 5 class WorkRecord: @@ -61,9 +52,6 @@ def __init__(self, title: str): self.title = title -CLUSTER_SECRET_REF = "aio-akv-sp" -CLUSTER_SECRET_CLASS_NAME = "aio-default-spc" - PRE_FLIGHT_SUCCESS_STATUS = "succeeded" @@ -111,15 +99,9 @@ def __init__(self, **kwargs): self._keyvault_resource_id = kwargs.get("keyvault_resource_id") if self._keyvault_resource_id: self._keyvault_name = self._keyvault_resource_id.split("/")[-1] - self._keyvault_sat_secret_name = kwargs["keyvault_spc_secret_name"] - self._csi_driver_version: str = kwargs["csi_driver_version"] - self._csi_driver_config: Optional[Dict[str, str]] = kwargs.get("csi_driver_config") - self._sp_app_id = kwargs.get("service_principal_app_id") - self._sp_obj_id = kwargs.get("service_principal_object_id") self._tls_ca_path = kwargs.get("tls_ca_path") self._tls_ca_key_path = kwargs.get("tls_ca_key_path") self._tls_ca_valid_days = kwargs.get("tls_ca_valid_days", DEFAULT_VALID_DAYS) - self._tls_insecure = kwargs.get("tls_insecure", False) self._template_path = kwargs.get("template_path") self._progress_shown = False self._render_progress = not self._no_progress @@ -128,8 +110,6 @@ def __init__(self, **kwargs): self._active_step: int = 0 self._subscription_id = get_subscription_id(self._cmd.cli_ctx) kwargs["subscription_id"] = self._subscription_id # TODO: temporary - self._cluster_secret_ref = CLUSTER_SECRET_REF - self._cluster_secret_class_name = CLUSTER_SECRET_CLASS_NAME # TODO: Make cluster target with KPIs self._cluster_name: str = kwargs["cluster_name"] self._cluster_namespace: str = kwargs["cluster_namespace"] @@ -151,44 +131,6 @@ def _build_display(self): ) self.display.add_step(WorkCategoryKey.PRE_FLIGHT, WorkStepKey.WHAT_IF, "Verify What-If deployment") - kv_csi_cat_desc = "Key Vault CSI Driver" - self.display.add_category(WorkCategoryKey.CSI_DRIVER, kv_csi_cat_desc, skipped=not self._keyvault_resource_id) - - kv_cloud_perm_model_desc = "Verify Key Vault{}permission model" - kv_cloud_perm_model_desc = kv_cloud_perm_model_desc.format( - f" '[cyan]{self._keyvault_name}[/cyan]' " if self._keyvault_resource_id else " " - ) - self.display.add_step( - WorkCategoryKey.CSI_DRIVER, WorkStepKey.KV_CLOUD_PERM_MODEL, description=kv_cloud_perm_model_desc - ) - - if self._sp_app_id: - sp_desc = f"Use SP app Id '[cyan]{self._sp_app_id}[/cyan]'" - elif self._sp_obj_id: - sp_desc = f"Use SP object Id '[cyan]{self._sp_obj_id}[/cyan]'" - else: - sp_desc = "To create app" - self.display.add_step(WorkCategoryKey.CSI_DRIVER, WorkStepKey.SP, description=sp_desc) - - self.display.add_step( - WorkCategoryKey.CSI_DRIVER, WorkStepKey.KV_CLOUD_AP, description="Configure access policy" - ) - - kv_cloud_sec_desc = f"Ensure default SPC secret name '[cyan]{self._keyvault_sat_secret_name}[/cyan]'" - self.display.add_step(WorkCategoryKey.CSI_DRIVER, WorkStepKey.KV_CLOUD_SEC, description=kv_cloud_sec_desc) - - kv_sp_test_desc = "Test SP access" - self.display.add_step(WorkCategoryKey.CSI_DRIVER, WorkStepKey.KV_CLOUD_TEST, description=kv_sp_test_desc) - - kv_csi_deploy_desc = f"Deploy driver to cluster '[cyan]v{self._csi_driver_version}[/cyan]'" - self.display.add_step(WorkCategoryKey.CSI_DRIVER, WorkStepKey.KV_CSI_DEPLOY, description=kv_csi_deploy_desc) - - kv_csi_configure_desc = "Configure driver" - self.display.add_step( - WorkCategoryKey.CSI_DRIVER, WorkStepKey.KV_CSI_CLUSTER, description=kv_csi_configure_desc - ) - - # TODO @digimaun - MQ insecure mode self.display.add_category(WorkCategoryKey.TLS_CA, "TLS", self._no_tls) if self._tls_ca_path: tls_ca_desc = f"User provided CA '[cyan]{self._tls_ca_path}[/cyan]'" @@ -197,7 +139,6 @@ def _build_display(self): f"Generate test CA using '[cyan]{DEFAULT_EC_ALGO.name}[/cyan]' " f"valid for '[cyan]{self._tls_ca_valid_days}[/cyan]' days" ) - self.display.add_step(WorkCategoryKey.TLS_CA, WorkStepKey.TLS_CERT, tls_ca_desc) self.display.add_step(WorkCategoryKey.TLS_CA, WorkStepKey.TLS_CLUSTER, "Configure cluster for tls") @@ -211,22 +152,14 @@ def _build_display(self): def do_work(self): # noqa: C901 from ..edge_api.keyvault import KEYVAULT_API_V1 from .base import ( - configure_cluster_secrets, configure_cluster_tls, deploy_template, - eval_secret_via_sp, prepare_ca, - prepare_keyvault_access_policy, - prepare_keyvault_secret, - prepare_sp, - provision_akv_csi_driver, throw_if_iotops_deployed, - validate_keyvault_permission_model, verify_arc_cluster_config, verify_cluster_and_use_location, verify_custom_location_namespace, verify_custom_locations_enabled, - wait_for_terminal_state, ) from .host import verify_cli_client_connections from .permissions import verify_write_permission_against_rg @@ -236,8 +169,9 @@ def do_work(self): # noqa: C901 try: # Ensure connection to ARM if needed. Show remediation error message otherwise. + # TODO - @digimaun - self._keyvault_resource_id if any([not self._no_preflight, not self._no_deploy, self._keyvault_resource_id]): - verify_cli_client_connections(include_graph=bool(self._keyvault_resource_id)) + verify_cli_client_connections() # cluster_location uses actual connected cluster location. Same applies to location IF not provided. self._connected_cluster = verify_cluster_and_use_location(self._kwargs) verify_arc_cluster_config(self._connected_cluster) @@ -306,120 +240,6 @@ def do_work(self): # noqa: C901 if not self._render_progress: logger.warning("Skipped Pre-Flight as requested.") - # CSI driver segment - if self._keyvault_resource_id: - work_kpis["csiDriver"] = {} - if ( - WorkCategoryKey.CSI_DRIVER in self.display.categories - and not self.display.categories[WorkCategoryKey.CSI_DRIVER][1] - ): - self.render_display( - category=WorkCategoryKey.CSI_DRIVER, active_step=WorkStepKey.KV_CLOUD_PERM_MODEL - ) - - # WorkStepKey.KV_CLOUD_PERM_MODEL - keyvault_resource = validate_keyvault_permission_model(**self._kwargs) - - self.complete_step( - category=WorkCategoryKey.CSI_DRIVER, - completed_step=WorkStepKey.KV_CLOUD_PERM_MODEL, - active_step=WorkStepKey.SP, - ) - - # WorkStepKey.SP - sp_record = prepare_sp(deployment_name=self._work_name, **self._kwargs) - if sp_record.created_app: - self.display.steps[WorkCategoryKey.CSI_DRIVER][ - WorkStepKey.SP - ].title = f"Created app '[cyan]{sp_record.client_id}[/cyan]'" - self.render_display(category=WorkCategoryKey.CSI_DRIVER) - work_kpis["csiDriver"]["spAppId"] = sp_record.client_id - work_kpis["csiDriver"]["spObjectId"] = sp_record.object_id - work_kpis["csiDriver"]["keyVaultId"] = self._keyvault_resource_id - - self.complete_step( - category=WorkCategoryKey.CSI_DRIVER, - completed_step=WorkStepKey.SP, - active_step=WorkStepKey.KV_CLOUD_AP, - ) - - # WorkCategoryKey.KV_CLOUD_AP - vault_uri = prepare_keyvault_access_policy( - keyvault_resource=keyvault_resource, - sp_record=sp_record, - **self._kwargs, - ) - - self.complete_step( - category=WorkCategoryKey.CSI_DRIVER, - completed_step=WorkStepKey.KV_CLOUD_AP, - active_step=WorkStepKey.KV_CLOUD_SEC, - ) - - # WorkStepKey.KV_CLOUD_SEC - keyvault_spc_secret_name = prepare_keyvault_secret( - deployment_name=self._work_name, - vault_uri=vault_uri, - **self._kwargs, - ) - work_kpis["csiDriver"]["kvSpcSecretName"] = keyvault_spc_secret_name - - self.complete_step( - category=WorkCategoryKey.CSI_DRIVER, - completed_step=WorkStepKey.KV_CLOUD_SEC, - active_step=WorkStepKey.KV_CLOUD_TEST, - ) - - # WorkStepKey.KV_CLOUD_TEST - eval_secret_via_sp( - cmd=self._cmd, - vault_uri=vault_uri, - keyvault_spc_secret_name=keyvault_spc_secret_name, - sp_record=sp_record, - ) - - self.complete_step( - category=WorkCategoryKey.CSI_DRIVER, - completed_step=WorkStepKey.KV_CLOUD_TEST, - active_step=WorkStepKey.KV_CSI_DEPLOY, - ) - - # WorkStepKey.KV_CSI_DEPLOY - enable_secret_rotation = not self._kwargs.get("disable_secret_rotation", False) - enable_secret_rotation = "true" if enable_secret_rotation else "false" - - akv_csi_driver_result = provision_akv_csi_driver( - enable_secret_rotation=enable_secret_rotation, - extension_version=self._csi_driver_version, - extension_config=self._csi_driver_config, - **self._kwargs, - ) - work_kpis["csiDriver"]["version"] = akv_csi_driver_result["properties"]["version"] - work_kpis["csiDriver"]["configurationSettings"] = akv_csi_driver_result["properties"][ - "configurationSettings" - ] - - self.complete_step( - category=WorkCategoryKey.CSI_DRIVER, - completed_step=WorkStepKey.KV_CSI_DEPLOY, - active_step=WorkStepKey.KV_CSI_CLUSTER, - ) - - # WorkStepKey.KV_CSI_CLUSTER - configure_cluster_secrets( - cluster_secret_ref=self._cluster_secret_ref, - cluster_akv_secret_class_name=self._cluster_secret_class_name, - sp_record=sp_record, - **self._kwargs, - ) - - self.complete_step( - category=WorkCategoryKey.CSI_DRIVER, completed_step=WorkStepKey.KV_CSI_CLUSTER, active_step=-1 - ) - else: - if not self._render_progress: - logger.warning("Skipped AKV CSI driver setup as requested.") - # TLS segment if ( WorkCategoryKey.TLS_CA in self.display.categories @@ -627,16 +447,17 @@ def build_template(self, work_kpis: dict) -> Tuple[TemplateVer, dict]: ): parameters[template_pair[1]] = {"value": self._kwargs[template_pair[0]]} - parameters["mqSecrets"] = { - "value": { - "enabled": True, - "secretProviderClassName": self._cluster_secret_class_name, - "servicePrincipalSecretRef": self._cluster_secret_ref, - } - } - parameters["opcUaBrokerSecrets"] = { - "value": {"kind": "csi", "csiServicePrincipalSecretRef": self._cluster_secret_ref} - } + # TODO - @digimaun + # parameters["mqSecrets"] = { + # "value": { + # "enabled": True, + # "secretProviderClassName": self._cluster_secret_class_name, + # "servicePrincipalSecretRef": self._cluster_secret_ref, + # } + # } + # parameters["opcUaBrokerSecrets"] = { + # "value": {"kind": "csi", "csiServicePrincipalSecretRef": self._cluster_secret_ref} + # } parameters["deployResourceSyncRules"] = {"value": self._deploy_rsync_rules} # Covers cluster_namespace diff --git a/azext_edge/tests/edge/init/conftest.py b/azext_edge/tests/edge/init/conftest.py index 476ce145e..556a89b5c 100644 --- a/azext_edge/tests/edge/init/conftest.py +++ b/azext_edge/tests/edge/init/conftest.py @@ -4,11 +4,8 @@ # Licensed under the MIT License. See License file in the project root for license information. # ---------------------------------------------------------------------------------------------- -from typing import Dict - import pytest -from azext_edge.edge.providers.orchestration.base import KEYVAULT_ARC_EXTENSION_VERSION from azext_edge.edge.util import get_timestamp_now_utc @@ -24,38 +21,6 @@ def mocked_deploy(mocker): yield patched -@pytest.fixture -def mocked_provision_akv_csi_driver(mocker): - patched = mocker.patch("azext_edge.edge.providers.orchestration.base.provision_akv_csi_driver", autospec=True) - - base_config_settings: Dict[str, str] = { - "secrets-store-csi-driver.enableSecretRotation": "true", - "secrets-store-csi-driver.rotationPollInterval": "1h", - "secrets-store-csi-driver.syncSecret.enabled": "false", - } - - def handle_return(*args, **kwargs): - custom_config = kwargs.get("extension_config") - if custom_config: - base_config_settings.update(custom_config) - - return { - "properties": { - "version": kwargs.get("extension_version") or KEYVAULT_ARC_EXTENSION_VERSION, - "configurationSettings": base_config_settings, - } - } - - patched.side_effect = handle_return - yield patched - - -@pytest.fixture -def mocked_configure_cluster_secrets(mocker): - patched = mocker.patch("azext_edge.edge.providers.orchestration.base.configure_cluster_secrets", autospec=True) - yield patched - - @pytest.fixture def mocked_cluster_tls(mocker): patched = mocker.patch("azext_edge.edge.providers.orchestration.base.configure_cluster_tls", autospec=True) @@ -106,43 +71,6 @@ def mocked_register_providers(mocker): yield patched -@pytest.fixture -def mocked_validate_keyvault_permission_model(mocker): - patched = mocker.patch( - "azext_edge.edge.providers.orchestration.base.validate_keyvault_permission_model", autospec=True - ) - - def handle_return(*args, **kwargs): - return { - "id": ( - "/subscriptions/ae128775-4f16-4fd2-8dff-0ceed6a6a1f3/resourceGroups/FakeRg/" - "providers/Microsoft.KeyVault/vaults/myfakekeyvault" - ), - "name": "myfakekeyvault", - "type": "Microsoft.KeyVault/vaults", - "location": "westus3", - "tags": {}, - "properties": { - "sku": {"family": "A", "name": "standard"}, - "tenantId": "351f1fd9-7de2-4c5b-b730-14b54fddb737", - "accessPolicies": [ - { - "tenantId": "351f1fd9-7de2-4c5b-b730-14b54fddb737", - "objectId": "44e44a12-594e-464a-acbb-0038734403bf", - "permissions": {"keys": [], "secrets": ["Get", "List"], "certificates": [], "storage": []}, - }, - ], - "enableRbacAuthorization": False, - "vaultUri": "https://myfakekeyvault.vault.azure.net/", - "provisioningState": "Succeeded", - "publicNetworkAccess": "Enabled", - }, - } - - patched.side_effect = handle_return - yield patched - - @pytest.fixture def mocked_edge_api_keyvault_api_v1(mocker): patched = mocker.patch("azext_edge.edge.providers.edge_api.keyvault.KEYVAULT_API_V1", autospec=False) @@ -157,39 +85,9 @@ def mocked_verify_write_permission_against_rg(mocker): yield patched -@pytest.fixture -def mocked_prepare_keyvault_access_policy(mocker): - patched = mocker.patch( - "azext_edge.edge.providers.orchestration.base.prepare_keyvault_access_policy", autospec=True - ) - - def handle_return(*args, **kwargs): - return f"https://localhost/{kwargs['keyvault_resource_id']}/vault" - - patched.side_effect = handle_return - yield patched - - -@pytest.fixture -def mocked_prepare_keyvault_secret(mocker): - patched = mocker.patch("azext_edge.edge.providers.orchestration.base.prepare_keyvault_secret", autospec=True) - - def handle_return(*args, **kwargs): - return kwargs["keyvault_spc_secret_name"] - - patched.side_effect = handle_return - yield patched - - -@pytest.fixture -def mocked_prepare_sp(mocker): - patched = mocker.patch("azext_edge.edge.providers.orchestration.base.prepare_sp", autospec=True) - yield patched - - @pytest.fixture def mocked_wait_for_terminal_state(mocker): - patched = mocker.patch("azext_edge.edge.providers.orchestration.base.wait_for_terminal_state", autospec=True) + patched = mocker.patch("azext_edge.edge.providers.orchestration.work.wait_for_terminal_state", autospec=True) yield patched @@ -236,12 +134,6 @@ def mocked_verify_arc_cluster_config(mocker): yield patched -@pytest.fixture -def mocked_eval_secret_via_sp(mocker): - patched = mocker.patch("azext_edge.edge.providers.orchestration.base.eval_secret_via_sp", autospec=True) - yield patched - - @pytest.fixture def mocked_verify_custom_location_namespace(mocker): patched = mocker.patch( diff --git a/azext_edge/tests/edge/init/int/helper.py b/azext_edge/tests/edge/init/int/helper.py index 2bccf4ec7..34c1c3e63 100644 --- a/azext_edge/tests/edge/init/int/helper.py +++ b/azext_edge/tests/edge/init/int/helper.py @@ -7,7 +7,6 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union -from azext_edge.edge.providers.orchestration.common import KEYVAULT_ARC_EXTENSION_VERSION from ....helpers import run @@ -53,7 +52,6 @@ def assert_init_result( # CSI driver assert result["csiDriver"]["keyVaultId"] == key_vault assert result["csiDriver"]["kvSpcSecretName"] == arg_dict.get("kv_spc_secret_name", "azure-iot-operations") - assert result["csiDriver"]["version"] == arg_dict.get("csi_ver", KEYVAULT_ARC_EXTENSION_VERSION) if sp_app_id: assert result["csiDriver"]["spAppId"] == sp_app_id diff --git a/azext_edge/tests/edge/init/test_base_unit.py b/azext_edge/tests/edge/init/test_base_unit.py index b46c47432..b8c8059ec 100644 --- a/azext_edge/tests/edge/init/test_base_unit.py +++ b/azext_edge/tests/edge/init/test_base_unit.py @@ -4,20 +4,12 @@ # Licensed under the MIT License. See License file in the project root for license information. # ---------------------------------------------------------------------------------------------- -import json from typing import Optional from unittest.mock import Mock import pytest from azure.cli.core.azclierror import HTTPError, ValidationError -from requests.models import Response -from azext_edge.edge.providers.orchestration.base import ServicePrincipal -from azext_edge.edge.providers.orchestration.common import ( - GRAPH_V1_APP_ENDPOINT, - GRAPH_V1_ENDPOINT, - GRAPH_V1_SP_ENDPOINT, -) from azext_edge.edge.providers.orchestration.connected_cluster import ConnectedCluster from ...generators import generate_random_string, get_zeroed_subscription @@ -26,15 +18,6 @@ ZEROED_SUB = get_zeroed_subscription() -# TODO: move fixtues once functions are moved -@pytest.fixture -def mocked_wait_for_terminal_state(mocker): - terminal_result = mocker.Mock() - terminal_result.as_dict.return_value = generate_random_string() - terminal_patch = mocker.patch(f"{BASE_PATH}.wait_for_terminal_state", autospec=True, return_value=terminal_result) - yield terminal_patch - - @pytest.fixture def mocked_sleep(mocker): yield mocker.patch(f"{BASE_PATH}.sleep") @@ -64,151 +47,14 @@ def mocked_base_namespace_functions(mocker, request): create_cluster_patch = mocker.patch(f"{path}.create_cluster_namespace") create_secret_patch = mocker.patch(f"{path}.create_namespaced_secret") create_configmap_patch = mocker.patch(f"{path}.create_namespaced_configmap") - create_object_patch = mocker.patch(f"{path}.create_namespaced_custom_objects") yield { "get_cluster_patch": get_cluster_patch, "create_cluster_patch": create_cluster_patch, "create_secret_patch": create_secret_patch, "create_configmap_patch": create_configmap_patch, - "create_object_patch": create_object_patch, } -@pytest.mark.parametrize( - "mocked_resource_management_client", - [{"client_path": BASE_PATH, "resources.begin_create_or_update_by_id": {"result": generate_random_string()}}], - indirect=True, -) -@pytest.mark.parametrize("rotation_poll_interval", ["1h"]) -@pytest.mark.parametrize("extension_name", ["akvsecretsprovider"]) -@pytest.mark.parametrize("extension_version", [None, "1.5.1", "1.5.3"]) -@pytest.mark.parametrize("extension_config", [None, {"arc.enableMonitoring": "false", "a": "b"}]) -def test_provision_akv_csi_driver( - mocked_resource_management_client, - mocked_wait_for_terminal_state, - rotation_poll_interval, - extension_name, - extension_version, - extension_config, -): - from azext_edge.edge.providers.orchestration.base import ( - KEYVAULT_ARC_EXTENSION_VERSION, - provision_akv_csi_driver, - ) - - subscription_id = generate_random_string() - cluster_name = generate_random_string() - resource_group_name = generate_random_string() - enable_secret_rotation = generate_random_string() - - options = {} - expected_extension_version = KEYVAULT_ARC_EXTENSION_VERSION - if extension_version: - options["extension_version"] = extension_version - expected_extension_version = extension_version - if extension_config: - options["extension_config"] = extension_config - - result = provision_akv_csi_driver( - subscription_id=subscription_id, - cluster_name=cluster_name, - resource_group_name=resource_group_name, - enable_secret_rotation=enable_secret_rotation, - rotation_poll_interval=rotation_poll_interval, - extension_name=extension_name, - **options, - ) - - assert result == mocked_wait_for_terminal_state.return_value.as_dict.return_value - poller = mocked_resource_management_client.resources.begin_create_or_update_by_id.return_value - mocked_wait_for_terminal_state.assert_called_once_with(poller) - - call_kwargs = mocked_resource_management_client.resources.begin_create_or_update_by_id.call_args.kwargs - expected_id = ( - f"/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}" - f"/providers/Microsoft.Kubernetes/connectedClusters/{cluster_name}/Providers" - f"/Microsoft.KubernetesConfiguration/extensions/{extension_name}" - ) - assert call_kwargs["resource_id"] == expected_id - assert call_kwargs["api_version"] == "2022-11-01" - - params = call_kwargs["parameters"] - assert params["identity"] == {"type": "SystemAssigned"} - assert params["properties"]["autoUpgradeMinorVersion"] is False - assert params["properties"]["version"] == expected_extension_version - - assert params["properties"]["configurationProtectedSettings"] == {} - - config_settings = params["properties"]["configurationSettings"] - assert config_settings["secrets-store-csi-driver.enableSecretRotation"] == enable_secret_rotation - assert config_settings["secrets-store-csi-driver.rotationPollInterval"] == rotation_poll_interval - assert config_settings["secrets-store-csi-driver.syncSecret.enabled"] == "false" - - if extension_config: - for c in extension_config: - assert config_settings[c] == extension_config[c] - - -@pytest.mark.parametrize( - "mocked_base_namespace_functions", - [{"get_cluster_namespace": False}, {"get_cluster_namespace": True}], - indirect=True, -) -def test_configure_cluster_secrets(mocker, mocked_base_namespace_functions, mocked_keyvault_api): - get_store_patch = mocker.patch(f"{BASE_PATH}.get_kv_secret_store_yaml", return_value=generate_random_string()) - from azext_edge.edge.providers.orchestration.base import configure_cluster_secrets - - cluster_namespace = generate_random_string() - cluster_secret_ref = generate_random_string() - keyvault_spc_secret_name = generate_random_string() - keyvault_resource_id = generate_random_string() - sp_record = mocker.Mock( - client_id=generate_random_string(), secret=generate_random_string(), tenant_id=generate_random_string() - ) - configure_cluster_secrets( - cluster_namespace=cluster_namespace, - cluster_secret_ref=cluster_secret_ref, - cluster_akv_secret_class_name=generate_random_string(), - keyvault_spc_secret_name=keyvault_spc_secret_name, - keyvault_resource_id=keyvault_resource_id, - sp_record=sp_record, - ) - mocked_base_namespace_functions["get_cluster_patch"].assert_called_once() - expected_call_count = int(not mocked_base_namespace_functions["get_cluster_patch"].return_value) - assert mocked_base_namespace_functions["create_cluster_patch"].call_count == expected_call_count - mocked_base_namespace_functions["create_secret_patch"].assert_called_once_with( - secret_name=cluster_secret_ref, - namespace=cluster_namespace, - data={"clientid": sp_record.client_id, "clientsecret": sp_record.secret}, - labels={"secrets-store.csi.k8s.io/used": "true"}, - delete_first=True, - ) - assert get_store_patch.call_count == 5 - mocked_base_namespace_functions["create_object_patch"].assert_called_once_with( - group=mocked_keyvault_api.group, - version=mocked_keyvault_api.version, - plural="secretproviderclasses", - namespace=cluster_namespace, - yaml_objects=[get_store_patch.return_value] * 5, - delete_first=True, - ) - - -@pytest.mark.parametrize("mocked_keyvault_api", [False], indirect=True) -def test_configure_cluster_secrets_error(mocked_keyvault_api): - from azext_edge.edge.providers.orchestration.base import configure_cluster_secrets - - with pytest.raises(ValidationError): - configure_cluster_secrets( - cluster_namespace=generate_random_string(), - cluster_secret_ref=generate_random_string(), - cluster_akv_secret_class_name=generate_random_string(), - keyvault_spc_secret_name=generate_random_string(), - keyvault_resource_id=generate_random_string(), - sp_record=generate_random_string(), - ) - - @pytest.mark.parametrize( "mocked_base_namespace_functions", [{"get_cluster_namespace": False}, {"get_cluster_namespace": True}], @@ -237,215 +83,6 @@ def test_configure_cluster_tls(mocked_base_namespace_functions): ) -@pytest.mark.parametrize( - "mocked_send_raw_request", - [ - { - "return_value": { - "appId": generate_random_string(), - "id": generate_random_string(), - "secretText": generate_random_string(), - "requiredResourceAccess": [ - {"resourceAppId": "cfa8b339-82a2-471a-a3c9-0fc0be7a4093"}, - {"resourceAppId": "00000003-0000-0000-c000-000000000000"}, - ], - } - } - ], - ids=["pass everything"], - indirect=True, -) -@pytest.mark.parametrize("app_id", [None, generate_random_string()]) -@pytest.mark.parametrize("object_id", [None, generate_random_string()]) -@pytest.mark.parametrize("secret", [None, generate_random_string()]) -@pytest.mark.parametrize("secret_valid_days", [365, 100]) -def test_prepare_sp( - mocker, - mocked_cmd, - mocked_get_tenant_id, - mocked_send_raw_request, - mocked_sleep, - app_id, - object_id, - secret, - secret_valid_days, -): - import datetime - - timedelta_spy = mocker.spy(datetime, "timedelta") - access_patch = mocker.patch(f"{BASE_PATH}.ensure_correct_access") - from azext_edge.edge.providers.orchestration.base import prepare_sp - - deployment_name = generate_random_string() - sp = prepare_sp( - mocked_cmd, - deployment_name, - service_principal_app_id=app_id, - service_principal_object_id=object_id, - service_principal_secret=secret, - service_principal_secret_valid_days=secret_valid_days, - ) - raw_request_result = mocked_send_raw_request.return_value.json.return_value - assert sp.client_id == (app_id or raw_request_result["appId"]) - assert sp.object_id == (object_id or raw_request_result["id"]) - assert sp.secret == (secret or raw_request_result["secretText"]) - assert sp.tenant_id == mocked_get_tenant_id.return_value - assert sp.created_app is not (app_id or object_id) - timedelta_spy.assert_called_once_with(days=secret_valid_days) - - # check calls one by one - call_count = 0 - if not app_id: - if object_id: - get_sp_call = mocked_send_raw_request.call_args_list[call_count].kwargs - assert get_sp_call["method"] == "GET" - assert get_sp_call["url"] == f"{GRAPH_V1_SP_ENDPOINT}/{sp.object_id}" - get_app_call = mocked_send_raw_request.call_args_list[call_count + 1].kwargs - assert get_app_call["method"] == "GET" - assert get_app_call["url"] == f"{GRAPH_V1_APP_ENDPOINT}(appId='{sp.client_id}')" - call_count += 2 - else: - post_sp_call = mocked_send_raw_request.call_args_list[call_count].kwargs - assert post_sp_call["method"] == "POST" - assert post_sp_call["url"] == f"{GRAPH_V1_APP_ENDPOINT}" - assert post_sp_call["body"] == json.dumps( - {"displayName": deployment_name, "signInAudience": "AzureADMyOrg"} - ) - access_patch.assert_called_once() - call_count += 1 - if not object_id: - get_sp_call = mocked_send_raw_request.call_args_list[call_count].kwargs - assert get_sp_call["method"] == "GET" - assert get_sp_call["url"] == f"{GRAPH_V1_SP_ENDPOINT}(appId='{sp.client_id}')" - call_count += 1 - # exception case here - if not secret: - post_call = mocked_send_raw_request.call_args_list[call_count].kwargs - assert post_call["method"] == "POST" - assert post_call["url"] == ( - f"{GRAPH_V1_ENDPOINT}/myorganization/applications(appId='{sp.client_id}')/addPassword" - ) - body = json.loads(post_call["body"]) - assert body["passwordCredential"]["displayName"] == deployment_name - assert body["passwordCredential"]["endDateTime"] - mocked_sleep.assert_called_once() - call_count += 1 - - assert mocked_send_raw_request.call_count == call_count - - -@pytest.mark.parametrize("error_code", [401, 403]) -def test_prepare_sp_catches(mocker, mocked_cmd, mocked_get_tenant_id, mocked_send_raw_request, error_code): - """Test that this function does not error even if there is an http error - there are 3 cases.""" - from azext_edge.edge.providers.orchestration.base import prepare_sp - - app_id = generate_random_string() - all_result = { - "appId": app_id, - "id": generate_random_string(), - "secretText": generate_random_string(), - "requiredResourceAccess": [ - {"resourceAppId": "cfa8b339-82a2-471a-a3c9-0fc0be7a4093"}, - {"resourceAppId": "00000003-0000-0000-c000-000000000000"}, - ], - } - - def custom_responses(**kwargs): - if kwargs["url"].startswith(f"{GRAPH_V1_APP_ENDPOINT}/"): - raise HTTPError(error_msg=generate_random_string(), response=mocker.Mock(status_code=error_code)) - request_mock = mocker.Mock() - request_mock.json.return_value = all_result - return request_mock - - mocked_send_raw_request.side_effect = custom_responses - mocker.patch(f"{BASE_PATH}.ensure_correct_access") - sp = prepare_sp( - mocked_cmd, - generate_random_string(), - service_principal_object_id=generate_random_string(), - service_principal_secret=generate_random_string(), - ) - assert sp - mocked_send_raw_request.reset_mock() - - def custom_responses2(**kwargs): - if kwargs["url"].startswith(f"{GRAPH_V1_SP_ENDPOINT}(appId='"): - raise HTTPError(error_msg=generate_random_string(), response=mocker.Mock(status_code=404)) - request_mock = mocker.Mock() - request_mock.json.return_value = all_result - return request_mock - - mocked_send_raw_request.side_effect = custom_responses2 - sp = prepare_sp( - mocked_cmd, - generate_random_string(), - service_principal_secret=generate_random_string(), - ) - assert sp - post_call = mocked_send_raw_request.call_args_list[2].kwargs - assert post_call["method"] == "POST" - assert post_call["url"] == f"{GRAPH_V1_SP_ENDPOINT}" - assert post_call["body"] == json.dumps({"appId": app_id}) - - -@pytest.mark.parametrize("app_id", [None, generate_random_string()]) -@pytest.mark.parametrize("object_id", [None, generate_random_string()]) -@pytest.mark.parametrize("secret", [None, generate_random_string()]) -def test_prepare_sp_error( - mocker, mocked_cmd, mocked_get_tenant_id, mocked_send_raw_request, app_id, object_id, secret -): - response = mocker.Mock(status_code=400) - mocked_send_raw_request.return_value.json.side_effect = HTTPError( - error_msg=generate_random_string(), response=response - ) - mocker.patch(f"{BASE_PATH}.ensure_correct_access") - from azext_edge.edge.providers.orchestration.base import prepare_sp - - if not all([app_id, object_id, secret]): - with pytest.raises(HTTPError): - prepare_sp( - mocked_cmd, - generate_random_string(), - service_principal_app_id=app_id, - service_principal_object_id=object_id, - service_principal_secret=secret, - ) - - -@pytest.mark.parametrize("key_vault", [False, True]) -@pytest.mark.parametrize("ms_graph", [False, True]) -def test_ensure_correct_access(mocked_cmd, mocked_send_raw_request, key_vault, ms_graph): - from azext_edge.edge.providers.orchestration.base import ensure_correct_access - - app_id = generate_random_string() - resource_access = [] - if key_vault: - resource_access.append({"resourceAppId": "cfa8b339-82a2-471a-a3c9-0fc0be7a4093"}) - if ms_graph: - resource_access.append({"resourceAppId": "00000003-0000-0000-c000-000000000000"}) - ensure_correct_access(mocked_cmd, app_id, resource_access) - - if key_vault and ms_graph: - mocked_send_raw_request.assert_not_called() - else: - mocked_send_raw_request.assert_called_once() - patch_call = mocked_send_raw_request.call_args.kwargs - assert patch_call["method"] == "PATCH" - assert patch_call["url"] == (f"https://graph.microsoft.com/v1.0/myorganization/applications(appId='{app_id}')") - body = json.loads(patch_call["body"])["requiredResourceAccess"] - id_map = {app["resourceAppId"]: app.get("resourceAccess") for app in body} - if not key_vault: - assert "cfa8b339-82a2-471a-a3c9-0fc0be7a4093" in id_map - scope = id_map["cfa8b339-82a2-471a-a3c9-0fc0be7a4093"][0] - assert scope["type"] == "Scope" - assert scope["id"] == "f53da476-18e3-4152-8e01-aec403e6edc0" - if not ms_graph: - assert "00000003-0000-0000-c000-000000000000" in id_map - scope = id_map["00000003-0000-0000-c000-000000000000"][0] - assert scope["type"] == "Scope" - assert scope["id"] == "e1fe6dd8-ba31-4d61-89e7-88639da4683d" - - @pytest.mark.parametrize("tls_ca_path", [None, generate_random_string()]) @pytest.mark.parametrize("tls_ca_key_path", [None, generate_random_string()]) @pytest.mark.parametrize("tls_ca_dir", [None, generate_random_string()]) @@ -495,121 +132,6 @@ def test_prepare_ca(mocker, tls_ca_path, tls_ca_key_path, tls_ca_dir): assert normalize_dir_patch.call_count == 0 -@pytest.mark.parametrize( - "mocked_resource_management_client", - [ - { - "client_path": BASE_PATH, - "resources.get_by_id": {"properties": {"result": generate_random_string()}}, - } - ], - indirect=True, -) -def test_validate_keyvault_permission_model(mocked_resource_management_client): - from azext_edge.edge.providers.orchestration.base import ( - validate_keyvault_permission_model, - ) - - result = validate_keyvault_permission_model( - subscription_id=generate_random_string(), - keyvault_resource_id=generate_random_string(), - ) - assert result == mocked_resource_management_client.resources.get_by_id.return_value.as_dict.return_value - - -@pytest.mark.parametrize( - "mocked_resource_management_client", - [ - { - "client_path": BASE_PATH, - "resources.get_by_id": {"properties": {"enableRbacAuthorization": True}}, - } - ], - indirect=True, -) -def test_validate_keyvault_permission_model_error(mocked_resource_management_client): - from azext_edge.edge.providers.orchestration.base import ( - validate_keyvault_permission_model, - ) - - with pytest.raises(ValidationError): - validate_keyvault_permission_model( - subscription_id=generate_random_string(), - keyvault_resource_id=generate_random_string(), - ) - - -@pytest.mark.parametrize( - "mocked_resource_management_client", - [ - { - "client_path": BASE_PATH, - "resources.begin_create_or_update_by_id": {"result": generate_random_string()}, - } - ], - indirect=True, -) -@pytest.mark.parametrize("access_policy", [False, True]) -def test_prepare_keyvault_access_policy(mocker, mocked_resource_management_client, mocked_sleep, access_policy): - from azext_edge.edge.providers.orchestration.base import ( - prepare_keyvault_access_policy, - ) - - sp_record = mocker.Mock(object_id=generate_random_string(), tenant_id=generate_random_string()) - keyvault_resource = {"properties": {"vaultUri": generate_random_string()}} - if access_policy: - keyvault_resource["accessPolicies"] = [{"objectId": sp_record.object_id}] - sp_record = mocker.Mock(object_id=generate_random_string(), tenant_id=generate_random_string()) - result = prepare_keyvault_access_policy( - subscription_id=generate_random_string(), - keyvault_resource=keyvault_resource, - keyvault_resource_id=generate_random_string(), - sp_record=sp_record, - ) - assert result == keyvault_resource["properties"]["vaultUri"] - assert len(keyvault_resource["properties"]["accessPolicies"]) == 1 - if not access_policy: - mocked_resource_management_client.resources.begin_create_or_update_by_id.assert_called_once() - mocked_sleep.assert_called_once() - assert keyvault_resource["properties"]["accessPolicies"][0]["tenantId"] == sp_record.tenant_id - assert keyvault_resource["properties"]["accessPolicies"][0]["objectId"] == sp_record.object_id - assert keyvault_resource["properties"]["accessPolicies"][0]["permissions"] - - -@pytest.mark.parametrize( - "mocked_send_raw_request", - [ - {"return_value": {"value": [{"name": generate_random_string(), "result": generate_random_string()}]}}, - {"return_value": {"name": generate_random_string(), "result": generate_random_string()}}, - ], - ids=["value", "no value"], - indirect=True, -) -@pytest.mark.parametrize("secret_name", [None, generate_random_string()]) -def test_prepare_keyvault_secret(mocked_cmd, mocked_send_raw_request, secret_name): - from azext_edge.edge.providers.orchestration.base import prepare_keyvault_secret - - deployment_name = ".".join(generate_random_string()) - vault_uri = generate_random_string() - result = prepare_keyvault_secret( - cmd=mocked_cmd, deployment_name=deployment_name, vault_uri=vault_uri, keyvault_spc_secret_name=secret_name - ) - if secret_name: - get_kwargs = mocked_send_raw_request.call_args_list[0].kwargs - assert get_kwargs["method"] == "GET" - assert get_kwargs["url"] == f"{vault_uri}/secrets/{secret_name}/versions?api-version=7.4" - assert get_kwargs["resource"] == "https://vault.azure.net" - if not mocked_send_raw_request.return_value.json.return_value.get("value") or not secret_name: - if not secret_name: - secret_name = deployment_name.replace(".", "-") - put_kwargs = mocked_send_raw_request.call_args_list[-1].kwargs - assert put_kwargs["method"] == "PUT" - assert put_kwargs["url"] == f"{vault_uri}/secrets/{secret_name}?api-version=7.4" - assert put_kwargs["resource"] == "https://vault.azure.net" - assert put_kwargs["body"] - assert result == secret_name - - @pytest.mark.parametrize( "mocked_resource_management_client", [ @@ -922,62 +444,6 @@ def test_verify_arc_cluster_config(mocker, mocked_cmd, test_scenario): get_config_map_patch.assert_called_once() -@pytest.mark.parametrize("http_error", [None, 401, 403, 500]) -def test_eval_secret_via_sp(mocker, mocked_cmd, http_error): - - def assert_mocked_get_token_from_sp_credential(): - mocked_get_token_from_sp_credential.assert_called_once_with( - tenant_id=sp_record.tenant_id, - client_id=sp_record.client_id, - client_secret=sp_record.secret, - scope="https://vault.azure.net/.default", - ) - - mock_token = generate_random_string() - mocked_get_token_from_sp_credential: Mock = mocker.patch( - f"{BASE_PATH}.get_token_from_sp_credential", return_value=mock_token - ) - mocked_send_raw_request: Mock = mocker.patch("azure.cli.core.util.send_raw_request") - - if http_error: - test_response = Response() - test_response.status_code = http_error - mocked_send_raw_request.side_effect = HTTPError(error_msg=generate_random_string(), response=test_response) - - from azext_edge.edge.providers.orchestration.base import eval_secret_via_sp - - vault_uri = generate_random_string() - kv_spc_secret_name = generate_random_string() - sp_record = ServicePrincipal( - client_id=generate_random_string(), - object_id=generate_random_string(), - tenant_id=generate_random_string(), - secret=generate_random_string(), - created_app=False, - ) - - if http_error: - with pytest.raises(ValidationError) as ve: - eval_secret_via_sp( - cmd=mocked_cmd, vault_uri=vault_uri, keyvault_spc_secret_name=kv_spc_secret_name, sp_record=sp_record - ) - assert_mocked_get_token_from_sp_credential() - if http_error in [401, 403]: - assert "auth failure" in str(ve.value) - return - - eval_secret_via_sp( - cmd=mocked_cmd, vault_uri=vault_uri, keyvault_spc_secret_name=kv_spc_secret_name, sp_record=sp_record - ) - assert_mocked_get_token_from_sp_credential() - mocked_send_raw_request.assert_called_once_with( - cli_ctx=mocked_cmd.cli_ctx, - method="GET", - headers=[f"Authorization=Bearer {mock_token}"], - url=f"{vault_uri}/secrets/{kv_spc_secret_name}?api-version=7.4", - ) - - @pytest.mark.parametrize( "custom_location_name, namespace, get_cl_for_np_return_value", [ @@ -1032,7 +498,10 @@ def test_register_providers(mocker, registration_state): mocked_get_resource_client: Mock = mocker.patch( "azext_edge.edge.providers.orchestration.rp_namespace.get_resource_client" ) - from azext_edge.edge.providers.orchestration.rp_namespace import RP_NAMESPACE_SET, register_providers + from azext_edge.edge.providers.orchestration.rp_namespace import ( + RP_NAMESPACE_SET, + register_providers, + ) class MockProvider: def __init__(self, namespace: str, registration_state: str): diff --git a/azext_edge/tests/edge/init/test_work_unit.py b/azext_edge/tests/edge/init/test_work_unit.py index 05aa4acb5..3d9b99c8b 100644 --- a/azext_edge/tests/edge/init/test_work_unit.py +++ b/azext_edge/tests/edge/init/test_work_unit.py @@ -18,22 +18,18 @@ from azext_edge.edge.commands_edge import init from azext_edge.edge.common import INIT_NO_PREFLIGHT_ENV_KEY from azext_edge.edge.providers.base import DEFAULT_NAMESPACE -from azext_edge.edge.providers.orchestration.base import KEYVAULT_ARC_EXTENSION_VERSION from azext_edge.edge.providers.orchestration.common import ( KubernetesDistroType, MqMemoryProfile, MqServiceType, ) from azext_edge.edge.providers.orchestration.work import ( - CLUSTER_SECRET_CLASS_NAME, - CLUSTER_SECRET_REF, CURRENT_TEMPLATE, WorkCategoryKey, WorkManager, WorkStepKey, get_basic_dataflow_profile, ) -from azext_edge.edge.util import assemble_nargs_to_dict from ...generators import generate_random_string @@ -55,8 +51,6 @@ def mock_broker_config(): cluster_name, cluster_namespace, resource_group_name, - keyvault_spc_secret_name, - keyvault_resource_id, custom_location_name, custom_location_namespace, location, @@ -86,8 +80,6 @@ def mock_broker_config(): generate_random_string(), # cluster_name None, # cluster_namespace generate_random_string(), # resource_group_name - None, # keyvault_spc_secret_name - None, # keyvault_resource_id None, # custom_location_name None, # custom_location_namespace None, # location @@ -116,8 +108,6 @@ def mock_broker_config(): "Mixed_Cluster_Name", # cluster_name generate_random_string(), # cluster_namespace generate_random_string(), # resource_group_name - generate_random_string(), # keyvault_spc_secret_name - generate_random_string(), # keyvault_resource_id generate_random_string(), # custom_location_name None, # custom_location_namespace generate_random_string(), # location @@ -146,8 +136,6 @@ def mock_broker_config(): generate_random_string(), # cluster_name generate_random_string(), # cluster_namespace generate_random_string(), # resource_group_name - generate_random_string(), # keyvault_spc_secret_name - generate_random_string(), # keyvault_resource_id generate_random_string(), # custom_location_name None, # custom_location_namespace generate_random_string(), # location @@ -182,8 +170,6 @@ def test_init_to_template_params( cluster_name, cluster_namespace, resource_group_name, - keyvault_spc_secret_name, - keyvault_resource_id, custom_location_name, custom_location_namespace, location, @@ -212,8 +198,6 @@ def test_init_to_template_params( (instance_name, "instance_name"), (instance_description, "instance_description"), (cluster_namespace, "cluster_namespace"), - (keyvault_spc_secret_name, "keyvault_spc_secret_name"), - (keyvault_resource_id, "keyvault_resource_id"), (custom_location_name, "custom_location_name"), (custom_location_namespace, "custom_location_namespace"), (location, "location"), @@ -320,19 +304,6 @@ def test_init_to_template_params( else: assert passthrough_value_tuple[1] not in parameters - set_value_tuples = [ - ( - "mqSecrets", - {"enabled": True, "secretProviderClassName": "aio-default-spc", "servicePrincipalSecretRef": "aio-akv-sp"}, - ), - ( - "opcUaBrokerSecrets", - {"kind": "csi", "csiServicePrincipalSecretRef": "aio-akv-sp"}, - ), - ] - for set_value_tuple in set_value_tuples: - assert parameters[set_value_tuple[0]]["value"] == set_value_tuple[1] - assert template_ver.content["variables"]["AIO_CLUSTER_RELEASE_NAMESPACE"] == expected_cluster_namespace # TODO @@ -370,12 +341,6 @@ def test_init_to_template_params( cluster_name, cluster_namespace, resource_group_name, - keyvault_resource_id, - keyvault_spc_secret_name, - disable_secret_rotation, - rotation_poll_interval, - csi_driver_version, - csi_driver_config, tls_ca_path, tls_ca_key_path, tls_ca_dir, @@ -389,12 +354,6 @@ def test_init_to_template_params( generate_random_string(), # cluster_name None, # cluster_namespace generate_random_string(), # resource_group_name - None, # keyvault_resource_id - None, # keyvault_spc_secret_name - None, # disable_secret_rotation - None, # rotation_poll_interval - None, # csi_driver_version - None, # csi_driver_config None, # tls_ca_path None, # tls_ca_key_path None, # tls_ca_dir @@ -407,12 +366,6 @@ def test_init_to_template_params( generate_random_string(), # cluster_name None, # cluster_namespace generate_random_string(), # resource_group_name - generate_random_string(), # keyvault_resource_id - None, # keyvault_spc_secret_name - None, # disable_secret_rotation - None, # rotation_poll_interval - None, # csi_driver_version - None, # csi_driver_config None, # tls_ca_path None, # tls_ca_key_path None, # tls_ca_dir @@ -425,12 +378,6 @@ def test_init_to_template_params( generate_random_string(), # cluster_name generate_random_string(), # cluster_namespace generate_random_string(), # resource_group_name - generate_random_string(), # keyvault_resource_id - generate_random_string(), # keyvault_spc_secret_name - None, # disable_secret_rotation - None, # rotation_poll_interval - None, # csi_driver_version - None, # csi_driver_config None, # tls_ca_path None, # tls_ca_key_path None, # tls_ca_dir @@ -443,12 +390,6 @@ def test_init_to_template_params( generate_random_string(), # cluster_name None, # cluster_namespace generate_random_string(), # resource_group_name - generate_random_string(), # keyvault_resource_id - generate_random_string(), # keyvault_spc_secret_name - True, # disable_secret_rotation - "3h", # rotation_poll_interval - None, # csi_driver_version - None, # csi_driver_config None, # tls_ca_path None, # tls_ca_key_path "/certs/", # tls_ca_dir @@ -461,12 +402,6 @@ def test_init_to_template_params( generate_random_string(), # cluster_name None, # cluster_namespace generate_random_string(), # resource_group_name - generate_random_string(), # keyvault_resource_id - generate_random_string(), # keyvault_spc_secret_name - True, # disable_secret_rotation - "3h", # rotation_poll_interval - "2.0.0", # csi_driver_version - ["telegraf.resources.limits.memory=500Mi", "telegraf.resources.limits.cpu=100m"], # csi_driver_config "/my/ca.crt", # tls_ca_path "/my/key.pem", # tls_ca_key_path None, # tls_ca_dir @@ -479,12 +414,6 @@ def test_init_to_template_params( generate_random_string(), # cluster_name None, # cluster_namespace generate_random_string(), # resource_group_name - None, # keyvault_resource_id - None, # keyvault_spc_secret_name - None, # disable_secret_rotation - None, # rotation_poll_interval - None, # csi_driver_version - None, # csi_driver_config None, # tls_ca_path None, # tls_ca_key_path None, # tls_ca_dir @@ -497,12 +426,6 @@ def test_init_to_template_params( generate_random_string(), # cluster_name None, # cluster_namespace generate_random_string(), # resource_group_name - None, # keyvault_resource_id - None, # keyvault_spc_secret_name - None, # disable_secret_rotation - None, # rotation_poll_interval - None, # csi_driver_version - None, # csi_driver_config None, # tls_ca_path None, # tls_ca_key_path None, # tls_ca_dir @@ -516,18 +439,12 @@ def test_init_to_template_params( def test_work_order( mocked_cmd: Mock, mocked_config: Mock, - mocked_provision_akv_csi_driver: Mock, - mocked_configure_cluster_secrets: Mock, mocked_cluster_tls: Mock, mocked_deploy_template: Mock, mocked_prepare_ca: Mock, - mocked_prepare_keyvault_access_policy: Mock, - mocked_prepare_keyvault_secret: Mock, - mocked_prepare_sp: Mock, mocked_register_providers: Mock, mocked_verify_cli_client_connections: Mock, mocked_edge_api_keyvault_api_v1: Mock, - mocked_validate_keyvault_permission_model: Mock, mocked_verify_write_permission_against_rg: Mock, mocked_wait_for_terminal_state: Mock, mocked_file_exists: Mock, @@ -535,18 +452,11 @@ def test_work_order( mocked_connected_cluster_extensions: Mock, mocked_verify_custom_locations_enabled: Mock, mocked_verify_arc_cluster_config: Mock, - mocked_eval_secret_via_sp: Mock, mocked_verify_custom_location_namespace: Mock, spy_get_current_template_copy: Mock, cluster_name, cluster_namespace, resource_group_name, - keyvault_resource_id, - keyvault_spc_secret_name, - disable_secret_rotation, - rotation_poll_interval, - csi_driver_version, - csi_driver_config, tls_ca_path, tls_ca_key_path, tls_ca_dir, @@ -562,23 +472,18 @@ def test_work_order( "cmd": mocked_cmd, "cluster_name": cluster_name, "resource_group_name": resource_group_name, - "keyvault_resource_id": keyvault_resource_id, - "disable_secret_rotation": disable_secret_rotation, "no_deploy": no_deploy, "no_tls": no_tls, "no_progress": True, "disable_rsync_rules": disable_rsync_rules, + "wait_sec": 0.25, } if no_preflight: environ[INIT_NO_PREFLIGHT_ENV_KEY] = "true" for param_with_default in [ - (rotation_poll_interval, "rotation_poll_interval"), - (csi_driver_version, "csi_driver_version"), - (csi_driver_config, "csi_driver_config"), (cluster_namespace, "cluster_namespace"), - (keyvault_spc_secret_name, "keyvault_spc_secret_name"), (tls_ca_path, "tls_ca_path"), (tls_ca_key_path, "tls_ca_key_path"), (tls_ca_dir, "tls_ca_dir"), @@ -588,23 +493,25 @@ def test_work_order( result = init(**call_kwargs) expected_template_copies = 0 - nothing_to_do = all([not keyvault_resource_id, no_tls, no_deploy, no_preflight]) - if nothing_to_do: - assert not result - mocked_verify_cli_client_connections.assert_not_called() - mocked_edge_api_keyvault_api_v1.is_deployed.assert_not_called() - return - if any([not no_preflight, not no_deploy, keyvault_resource_id]): - mocked_verify_cli_client_connections.assert_called_once() - mocked_connected_cluster_location.assert_called_once() + # TODO - @digimaun + # nothing_to_do = all([not keyvault_resource_id, no_tls, no_deploy, no_preflight]) + # if nothing_to_do: + # assert not result + # mocked_verify_cli_client_connections.assert_not_called() + # mocked_edge_api_keyvault_api_v1.is_deployed.assert_not_called() + # return + + # if any([not no_preflight, not no_deploy, keyvault_resource_id]): + # mocked_verify_cli_client_connections.assert_called_once() + # mocked_connected_cluster_location.assert_called_once() expected_cluster_namespace = cluster_namespace.lower() if cluster_namespace else DEFAULT_NAMESPACE displays_to_eval = [] for category_tuple in [ (not no_preflight, WorkCategoryKey.PRE_FLIGHT), - (keyvault_resource_id, WorkCategoryKey.CSI_DRIVER), + # (keyvault_resource_id, WorkCategoryKey.CSI_DRIVER), (not no_tls, WorkCategoryKey.TLS_CA), (not no_deploy, WorkCategoryKey.DEPLOY_AIO), ]: @@ -633,98 +540,6 @@ def test_work_order( mocked_verify_arc_cluster_config.assert_not_called() mocked_verify_custom_location_namespace.assert_not_called() - if keyvault_resource_id: - assert result["csiDriver"] - assert result["csiDriver"]["spAppId"] - assert result["csiDriver"]["spObjectId"] - assert result["csiDriver"]["keyVaultId"] == keyvault_resource_id - - expected_csi_driver_version = csi_driver_version if csi_driver_version else KEYVAULT_ARC_EXTENSION_VERSION - assert result["csiDriver"]["version"] == expected_csi_driver_version - - expected_csi_driver_custom_config = assemble_nargs_to_dict(csi_driver_config) if csi_driver_config else {} - if expected_csi_driver_custom_config: - for key in expected_csi_driver_custom_config: - assert expected_csi_driver_custom_config[key] == result["csiDriver"]["configurationSettings"][key] - - expected_keyvault_spc_secret_name = keyvault_spc_secret_name if keyvault_spc_secret_name else DEFAULT_NAMESPACE - assert result["csiDriver"]["kvSpcSecretName"] == expected_keyvault_spc_secret_name - - mocked_validate_keyvault_permission_model.assert_called_once() - assert mocked_validate_keyvault_permission_model.call_args.kwargs["subscription_id"] - assert ( - mocked_validate_keyvault_permission_model.call_args.kwargs["keyvault_resource_id"] == keyvault_resource_id - ) - - mocked_prepare_sp.assert_called_once() - assert mocked_prepare_sp.call_args.kwargs["deployment_name"] - assert mocked_prepare_sp.call_args.kwargs["cmd"] - - mocked_prepare_keyvault_access_policy.assert_called_once() - assert mocked_prepare_keyvault_access_policy.call_args.kwargs["subscription_id"] - assert mocked_prepare_keyvault_access_policy.call_args.kwargs["keyvault_resource_id"] == keyvault_resource_id - assert mocked_prepare_keyvault_access_policy.call_args.kwargs["sp_record"] - - mocked_prepare_keyvault_secret.assert_called_once() - expected_vault_uri = f"https://localhost/{keyvault_resource_id}/vault" - - assert mocked_prepare_keyvault_secret.call_args.kwargs["cmd"] - assert mocked_prepare_keyvault_secret.call_args.kwargs["deployment_name"] - assert mocked_prepare_keyvault_secret.call_args.kwargs["vault_uri"] == expected_vault_uri - assert ( - mocked_prepare_keyvault_secret.call_args.kwargs["keyvault_spc_secret_name"] - == expected_keyvault_spc_secret_name - ) - - mocked_provision_akv_csi_driver.assert_called_once() - assert mocked_provision_akv_csi_driver.call_args.kwargs["subscription_id"] - assert mocked_provision_akv_csi_driver.call_args.kwargs["cluster_name"] == cluster_name - assert mocked_provision_akv_csi_driver.call_args.kwargs["resource_group_name"] == resource_group_name - assert ( - mocked_provision_akv_csi_driver.call_args.kwargs["enable_secret_rotation"] == "false" - if disable_secret_rotation - else "true" - ) - assert ( - mocked_provision_akv_csi_driver.call_args.kwargs["rotation_poll_interval"] == rotation_poll_interval - if rotation_poll_interval - else "1h" - ) - - assert "extension_name" not in mocked_provision_akv_csi_driver.call_args.kwargs - - mocked_configure_cluster_secrets.assert_called_once() - assert mocked_configure_cluster_secrets.call_args.kwargs["cluster_namespace"] == expected_cluster_namespace - assert mocked_configure_cluster_secrets.call_args.kwargs["cluster_secret_ref"] == CLUSTER_SECRET_REF - assert ( - mocked_configure_cluster_secrets.call_args.kwargs["cluster_akv_secret_class_name"] - == CLUSTER_SECRET_CLASS_NAME - ) - assert ( - mocked_configure_cluster_secrets.call_args.kwargs["keyvault_spc_secret_name"] - == expected_keyvault_spc_secret_name - ) - assert mocked_configure_cluster_secrets.call_args.kwargs["keyvault_resource_id"] == keyvault_resource_id - assert mocked_configure_cluster_secrets.call_args.kwargs["sp_record"] - - mocked_eval_secret_via_sp.assert_called_once() - assert mocked_eval_secret_via_sp.call_args.kwargs["vault_uri"] == expected_vault_uri - assert ( - mocked_eval_secret_via_sp.call_args.kwargs["keyvault_spc_secret_name"] == expected_keyvault_spc_secret_name - ) - assert mocked_eval_secret_via_sp.call_args.kwargs["sp_record"] - else: - if not nothing_to_do and result: - assert "csiDriver" not in result - mocked_prepare_sp.assert_not_called() - mocked_prepare_keyvault_access_policy.assert_not_called() - mocked_prepare_keyvault_secret.assert_not_called() - mocked_provision_akv_csi_driver.assert_not_called() - mocked_configure_cluster_secrets.assert_not_called() - mocked_eval_secret_via_sp.assert_not_called() - - mocked_edge_api_keyvault_api_v1.is_deployed.assert_called_once() - if not no_tls: assert result["tls"]["aioTrustConfigMap"] # TODO assert result["tls"]["aioTrustSecretName"] # TODO @@ -741,8 +556,9 @@ def test_work_order( assert mocked_cluster_tls.call_args.kwargs["secret_name"] assert mocked_cluster_tls.call_args.kwargs["cm_name"] else: - if not nothing_to_do and result: - assert "tls" not in result + # TODO - @digimaun + # if not nothing_to_do and result: + # assert "tls" not in result mocked_prepare_ca.assert_not_called() mocked_cluster_tls.assert_not_called() @@ -771,17 +587,18 @@ def test_work_order( assert mocked_deploy_template.call_args.kwargs["cluster_name"] == cluster_name assert mocked_deploy_template.call_args.kwargs["cluster_namespace"] == expected_cluster_namespace else: - if not nothing_to_do and result: - assert "deploymentName" not in result - assert "resourceGroup" not in result - assert "clusterName" not in result - assert "clusterNamespace" not in result - assert "deploymentLink" not in result - assert "deploymentState" not in result + pass + # if not nothing_to_do and result: + # assert "deploymentName" not in result + # assert "resourceGroup" not in result + # assert "clusterName" not in result + # assert "clusterNamespace" not in result + # assert "deploymentLink" not in result + # assert "deploymentState" not in result # TODO # mocked_deploy_template.assert_not_called() - assert spy_get_current_template_copy.call_count == expected_template_copies + # assert spy_get_current_template_copy.call_count == expected_template_copies def _assert_displays_for(work_category_set: FrozenSet[WorkCategoryKey], display_spys: Dict[str, Mock]): @@ -802,29 +619,6 @@ def _assert_displays_for(work_category_set: FrozenSet[WorkCategoryKey], display_ assert render_display_call_kwargs[index] == {"active_step": -1} index += 1 - if WorkCategoryKey.CSI_DRIVER in work_category_set: - assert render_display_call_kwargs[index] == { - "category": WorkCategoryKey.CSI_DRIVER, - "active_step": WorkStepKey.KV_CLOUD_PERM_MODEL, - } - index += 1 - assert render_display_call_kwargs[index] == {"active_step": WorkStepKey.SP} - index += 1 - assert render_display_call_kwargs[index] == {"category": WorkCategoryKey.CSI_DRIVER} - index += 1 - assert render_display_call_kwargs[index] == {"active_step": WorkStepKey.KV_CLOUD_AP} - index += 1 - assert render_display_call_kwargs[index] == {"active_step": WorkStepKey.KV_CLOUD_SEC} - index += 1 - assert render_display_call_kwargs[index] == {"active_step": WorkStepKey.KV_CLOUD_TEST} - index += 1 - assert render_display_call_kwargs[index] == {"active_step": WorkStepKey.KV_CSI_DEPLOY} - index += 1 - assert render_display_call_kwargs[index] == {"active_step": WorkStepKey.KV_CSI_CLUSTER} - index += 1 - assert render_display_call_kwargs[index] == {"active_step": -1} - index += 1 - if WorkCategoryKey.TLS_CA in work_category_set: assert render_display_call_kwargs[index] == { "category": WorkCategoryKey.TLS_CA, diff --git a/azext_edge/tests/edge/support/test_support_unit.py b/azext_edge/tests/edge/support/test_support_unit.py index 8a9fa4254..76e3d0f4b 100644 --- a/azext_edge/tests/edge/support/test_support_unit.py +++ b/azext_edge/tests/edge/support/test_support_unit.py @@ -35,7 +35,8 @@ AKRI_SERVICE_LABEL, AKRI_WEBHOOK_LABEL, ) -from azext_edge.edge.providers.support.arcagents import ARC_AGENTS, MONIKER +# TODO - @elsie4ever +# from azext_edge.edge.providers.support.arcagents import ARC_AGENTS, MONIKER from azext_edge.edge.providers.support.base import get_bundle_path from azext_edge.edge.providers.support.billing import ( AIO_BILLING_USAGE_NAME_LABEL, @@ -58,7 +59,7 @@ ORC_CONTROLLER_LABEL, ) from azext_edge.edge.providers.support.otel import OTEL_API, OTEL_NAME_LABEL -from azext_edge.edge.providers.support.common import COMPONENT_LABEL_FORMAT, NAME_LABEL_FORMAT +from azext_edge.edge.providers.support.common import NAME_LABEL_FORMAT # COMPONENT_LABEL_FORMAT # TODO - @elsie4ever from azext_edge.edge.providers.support_bundle import COMPAT_MQTT_BROKER_APIS from azext_edge.tests.edge.support.conftest import add_pod_to_mocked_pods @@ -895,57 +896,58 @@ def test_get_bundle_path(mocked_os_makedirs): assert str(path).startswith(expected) and str(path).endswith("_aio.zip") -# TODO - test zipfile write for specific resources -# MQ connector stateful sets need labels based on connector names -@pytest.mark.parametrize( - "mocked_cluster_resources", - [[MQTT_BROKER_API_V1B1]], - indirect=True, -) -@pytest.mark.parametrize( - "custom_objects", - [ - # connectors present - { - "items": [ - { - "metadata": {"name": "mock-connector", "namespace": "mock-namespace"}, - }, - { - "metadata": {"name": "mock-connector2", "namespace": "mock-namespace2"}, - }, - ] - }, - # no connectors - {"items": []}, - ], -) -def test_mq_list_stateful_sets( - mocker, - mocked_config, - mocked_client, - mocked_cluster_resources, - custom_objects, - mocked_zipfile, - mocked_os_makedirs, -): - - # mock MQ support bundle to return connectors - mocked_mq_support_active_api = mocker.patch("azext_edge.edge.providers.support.mq.MQ_ACTIVE_API") - mocked_mq_support_active_api.get_resources.return_value = custom_objects - result = support_bundle(None, bundle_dir=a_bundle_dir, ops_service="broker") - assert result - - # assert initial call to list stateful sets - mocked_client.AppsV1Api().list_stateful_set_for_all_namespaces.assert_any_call( - label_selector=MQ_NAME_LABEL, field_selector=None - ) - - # TODO - assert zipfile write of generic statefulset - if not custom_objects["items"]: - # TODO - will revert to initial call once the old label is removed - # mocked_client.AppsV1Api().list_stateful_set_for_all_namespaces.assert_called_once() - mocked_client.AppsV1Api().list_stateful_set_for_all_namespaces.assert_called() +# TODO - @elsie4ever +# # TODO - test zipfile write for specific resources +# # MQ connector stateful sets need labels based on connector names +# @pytest.mark.parametrize( +# "mocked_cluster_resources", +# [[MQTT_BROKER_API_V1B1]], +# indirect=True, +# ) +# @pytest.mark.parametrize( +# "custom_objects", +# [ +# # connectors present +# { +# "items": [ +# { +# "metadata": {"name": "mock-connector", "namespace": "mock-namespace"}, +# }, +# { +# "metadata": {"name": "mock-connector2", "namespace": "mock-namespace2"}, +# }, +# ] +# }, +# # no connectors +# {"items": []}, +# ], +# ) +# def test_mq_list_stateful_sets( +# mocker, +# mocked_config, +# mocked_client, +# mocked_cluster_resources, +# custom_objects, +# mocked_zipfile, +# mocked_os_makedirs, +# ): + +# # mock MQ support bundle to return connectors +# mocked_mq_support_active_api = mocker.patch("azext_edge.edge.providers.support.mq.MQ_ACTIVE_API") +# mocked_mq_support_active_api.get_resources.return_value = custom_objects +# result = support_bundle(None, bundle_dir=a_bundle_dir, ops_service="broker") +# assert result + +# # assert initial call to list stateful sets +# mocked_client.AppsV1Api().list_stateful_set_for_all_namespaces.assert_any_call( +# label_selector=MQ_NAME_LABEL, field_selector=None +# ) + +# # TODO - assert zipfile write of generic statefulset +# if not custom_objects["items"]: +# # TODO - will revert to initial call once the old label is removed +# # mocked_client.AppsV1Api().list_stateful_set_for_all_namespaces.assert_called_once() +# mocked_client.AppsV1Api().list_stateful_set_for_all_namespaces.assert_called() @pytest.mark.parametrize( @@ -991,75 +993,75 @@ def test_create_bundle_mq_traces( assert_zipfile_write(mocked_zipfile, zinfo=test_zipinfo, data="trace_data") -@pytest.mark.parametrize( - "mocked_cluster_resources", - [ - [MQTT_BROKER_API_V1B1], - [MQTT_BROKER_API_V1B1, MQ_ACTIVE_API], - [MQTT_BROKER_API_V1B1, OPCUA_API_V1], - [MQTT_BROKER_API_V1B1, OPCUA_API_V1, DEVICEREGISTRY_API_V1], - [MQTT_BROKER_API_V1B1, OPCUA_API_V1, ORC_API_V1], - [MQTT_BROKER_API_V1B1, OPCUA_API_V1, ORC_API_V1, AKRI_API_V0], - [MQTT_BROKER_API_V1B1, OPCUA_API_V1, ORC_API_V1, CLUSTER_CONFIG_API_V1], - ], - indirect=True, -) -def test_create_bundle_arc_agents( - mocked_client, - mocked_cluster_resources, - mocked_config, - mocked_os_makedirs, - mocked_zipfile, - mocked_get_custom_objects, - mocked_list_cron_jobs, - mocked_list_jobs, - mocked_list_deployments, - mocked_list_persistent_volume_claims, - mocked_list_pods, - mocked_list_replicasets, - mocked_list_statefulsets, - mocked_list_daemonsets, - mocked_list_nodes, - mocked_list_cluster_events, - mocked_list_storage_classes, - mocked_get_stats, - mocked_root_logger, - mocked_mq_active_api, - mocked_namespaced_custom_objects, - mocked_get_arc_services -): - since_seconds = random.randint(86400, 172800) - result = support_bundle(None, bundle_dir=a_bundle_dir, log_age_seconds=since_seconds) - - assert "bundlePath" in result - assert a_bundle_dir in result["bundlePath"] - - for component, has_service in ARC_AGENTS: - assert_list_pods( - mocked_client, - mocked_zipfile, - mocked_list_pods, - label_selector=COMPONENT_LABEL_FORMAT.format(label=component), - directory_path=f"{MONIKER}/{component}", - since_seconds=since_seconds - ) - assert_list_replica_sets( - mocked_client, - mocked_zipfile, - label_selector=COMPONENT_LABEL_FORMAT.format(label=component), - directory_path=f"{MONIKER}/{component}" - ) - assert_list_deployments( - mocked_client, - mocked_zipfile, - label_selector=COMPONENT_LABEL_FORMAT.format(label=component), - directory_path=f"{MONIKER}/{component}" - ) - if has_service: - assert_list_services( - mocked_client, - mocked_zipfile, - label_selector=None, - directory_path=f"{MONIKER}/{component}", - mock_names=[f"{component}"] - ) +# @pytest.mark.parametrize( +# "mocked_cluster_resources", +# [ +# [MQTT_BROKER_API_V1B1], +# [MQTT_BROKER_API_V1B1, MQ_ACTIVE_API], +# [MQTT_BROKER_API_V1B1, OPCUA_API_V1], +# [MQTT_BROKER_API_V1B1, OPCUA_API_V1, DEVICEREGISTRY_API_V1], +# [MQTT_BROKER_API_V1B1, OPCUA_API_V1, ORC_API_V1], +# [MQTT_BROKER_API_V1B1, OPCUA_API_V1, ORC_API_V1, AKRI_API_V0], +# [MQTT_BROKER_API_V1B1, OPCUA_API_V1, ORC_API_V1, CLUSTER_CONFIG_API_V1], +# ], +# indirect=True, +# ) +# def test_create_bundle_arc_agents( +# mocked_client, +# mocked_cluster_resources, +# mocked_config, +# mocked_os_makedirs, +# mocked_zipfile, +# mocked_get_custom_objects, +# mocked_list_cron_jobs, +# mocked_list_jobs, +# mocked_list_deployments, +# mocked_list_persistent_volume_claims, +# mocked_list_pods, +# mocked_list_replicasets, +# mocked_list_statefulsets, +# mocked_list_daemonsets, +# mocked_list_nodes, +# mocked_list_cluster_events, +# mocked_list_storage_classes, +# mocked_get_stats, +# mocked_root_logger, +# mocked_mq_active_api, +# mocked_namespaced_custom_objects, +# mocked_get_arc_services +# ): +# since_seconds = random.randint(86400, 172800) +# result = support_bundle(None, bundle_dir=a_bundle_dir, log_age_seconds=since_seconds) + +# assert "bundlePath" in result +# assert a_bundle_dir in result["bundlePath"] + +# for component, has_service in ARC_AGENTS: +# assert_list_pods( +# mocked_client, +# mocked_zipfile, +# mocked_list_pods, +# label_selector=COMPONENT_LABEL_FORMAT.format(label=component), +# directory_path=f"{MONIKER}/{component}", +# since_seconds=since_seconds +# ) +# assert_list_replica_sets( +# mocked_client, +# mocked_zipfile, +# label_selector=COMPONENT_LABEL_FORMAT.format(label=component), +# directory_path=f"{MONIKER}/{component}" +# ) +# assert_list_deployments( +# mocked_client, +# mocked_zipfile, +# label_selector=COMPONENT_LABEL_FORMAT.format(label=component), +# directory_path=f"{MONIKER}/{component}" +# ) +# if has_service: +# assert_list_services( +# mocked_client, +# mocked_zipfile, +# label_selector=None, +# directory_path=f"{MONIKER}/{component}", +# mock_names=[f"{component}"] +# )