Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix "Unsafe" token values that originate from lookups #291

Merged
merged 7 commits into from
Aug 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelogs/fragments/289-handle-unsafe-strings.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
bugfixes:
- community.hashi_vault plugins - tokens will be cast to a string type before being sent to ``hvac`` to prevent errors in ``requests`` when values are ``AnsibleUnsafe`` (https://github.com/ansible-collections/community.hashi_vault/issues/289).
2 changes: 1 addition & 1 deletion plugins/module_utils/_auth_method_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def validate(self):
raise HashiVaultValueError("No Vault Token specified or discovered.")

def authenticate(self, client, use_token=True, lookup_self=False):
token = self._options.get_option('token')
token = self._stringify(self._options.get_option('token'))
validate = self._options.get_option_default('token_validate')

response = None
Expand Down
52 changes: 51 additions & 1 deletion plugins/module_utils/_hashi_vault_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,55 @@
HAS_HVAC = False


def _stringify(input):
'''
This method is primarily used to Un-Unsafe values that come from Ansible.
We want to remove the Unsafe context so that libraries don't get confused
by the values.
'''

# Since this is a module_util, and will be used by both plugins and modules,
# we cannot import the AnsibleUnsafe* types, because they are controller-only.
# However, they subclass the native types, so we can check for that.

# bytes is the only consistent type to check against in both py2 and py3
if isinstance(input, bytes):
# seems redundant, but this will give us a regular bytes object even
# when the input is AnsibleUnsafeBytes
return bytes(input)
else:
# instead of checking for py2 vs. py3 to cast to str or unicode,
# let's get the type from the literal.
return type(u'')(input)


class HashiVaultValueError(ValueError):
'''Use in common code to raise an Exception that can be turned into AnsibleError or used to fail_json()'''


class HashiVaultHelper():

STRINGIFY_CANDIDATES = set(
'token', # Token will end up in a header, requests requires headers to be str or bytes,
# and newer versions of requests stopped converting automatically. Because our
# token could have been passed in from a previous lookup call, it could be one
# of the AnsibleUnsafe types instead, causing a failure. Tokens should always
# be strings, so we will convert them.
)

def __init__(self):
# TODO move hvac checking here?
pass

def get_vault_client(self, hashi_vault_logout_inferred_token=True, hashi_vault_revoke_on_logout=False, **kwargs):
@staticmethod
def _stringify(input):
return _stringify(input)

def get_vault_client(
self,
hashi_vault_logout_inferred_token=True, hashi_vault_revoke_on_logout=False, hashi_vault_stringify_args=True,
**kwargs
):
'''
creates a Vault client with the given kwargs

Expand All @@ -45,8 +83,16 @@ def get_vault_client(self, hashi_vault_logout_inferred_token=True, hashi_vault_r

:param hashi_vault_revoke_on_logout: if True revokes any current token on logout. Only used if a logout is performed. Not recommended.
:type hashi_vault_revoke_on_logout: bool

:param hashi_vault_stringify_args: if True converts a specific set of defined kwargs to a string type.
:type hashi_vault_stringify_args: bool
'''

if hashi_vault_stringify_args:
for key in kwargs.keys():
if key in self.STRINGIFY_CANDIDATES:
kwargs[key] = self._stringify(kwargs[key])

client = hvac.Client(**kwargs)

# logout to prevent accidental use of inferred tokens
Expand Down Expand Up @@ -249,3 +295,7 @@ def warn(self, message):

def deprecate(self, message, version=None, date=None, collection_name=None):
self._deprecator(message, version=version, date=date, collection_name=collection_name)

@staticmethod
def _stringify(input):
return _stringify(input)
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
HashiVaultAuthMethodBase,
HashiVaultOptionGroupBase,
HashiVaultValueError,
_stringify,
)


Expand Down Expand Up @@ -74,3 +75,12 @@ def test_deprecate_callback(self, auth_base, deprecator, version, date, collecti
auth_base.deprecate(msg, version, date, collection_name)

deprecator.assert_called_once_with(msg, version=version, date=date, collection_name=collection_name)

def test_has_stringify(self, auth_base):
v = 'X'
wrapper = mock.Mock(wraps=_stringify)
with mock.patch('ansible_collections.community.hashi_vault.plugins.module_utils._hashi_vault_common._stringify', wrapper):
r = auth_base._stringify(v)

wrapper.assert_called_once_with(v)
assert r == v
14 changes: 13 additions & 1 deletion tests/unit/plugins/module_utils/test_hashi_vault_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import pytest

from ansible_collections.community.hashi_vault.tests.unit.compat import mock
from ansible_collections.community.hashi_vault.plugins.module_utils._hashi_vault_common import HashiVaultHelper
from ansible_collections.community.hashi_vault.plugins.module_utils._hashi_vault_common import (
HashiVaultHelper,
_stringify,
)


@pytest.fixture
Expand Down Expand Up @@ -46,3 +49,12 @@ def test_get_vault_client_with_logout_implicit_token(self, hashi_vault_helper, v
client = hashi_vault_helper.get_vault_client(hashi_vault_logout_inferred_token=True)

assert client.token is None

def test_has_stringify(self, hashi_vault_helper):
v = 'X'
wrapper = mock.Mock(wraps=_stringify)
with mock.patch('ansible_collections.community.hashi_vault.plugins.module_utils._hashi_vault_common._stringify', wrapper):
r = hashi_vault_helper._stringify(v)

wrapper.assert_called_once_with(v)
assert r == v, '%r != %r' % (r, v)
10 changes: 10 additions & 0 deletions tests/unit/plugins/plugin_utils/authentication/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2022 Brian Scholer (@briantist)
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later

from __future__ import (absolute_import, division, print_function)
__metaclass__ = type

# pylint: disable=wildcard-import,unused-wildcard-import
from ...module_utils.authentication.conftest import *
55 changes: 55 additions & 0 deletions tests/unit/plugins/plugin_utils/authentication/test_auth_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2022 Brian Scholer (@briantist)
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later

from __future__ import (absolute_import, division, print_function)
__metaclass__ = type

import pytest

from ansible.utils.unsafe_proxy import AnsibleUnsafe, AnsibleUnsafeBytes, AnsibleUnsafeText

from ansible_collections.community.hashi_vault.tests.unit.compat import mock

from ansible_collections.community.hashi_vault.plugins.module_utils._auth_method_token import (
HashiVaultAuthMethodToken,
)


@pytest.fixture
def option_dict():
return {
'auth_method': 'fake',
'token': None,
'token_path': None,
'token_file': '.vault-token',
'token_validate': True,
}


@pytest.fixture(params=[AnsibleUnsafeBytes(b'ub_opaque'), AnsibleUnsafeText(u'ut_opaque'), b'b_opaque', u't_opaque'])
def token(request):
return request.param


@pytest.fixture
def auth_token(adapter, warner, deprecator):
return HashiVaultAuthMethodToken(adapter, warner, deprecator)


class TestAuthToken(object):

def test_auth_token_unsafes(self, auth_token, client, adapter, token):
adapter.set_option('token', token)
adapter.set_option('token_validate', False)

wrapper = mock.Mock(wraps=auth_token._stringify)

with mock.patch.object(auth_token, '_stringify', wrapper):
response = auth_token.authenticate(client, use_token=True, lookup_self=False)

assert isinstance(response['auth']['client_token'], (bytes, type(u''))), repr(response['auth']['client_token'])
assert isinstance(client.token, (bytes, type(u''))), repr(client.token)
assert not isinstance(response['auth']['client_token'], AnsibleUnsafe), repr(response['auth']['client_token'])
assert not isinstance(client.token, AnsibleUnsafe), repr(client.token)
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2022 Brian Scholer (@briantist)
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later

from __future__ import (absolute_import, division, print_function)
__metaclass__ = type

import pytest

from ansible.utils.unsafe_proxy import AnsibleUnsafe, AnsibleUnsafeBytes, AnsibleUnsafeText

from ansible_collections.community.hashi_vault.plugins.module_utils._hashi_vault_common import _stringify


@pytest.fixture
def uvalue():
return u'fake123'


@pytest.fixture
def bvalue():
return b'fake456'


class TestHashiVaultCommonStringify(object):
@pytest.mark.parametrize('unsafe', [True, False])
def test_stringify_bytes(self, unsafe, bvalue):
token = bvalue
if unsafe:
token = AnsibleUnsafeBytes(token)

r = _stringify(token)

assert isinstance(r, bytes)
assert not isinstance(r, AnsibleUnsafe)

@pytest.mark.parametrize('unsafe', [True, False])
def test_stringify_unicode(self, unsafe, uvalue):
token = uvalue
utype = type(token)
if unsafe:
token = AnsibleUnsafeText(token)

r = _stringify(token)

assert isinstance(r, utype)
assert not isinstance(r, AnsibleUnsafe)
57 changes: 57 additions & 0 deletions tests/unit/plugins/plugin_utils/test_hashi_vault_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2022 Brian Scholer (@briantist)
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later

from __future__ import (absolute_import, division, print_function)
__metaclass__ = type

import pytest

from ansible.utils.unsafe_proxy import AnsibleUnsafe, AnsibleUnsafeBytes, AnsibleUnsafeText

from ansible_collections.community.hashi_vault.tests.unit.compat import mock
from ansible_collections.community.hashi_vault.plugins.module_utils._hashi_vault_common import HashiVaultHelper


@pytest.fixture
def hashi_vault_helper():
return HashiVaultHelper()


@pytest.fixture
def expected_stringify_candidates():
return set(
'token',
)


class TestHashiVaultHelper(object):
def test_expected_stringify_candidates(self, hashi_vault_helper, expected_stringify_candidates):
# If we add more candidates to the set without updating the tests,
# this will help us catch that. The purpose is not to simply update
# the set in the fixture, but to also add specific tests where appropriate.
assert hashi_vault_helper.STRINGIFY_CANDIDATES == expected_stringify_candidates, '%r' % (
hashi_vault_helper.STRINGIFY_CANDIDATES ^ expected_stringify_candidates
)

@pytest.mark.parametrize('input', [b'one', u'two', AnsibleUnsafeBytes(b'three'), AnsibleUnsafeText(u'four')])
@pytest.mark.parametrize('stringify', [True, False])
def test_get_vault_client_stringify(self, hashi_vault_helper, expected_stringify_candidates, input, stringify):
kwargs = {
'__no_candidate': AnsibleUnsafeText(u'value'),
}
expected_calls = []
for k in expected_stringify_candidates:
v = '%s_%s' % (k, input)
kwargs[k] = v
if stringify:
expected_calls.append(mock.call(v))

wrapper = mock.Mock(wraps=hashi_vault_helper._stringify)
with mock.patch('hvac.Client'):
with mock.patch.object(hashi_vault_helper, '_stringify', wrapper):
hashi_vault_helper.get_vault_client(hashi_vault_stringify_args=stringify, **kwargs)

assert wrapper.call_count == len(expected_calls)
wrapper.assert_has_calls(expected_calls)