Skip to content

Commit

Permalink
Don't send optional parameters unless explicitly specified (#533)
Browse files Browse the repository at this point in the history
* Only send optional parameters when supplied

This modifies functions to only send optional parameters when a value has explicitly been supplied.
Doing so prevents "update" functions from clobbering preexisting values when they're not supplied.
Additionally, all default parameter values have been changed to to None so that the Vault server can select the appropriate default value.

* Fix and extend Azure test case

Co-authored-by: Jeffrey Hogan <jeff@jeffhogan.me>
  • Loading branch information
llamasoft and jeffwecan authored Feb 13, 2020
1 parent 760d5b0 commit 09e0702
Show file tree
Hide file tree
Showing 24 changed files with 544 additions and 410 deletions.
139 changes: 59 additions & 80 deletions hvac/api/auth_methods/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Aws(VaultApiBase):
Reference: https://www.vaultproject.io/api/auth/aws/index.html
"""

def configure(self, max_retries=-1, access_key=None, secret_key=None, endpoint=None, iam_endpoint=None,
def configure(self, max_retries=None, access_key=None, secret_key=None, endpoint=None, iam_endpoint=None,
sts_endpoint=None, iam_server_id_header_value=None, mount_point=AWS_DEFAULT_MOUNT_POINT):
"""Configures the credentials required to perform API calls to AWS as well as custom endpoints to talk to AWS
API
Expand Down Expand Up @@ -66,21 +66,15 @@ def configure(self, max_retries=-1, access_key=None, secret_key=None, endpoint=N
:rtype: requests.Response
"""

params = {
'max_retries': max_retries
}
if access_key is not None:
params['access_key'] = access_key
if secret_key is not None:
params['secret_key'] = secret_key
if endpoint is not None:
params['endpoint'] = endpoint
if iam_endpoint is not None:
params['iam_endpoint'] = iam_endpoint
if sts_endpoint is not None:
params['sts_endpoint'] = sts_endpoint
if iam_server_id_header_value is not None:
params['iam_server_id_header_value'] = iam_server_id_header_value
params = utils.remove_nones({
'max_retries': max_retries,
'access_key': access_key,
'secret_key': secret_key,
'endpoint': endpoint,
'iam_endpoint': iam_endpoint,
'sts_endpoint': sts_endpoint,
'iam_server_id_header_value': iam_server_id_header_value,
})
api_path = utils.format_url('/v1/auth/{mount_point}/config/client', mount_point=mount_point)
return self._adapter.post(
url=api_path,
Expand Down Expand Up @@ -120,7 +114,7 @@ def delete_config(self, mount_point=AWS_DEFAULT_MOUNT_POINT):
url=api_path
)

def configure_identity_integration(self, iam_alias='role_id', ec2_alias="role_id",
def configure_identity_integration(self, iam_alias=None, ec2_alias=None,
mount_point=AWS_DEFAULT_MOUNT_POINT):
"""Configures the way that Vault interacts with the Identity store. The default (as of Vault 1.0.3)
is role_id for both values
Expand All @@ -147,22 +141,22 @@ def configure_identity_integration(self, iam_alias='role_id', ec2_alias="role_id
:return: The response of the request
:rtype: request.Response
"""
if iam_alias not in ALLOWED_IAM_ALIAS_TYPES:
if iam_alias is not None and iam_alias not in ALLOWED_IAM_ALIAS_TYPES:
error_msg = 'invalid iam alias type provided: "{arg}"; supported iam alias types: "{alias_types}"'
raise exceptions.ParamValidationError(error_msg.format(
arg=iam_alias,
environments=','.join(ALLOWED_IAM_ALIAS_TYPES)
))
if ec2_alias not in ALLOWED_EC2_ALIAS_TYPES:
if ec2_alias is not None and ec2_alias not in ALLOWED_EC2_ALIAS_TYPES:
error_msg = 'invalid ec2 alias type provided: "{arg}"; supported ec2 alias types: "{alias_types}"'
raise exceptions.ParamValidationError(error_msg.format(
arg=ec2_alias,
environments=','.join(ALLOWED_EC2_ALIAS_TYPES)
))
params = {
params = utils.remove_nones({
'iam_alias': iam_alias,
'ec2_alias': ec2_alias,
}
})
api_auth = '/v1/auth/{mount_point}/config/identity'.format(mount_point=mount_point)
return self._adapter.post(
url=api_auth,
Expand All @@ -186,7 +180,7 @@ def read_identity_integration(self, mount_point=AWS_DEFAULT_MOUNT_POINT):
)
return response.json().get('data')

def create_certificate_configuration(self, cert_name, aws_public_cert, document_type="pkcs7", mount_point=AWS_DEFAULT_MOUNT_POINT):
def create_certificate_configuration(self, cert_name, aws_public_cert, document_type=None, mount_point=AWS_DEFAULT_MOUNT_POINT):
"""Registers an AWS public key to be used to verify the instance identity documents
While the PKCS#7 signature of the identity documents have DSA digest, the identity signature will have RSA
Expand All @@ -211,8 +205,12 @@ def create_certificate_configuration(self, cert_name, aws_public_cert, document_
params = {
'cert_name': cert_name,
'aws_public_cert': aws_public_cert,
'document_type': document_type,
}
params.update(
utils.remove_nones({
'document_type': document_type,
})
)
api_path = utils.format_url('/v1/auth/{0}/config/certificate/{1}', mount_point, cert_name)
return self._adapter.post(
url=api_path,
Expand Down Expand Up @@ -332,7 +330,7 @@ def delete_sts_role(self, account_id, mount_point=AWS_DEFAULT_MOUNT_POINT):
url=api_path,
)

def configure_identity_whitelist_tidy(self, safety_buffer="72h", disable_periodic_tidy=False,
def configure_identity_whitelist_tidy(self, safety_buffer=None, disable_periodic_tidy=None,
mount_point=AWS_DEFAULT_MOUNT_POINT):
"""Configures the periodic tidying operation of the whitelisted identity entries
Expand All @@ -342,10 +340,10 @@ def configure_identity_whitelist_tidy(self, safety_buffer="72h", disable_periodi
:return:
"""
api_path = utils.format_url('/v1/auth/{mount_point}/config/tidy/identity-whitelist', mount_point=mount_point)
params = {
params = utils.remove_nones({
'safety_buffer': safety_buffer,
'disable_periodic_tidy': disable_periodic_tidy,
}
})
return self._adapter.post(
url=api_path,
json=params,
Expand Down Expand Up @@ -374,7 +372,7 @@ def delete_identity_whitelist_tidy(self, mount_point=AWS_DEFAULT_MOUNT_POINT):
url=api_path,
)

def configure_role_tag_blacklist_tidy(self, safety_buffer='72h', disable_periodic_tidy=False,
def configure_role_tag_blacklist_tidy(self, safety_buffer=None, disable_periodic_tidy=None,
mount_point=AWS_DEFAULT_MOUNT_POINT):
"""Configures the periodic tidying operation of the blacklisted role tag entries
Expand All @@ -384,10 +382,10 @@ def configure_role_tag_blacklist_tidy(self, safety_buffer='72h', disable_periodi
:return:
"""
api_path = utils.format_url('/v1/auth/{mount_point}/config/tidy/roletag-blacklist', mount_point=mount_point)
params = {
params = utils.remove_nones({
'safety_buffer': safety_buffer,
'disable_periodic_tidy': disable_periodic_tidy,
}
})
return self._adapter.post(
url=api_path,
json=params,
Expand Down Expand Up @@ -416,7 +414,7 @@ def delete_role_tag_blacklist_tidy(self, mount_point=AWS_DEFAULT_MOUNT_POINT):
url=api_path
)

def create_role(self, role, auth_type="iam", bound_ami_id=None, bound_account_id=None,
def create_role(self, role, auth_type=None, bound_ami_id=None, bound_account_id=None,
bound_region=None, bound_vpc_id=None, bound_subnet_id=None, bound_iam_role_arn=None,
bound_iam_instance_profile_arn=None, bound_ec2_instance_id=None, role_tag=None,
bound_iam_principal_arn=None, inferred_entity_type=None, inferred_aws_region=None,
Expand Down Expand Up @@ -460,46 +458,31 @@ def create_role(self, role, auth_type="iam", bound_ami_id=None, bound_account_id
api_path = utils.format_url('/v1/auth/{0}/role/{1}', mount_point, role)
params = {
'role': role,
'auth_type': auth_type,
'resolve_aws_unique_ids': resolve_aws_unique_ids,
}
if bound_ami_id is not None:
params['bound_ami_id'] = bound_ami_id
if bound_account_id is not None:
params['bound_account_id'] = bound_account_id
if bound_region is not None:
params['bound_region'] = bound_region
if bound_vpc_id is not None:
params['bound_vpc_id'] = bound_vpc_id
if bound_subnet_id is not None:
params['bound_subnet_id'] = bound_subnet_id
if bound_iam_role_arn is not None:
params['bound_iam_role_arn'] = bound_iam_role_arn
if bound_iam_instance_profile_arn is not None:
params['bound_iam_instance_profile_arn'] = bound_iam_instance_profile_arn
if bound_ec2_instance_id is not None:
params['bound_ec2_instance_id'] = bound_ec2_instance_id
if role_tag is not None:
params['role_tag'] = role_tag
if bound_iam_principal_arn is not None:
params['bound_iam_principal_arn'] = bound_iam_principal_arn
if inferred_entity_type is not None:
params['inferred_entity_type'] = inferred_entity_type
if inferred_aws_region is not None:
params['inferred_aws_region'] = inferred_aws_region
if ttl is not None:
params['ttl'] = ttl
if max_ttl is not None:
params['max_ttl'] = max_ttl
if period is not None:
params['period'] = period
if policies is not None:
params['policies'] = policies
if allow_instance_migration is not None:
params['allow_instance_migration'] = allow_instance_migration
if disallow_reauthentication is not None:
params['disallow_reauthentication'] = disallow_reauthentication

params.update(
utils.remove_nones({
'auth_type': auth_type,
'resolve_aws_unique_ids': resolve_aws_unique_ids,
'bound_ami_id': bound_ami_id,
'bound_account_id': bound_account_id,
'bound_region': bound_region,
'bound_vpc_id': bound_vpc_id,
'bound_subnet_id': bound_subnet_id,
'bound_iam_role_arn': bound_iam_role_arn,
'bound_iam_instance_profile_arn': bound_iam_instance_profile_arn,
'bound_ec2_instance_id': bound_ec2_instance_id,
'role_tag': role_tag,
'bound_iam_principal_arn': bound_iam_principal_arn,
'inferred_entity_type': inferred_entity_type,
'inferred_aws_region': inferred_aws_region,
'ttl': ttl,
'max_ttl': max_ttl,
'period': period,
'policies': policies,
'allow_instance_migration': allow_instance_migration,
'disallow_reauthentication': disallow_reauthentication,
})
)
return self._adapter.post(
url=api_path,
json=params,
Expand Down Expand Up @@ -543,7 +526,7 @@ def delete_role(self, role, mount_point=AWS_DEFAULT_MOUNT_POINT):
)

def create_role_tags(self, role, policies=None, max_ttl=None, instance_id=None, allow_instance_migration=None,
disallow_reauthentication=False, mount_point=AWS_DEFAULT_MOUNT_POINT):
disallow_reauthentication=None, mount_point=AWS_DEFAULT_MOUNT_POINT):
"""Creates a role tag on the role, which helps in restricting the capabilities that are set on the role.
Role tags are not tied to any specific ec2 instance unless specified explicitly using the instance_id parameter
Expand All @@ -569,17 +552,13 @@ def create_role_tags(self, role, policies=None, max_ttl=None, instance_id=None,
"""
api_path = utils.format_url('/v1/auth/{0}/role/{1}/tag', mount_point, role)

params = {
params = utils.remove_nones({
'disallow_reauthentication': disallow_reauthentication,
}
if policies is not None:
params['policies'] = policies
if max_ttl is not None:
params['max_ttl'] = max_ttl
if instance_id is not None:
params['instance_id'] = instance_id
if allow_instance_migration is not None:
params['allow_instance_migration'] = allow_instance_migration
'policies': policies,
'max_ttl': max_ttl,
'instance_id': instance_id,
'allow_instance_migration': allow_instance_migration,
})

return self._adapter.post(
url=api_path,
Expand Down
56 changes: 30 additions & 26 deletions hvac/api/auth_methods/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Azure(VaultApiBase):
Reference: https://www.vaultproject.io/api/auth/azure/index.html
"""

def configure(self, tenant_id, resource, environment='AzurePublicCloud', client_id=None, client_secret=None,
def configure(self, tenant_id, resource, environment=None, client_id=None, client_secret=None,
mount_point=DEFAULT_MOUNT_POINT):
"""Configure the credentials required for the plugin to perform API calls to Azure.
Expand All @@ -43,7 +43,7 @@ def configure(self, tenant_id, resource, environment='AzurePublicCloud', client_
:return: The response of the request.
:rtype: requests.Response
"""
if environment not in VALID_ENVIRONMENTS:
if environment is not None and environment not in VALID_ENVIRONMENTS:
error_msg = 'invalid environment argument provided: "{arg}"; supported environments: "{environments}"'
raise exceptions.ParamValidationError(error_msg.format(
arg=environment,
Expand All @@ -52,12 +52,14 @@ def configure(self, tenant_id, resource, environment='AzurePublicCloud', client_
params = {
'tenant_id': tenant_id,
'resource': resource,
'environment': environment,
}
if client_id is not None:
params['client_id'] = client_id
if client_secret is not None:
params['client_secret'] = client_secret
params.update(
utils.remove_nones({
'environment': environment,
'client_id': client_id,
'client_secret': client_secret,
})
)
api_path = utils.format_url('/v1/auth/{mount_point}/config', mount_point=mount_point)
return self._adapter.post(
url=api_path,
Expand Down Expand Up @@ -112,7 +114,7 @@ def create_role(self, name, policies=None, ttl=None, max_ttl=None, period=None,
:param name: Name of the role.
:type name: str | unicode
:param policies: Policies to be set on tokens issued using this role.
:type policies: list
:type policies: str | list
:param num_uses: Number of uses to set on a token produced by this role.
:type num_uses: int
:param ttl: The TTL period of tokens issued using this role in seconds.
Expand Down Expand Up @@ -140,15 +142,17 @@ def create_role(self, name, policies=None, ttl=None, max_ttl=None, period=None,
:return: The response of the request.
:rtype: requests.Response
"""
if policies is None:
policies = []
if not isinstance(policies, list) or not all([isinstance(p, str) for p in policies]):
error_msg = 'unsupported policies argument provided "{arg}" ({arg_type}), required type: List[str]"'
raise exceptions.ParamValidationError(error_msg.format(
arg=policies,
arg_type=type(policies),
))
params = {
if policies is not None:
if not (
isinstance(policies, str)
or (isinstance(policies, list) and all([isinstance(p, str) for p in policies]))
):
error_msg = 'unsupported policies argument provided "{arg}" ({arg_type}), required type: str or List[str]"'
raise exceptions.ParamValidationError(error_msg.format(
arg=policies,
arg_type=type(policies),
))
params = utils.remove_nones({
'policies': policies,
'ttl': ttl,
'max_ttl': max_ttl,
Expand All @@ -160,7 +164,7 @@ def create_role(self, name, policies=None, ttl=None, max_ttl=None, period=None,
'bound_resource_groups': bound_resource_groups,
'bound_scale_sets': bound_scale_sets,
'num_uses': num_uses,
}
})

api_path = utils.format_url('/v1/auth/{mount_point}/role/{name}', mount_point=mount_point, name=name)
return self._adapter.post(
Expand Down Expand Up @@ -272,14 +276,14 @@ def login(self, role, jwt, subscription_id=None, resource_group_name=None, vm_na
'role': role,
'jwt': jwt,
}
if subscription_id is not None:
params['subscription_id'] = subscription_id
if resource_group_name is not None:
params['resource_group_name'] = resource_group_name
if vm_name is not None:
params['vm_name'] = vm_name
if vmss_name is not None:
params['vmss_name'] = vmss_name
params.update(
utils.remove_nones({
'subscription_id': subscription_id,
'resource_group_name': resource_group_name,
'vm_name': vm_name,
'vmss_name': vmss_name,
})
)
api_path = utils.format_url('/v1/auth/{mount_point}/login', mount_point=mount_point)
response = self._adapter.login(
url=api_path,
Expand Down
Loading

0 comments on commit 09e0702

Please sign in to comment.