Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Key Vault] Add client creation decorator to certificates tests #18319

Merged
merged 1 commit into from
Apr 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions sdk/keyvault/azure-keyvault-certificates/tests/_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,42 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import functools

from azure.keyvault.certificates import ApiVersion
from azure.keyvault.certificates._shared import HttpChallengeCache
from azure.keyvault.certificates._shared.client_base import DEFAULT_VERSION
from devtools_testutils import AzureTestCase
from parameterized import parameterized
from devtools_testutils import AzureTestCase, PowerShellPreparer
from parameterized import parameterized, param
import pytest


def client_setup(testcase_func):
"""decorator that creates a client to be passed in to a test method"""
@PowerShellPreparer("keyvault", azure_keyvault_url="https://vaultname.vault.azure.net")
@functools.wraps(testcase_func)
def wrapper(test_class_instance, azure_keyvault_url, api_version, **kwargs):
test_class_instance._skip_if_not_configured(api_version)
client = test_class_instance.create_client(azure_keyvault_url, api_version=api_version, **kwargs)

if kwargs.get("is_async"):
import asyncio

coroutine = testcase_func(test_class_instance, client)
loop = asyncio.get_event_loop()
loop.run_until_complete(coroutine)
else:
testcase_func(test_class_instance, client)
return wrapper


def get_decorator(**kwargs):
"""returns a test decorator for test parameterization"""
versions = kwargs.pop("api_versions", None) or ApiVersion
params = [param(api_version=api_version, **kwargs) for api_version in versions]
return functools.partial(parameterized.expand, params, name_func=suffixed_test_name)


def suffixed_test_name(testcase_func, param_num, param):
return "{}_{}".format(testcase_func.__name__, parameterized.to_safe_name(param.kwargs.get("api_version")))

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,17 @@
IssuerProperties,
parse_key_vault_certificate_id
)
from devtools_testutils import PowerShellPreparer
from parameterized import parameterized, param
import pytest

from _shared.test_case import KeyVaultTestCase
from _test_case import CertificatesTestCase, suffixed_test_name
from _test_case import client_setup, get_decorator, CertificatesTestCase

KeyVaultPreparer = functools.partial(
PowerShellPreparer,
"keyvault",
azure_keyvault_url="https://vaultname.vault.azure.net"
)

all_api_versions = get_decorator()
logging_enabled = get_decorator(logging_enable=True)
logging_disabled = get_decorator(logging_enable=False)
exclude_2016_10_01 = get_decorator(api_versions=[v for v in ApiVersion if v != ApiVersion.V2016_10_01])
only_2016_10_01 = get_decorator(api_versions=[ApiVersion.V2016_10_01])


class RetryAfterReplacer(RecordingProcessor):
Expand Down Expand Up @@ -169,11 +168,9 @@ def _validate_certificate_issuer_properties(self, a, b):
self.assertEqual(a.name, b.name)
self.assertEqual(a.provider, b.provider)

@parameterized.expand([param(api_version=api_version) for api_version in ApiVersion], name_func=suffixed_test_name)
@KeyVaultPreparer()
def test_crud_operations(self, azure_keyvault_url, **kwargs):
client = self.create_client(azure_keyvault_url, **kwargs)

@all_api_versions()
@client_setup
def test_crud_operations(self, client, **kwargs):
cert_name = self.get_resource_name("cert")
lifetime_actions = [LifetimeAction(lifetime_percentage=80, action=CertificatePolicyAction.auto_renew)]
cert_policy = CertificatePolicy(
Expand Down Expand Up @@ -224,11 +221,9 @@ def test_crud_operations(self, azure_keyvault_url, **kwargs):
if not hasattr(ex, "message") or "not found" not in ex.message.lower():
raise ex

@parameterized.expand([param(api_version=api_version) for api_version in ApiVersion], name_func=suffixed_test_name)
@KeyVaultPreparer()
def test_import_certificate_not_password_encoded_no_policy(self, azure_keyvault_url, **kwargs):
client = self.create_client(azure_keyvault_url, **kwargs)

@all_api_versions()
@client_setup
def test_import_certificate_not_password_encoded_no_policy(self, client, **kwargs):
# If a certificate is not password encoded, we can import the certificate
# without passing in 'password'
certificate = client.import_certificate(
Expand All @@ -237,12 +232,9 @@ def test_import_certificate_not_password_encoded_no_policy(self, azure_keyvault_
)
self.assertIsNotNone(certificate.policy)

@parameterized.expand([param(api_version=api_version) for api_version in ApiVersion], name_func=suffixed_test_name)
@KeyVaultPreparer()
def test_import_certificate_password_encoded_no_policy(self, azure_keyvault_url, **kwargs):
self._skip_if_not_configured(**kwargs)
client = self.create_client(azure_keyvault_url, **kwargs)

@all_api_versions()
@client_setup
def test_import_certificate_password_encoded_no_policy(self, client, **kwargs):
# If a certificate is password encoded, we have to pass in 'password'
# when importing the certificate
certificate = client.import_certificate(
Expand All @@ -252,11 +244,9 @@ def test_import_certificate_password_encoded_no_policy(self, azure_keyvault_url,
)
self.assertIsNotNone(certificate.policy)

@parameterized.expand([param(api_version=api_version) for api_version in ApiVersion], name_func=suffixed_test_name)
@KeyVaultPreparer()
def test_list(self, azure_keyvault_url, **kwargs):
client = self.create_client(azure_keyvault_url, **kwargs)

@all_api_versions()
@client_setup
def test_list(self, client, **kwargs):
max_certificates = self.list_test_size
expected = {}

Expand All @@ -282,12 +272,9 @@ def test_list(self, azure_keyvault_url, **kwargs):
returned_certificates = client.list_properties_of_certificates(max_page_size=max_certificates - 1)
self._validate_certificate_list(expected, returned_certificates)

@parameterized.expand([param(api_version=api_version) for api_version in ApiVersion], name_func=suffixed_test_name)
@KeyVaultPreparer()
def test_list_certificate_versions(self, azure_keyvault_url, **kwargs):
self._skip_if_not_configured(**kwargs)
client = self.create_client(azure_keyvault_url, **kwargs)

@all_api_versions()
@client_setup
def test_list_certificate_versions(self, client, **kwargs):
cert_name = self.get_resource_name("certver")

max_certificates = self.list_test_size
Expand Down Expand Up @@ -315,12 +302,9 @@ def test_list_certificate_versions(self, azure_keyvault_url, **kwargs):
),
)

@parameterized.expand([param(api_version=api_version) for api_version in ApiVersion], name_func=suffixed_test_name)
@KeyVaultPreparer()
def test_crud_contacts(self, azure_keyvault_url, **kwargs):
self._skip_if_not_configured(**kwargs)
client = self.create_client(azure_keyvault_url, **kwargs)

@all_api_versions()
@client_setup
def test_crud_contacts(self, client, **kwargs):
contact_list = [
CertificateContact(email="admin@contoso.com", name="John Doe", phone="1111111111"),
CertificateContact(email="admin2@contoso.com", name="John Doe2", phone="2222222222"),
Expand All @@ -346,11 +330,9 @@ def test_crud_contacts(self, azure_keyvault_url, **kwargs):
if not hasattr(ex, "message") or "not found" not in ex.message.lower():
raise ex

@parameterized.expand([param(api_version=api_version) for api_version in ApiVersion], name_func=suffixed_test_name)
@KeyVaultPreparer()
def test_recover_and_purge(self, azure_keyvault_url, **kwargs):
client = self.create_client(azure_keyvault_url, **kwargs)

@all_api_versions()
@client_setup
def test_recover_and_purge(self, client, **kwargs):
certs = {}
# create certificates to recover
for i in range(self.list_test_size):
Expand Down Expand Up @@ -390,12 +372,9 @@ def test_recover_and_purge(self, azure_keyvault_url, **kwargs):
actual = {k: client.get_certificate_version(certificate_name=k, version="") for k in expected.keys()}
self.assertEqual(len(set(expected.keys()) & set(actual.keys())), len(expected))

@parameterized.expand([param(api_version=api_version) for api_version in ApiVersion], name_func=suffixed_test_name)
@KeyVaultPreparer()
def test_async_request_cancellation_and_deletion(self, azure_keyvault_url, **kwargs):
self._skip_if_not_configured(**kwargs)
client = self.create_client(azure_keyvault_url, **kwargs)

@all_api_versions()
@client_setup
def test_async_request_cancellation_and_deletion(self, client, **kwargs):
cert_name = self.get_resource_name("asyncCanceledDeletedCert")
cert_policy = CertificatePolicy.get_default()
# create certificate
Expand Down Expand Up @@ -444,15 +423,9 @@ def test_async_request_cancellation_and_deletion(self, azure_keyvault_url, **kwa
# delete cancelled certificate
client.begin_delete_certificate(cert_name).wait()

@parameterized.expand(
[param(api_version=api_version) for api_version in ApiVersion if api_version != ApiVersion.V2016_10_01],
name_func=suffixed_test_name
)
@KeyVaultPreparer()
def test_policy(self, azure_keyvault_url, **kwargs):
self._skip_if_not_configured(**kwargs)
client = self.create_client(azure_keyvault_url, **kwargs)

@exclude_2016_10_01()
@client_setup
def test_policy(self, client, **kwargs):
cert_name = self.get_resource_name("policyCertificate")
cert_policy = CertificatePolicy(
issuer_name="Self",
Expand Down Expand Up @@ -485,12 +458,9 @@ def test_policy(self, azure_keyvault_url, **kwargs):

self._validate_certificate_policy(cert_policy, returned_policy)

@parameterized.expand([param(api_version=api_version) for api_version in ApiVersion], name_func=suffixed_test_name)
@KeyVaultPreparer()
def test_get_pending_certificate_signing_request(self, azure_keyvault_url, **kwargs):
self._skip_if_not_configured(**kwargs)
client = self.create_client(azure_keyvault_url, **kwargs)

@all_api_versions()
@client_setup
def test_get_pending_certificate_signing_request(self, client, **kwargs):
cert_name = self.get_resource_name("unknownIssuerCert")

# get pending certificate signing request
Expand All @@ -500,14 +470,9 @@ def test_get_pending_certificate_signing_request(self, azure_keyvault_url, **kwa
pending_version_csr = client.get_certificate_operation(certificate_name=cert_name).csr
self.assertEqual(client.get_certificate_operation(certificate_name=cert_name).csr, pending_version_csr)

@parameterized.expand(
[param(api_version=api_version) for api_version in ApiVersion if api_version != ApiVersion.V2016_10_01],
name_func=suffixed_test_name
)
@KeyVaultPreparer()
def test_backup_restore(self, azure_keyvault_url, **kwargs):
client = self.create_client(azure_keyvault_url, **kwargs)

@exclude_2016_10_01()
@client_setup
def test_backup_restore(self, client, **kwargs):
policy = CertificatePolicy.get_default()
policy._san_user_principal_names = ["john.doe@domain.com"]
cert_name = self.get_resource_name("cert")
Expand All @@ -529,12 +494,9 @@ def test_backup_restore(self, azure_keyvault_url, **kwargs):
restored_certificate = self._poll_until_no_exception(restore_function, ResourceExistsError)
self._validate_certificate_bundle(cert=restored_certificate, cert_name=cert_name, cert_policy=policy)

@parameterized.expand([param(api_version=api_version) for api_version in ApiVersion], name_func=suffixed_test_name)
@KeyVaultPreparer()
def test_crud_issuer(self, azure_keyvault_url, **kwargs):
self._skip_if_not_configured(**kwargs)
client = self.create_client(azure_keyvault_url, **kwargs)

@all_api_versions()
@client_setup
def test_crud_issuer(self, client, **kwargs):
issuer_name = self.get_resource_name("issuer")
admin_contacts = [
AdministratorContact(first_name="John", last_name="Doe", email="admin@microsoft.com", phone="4255555555")
Expand Down Expand Up @@ -611,12 +573,9 @@ def test_crud_issuer(self, azure_keyvault_url, **kwargs):
if not hasattr(ex, "message") or "not found" not in ex.message.lower():
raise ex

@parameterized.expand([param(api_version=api_version) for api_version in ApiVersion], name_func=suffixed_test_name)
@KeyVaultPreparer()
def test_logging_enabled(self, azure_keyvault_url, **kwargs):
self._skip_if_not_configured(**kwargs)
client = self.create_client(azure_keyvault_url, logging_enable=True, **kwargs)

@logging_enabled()
@client_setup
def test_logging_enabled(self, client, **kwargs):
mock_handler = MockHandler()

logger = logging.getLogger("azure")
Expand All @@ -638,12 +597,9 @@ def test_logging_enabled(self, azure_keyvault_url, **kwargs):

assert False, "Expected request body wasn't logged"

@parameterized.expand([param(api_version=api_version) for api_version in ApiVersion], name_func=suffixed_test_name)
@KeyVaultPreparer()
def test_logging_disabled(self, azure_keyvault_url, **kwargs):
self._skip_if_not_configured(**kwargs)
client = self.create_client(azure_keyvault_url, logging_enable=False, **kwargs)

@logging_disabled()
@client_setup
def test_logging_disabled(self, client, **kwargs):
mock_handler = MockHandler()

logger = logging.getLogger("azure")
Expand All @@ -662,10 +618,9 @@ def test_logging_disabled(self, azure_keyvault_url, **kwargs):
# this means the message is not JSON or has no kty property
pass

@KeyVaultPreparer()
def test_2016_10_01_models(self, azure_keyvault_url, **kwargs):
client = self.create_client(azure_keyvault_url, api_version=ApiVersion.V2016_10_01)

@only_2016_10_01()
@client_setup
def test_models(self, client, **kwargs):
"""The client should correctly deserialize version 2016-10-01 models"""

cert_name = self.get_resource_name("cert")
Expand All @@ -675,12 +630,9 @@ def test_2016_10_01_models(self, azure_keyvault_url, **kwargs):
assert cert.policy.key_curve_name is None
assert cert.policy.certificate_transparency is None

@parameterized.expand([param(api_version=api_version) for api_version in ApiVersion], name_func=suffixed_test_name)
@KeyVaultPreparer()
def test_get_certificate_version(self, azure_keyvault_url, **kwargs):
self._skip_if_not_configured(**kwargs)
client = self.create_client(azure_keyvault_url, **kwargs)

@all_api_versions()
@client_setup
def test_get_certificate_version(self, client, **kwargs):
cert_name = self.get_resource_name("cert")
for _ in range(self.list_test_size):
client.begin_create_certificate(cert_name, CertificatePolicy.get_default()).wait()
Expand All @@ -703,9 +655,10 @@ def test_get_certificate_version(self, azure_keyvault_url, **kwargs):
assert version_properties.version == cert.properties.version
assert version_properties.x509_thumbprint == cert.properties.x509_thumbprint

@KeyVaultPreparer()
def test_list_properties_of_certificates_2016_10_01(self, azure_keyvault_url, **kwargs):
client = self.create_client(azure_keyvault_url, api_version=ApiVersion.V2016_10_01)
@only_2016_10_01()
@client_setup
def test_list_properties_of_certificates(self, client, **kwargs):
"""Tests API version v2016_10_01"""

[_ for _ in client.list_properties_of_certificates()]

Expand All @@ -714,11 +667,11 @@ def test_list_properties_of_certificates_2016_10_01(self, azure_keyvault_url, **

assert "The 'include_pending' parameter to `list_properties_of_certificates` is only available for API versions v7.0 and up" in str(excinfo.value)

@KeyVaultPreparer()
def test_list_deleted_certificates_2016_10_01(self, azure_keyvault_url, **kwargs):
client = self.create_client(azure_keyvault_url, api_version=ApiVersion.V2016_10_01)


@only_2016_10_01()
@client_setup
def test_list_deleted_certificates(self, client, **kwargs):
"""Tests API version v2016_10_01"""
[_ for _ in client.list_deleted_certificates()]

with pytest.raises(NotImplementedError) as excinfo:
Expand Down
Loading