Skip to content

Commit

Permalink
extend sigv4 to accept aws_session_token
Browse files Browse the repository at this point in the history
Signed-off-by: Chinmay Gadgil <chinmay5j@gmail.com>
  • Loading branch information
cgchinmay committed Oct 16, 2023
1 parent b213ce5 commit 328f352
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 47 deletions.
15 changes: 12 additions & 3 deletions osbenchmark/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def __init__(self, hosts, client_options):
self.aws_log_in_dict = self.parse_aws_log_in_params()
masked_client_options["aws_access_key_id"] = "*****"
masked_client_options["aws_secret_access_key"] = "*****"
# session_token is optional and used only for role based access
if self.aws_log_in_dict.get("aws_session_token", None):
masked_client_options["aws_session_token"] = "*****"
self.logger.info("Creating OpenSearch client connected to %s with options [%s]", hosts, masked_client_options)

# we're using an SSL context now and it is not allowed to have use_ssl present in client options anymore
Expand Down Expand Up @@ -206,7 +209,7 @@ def __init__(self, hosts, client_options):
self.logger.info("HTTP basic authentication: off")

if self._is_set(self.client_options, "compressed"):
console.warn("You set the deprecated client option 'compressed. Please use 'http_compress' instead.", logger=self.logger)
console.warn("You set the deprecated client option 'compressed'. Please use 'http_compress' instead.", logger=self.logger)
self.client_options["http_compress"] = self.client_options.pop("compressed")

if self._is_set(self.client_options, "http_compress"):
Expand Down Expand Up @@ -251,12 +254,16 @@ def parse_aws_log_in_params(self):
aws_log_in_dict["aws_secret_access_key"] = os.environ.get("OSB_AWS_SECRET_ACCESS_KEY")
aws_log_in_dict["region"] = os.environ.get("OSB_REGION")
aws_log_in_dict["service"] = os.environ.get("OSB_SERVICE")
# optional: applicable only for role-based access
aws_log_in_dict["aws_session_token"] = os.environ.get("OSB_AWS_SESSION_TOKEN")
# aws log in : option 2) parameters are passed in from command line
elif self.client_options["amazon_aws_log_in"] == "client_option":
aws_log_in_dict["aws_access_key_id"] = self.client_options.get("aws_access_key_id")
aws_log_in_dict["aws_secret_access_key"] = self.client_options.get("aws_secret_access_key")
aws_log_in_dict["region"] = self.client_options.get("region")
aws_log_in_dict["service"] = self.client_options.get("service")
# optional: applicable only for role-based access
aws_log_in_dict["aws_session_token"] = self.client_options.get("aws_session_token")
if (not aws_log_in_dict["aws_access_key_id"] or not aws_log_in_dict["aws_secret_access_key"]
or not aws_log_in_dict["service"] or not aws_log_in_dict["region"]):
self.logger.error("Invalid amazon aws log in parameters, required input aws_access_key_id, "
Expand All @@ -282,7 +289,8 @@ def create(self):
return opensearchpy.OpenSearch(hosts=self.hosts, ssl_context=self.ssl_context, **self.client_options)

credentials = Credentials(access_key=self.aws_log_in_dict["aws_access_key_id"],
secret_key=self.aws_log_in_dict["aws_secret_access_key"])
secret_key=self.aws_log_in_dict["aws_secret_access_key"],
token=self.aws_log_in_dict["aws_session_token"])
aws_auth = opensearchpy.AWSV4SignerAuth(credentials, self.aws_log_in_dict["region"],
self.aws_log_in_dict["service"])
return opensearchpy.OpenSearch(hosts=self.hosts, use_ssl=True, verify_certs=True, http_auth=aws_auth,
Expand Down Expand Up @@ -332,7 +340,8 @@ class BenchmarkAsyncOpenSearch(opensearchpy.AsyncOpenSearch, RequestContextHolde
**self.client_options)

credentials = Credentials(access_key=self.aws_log_in_dict["aws_access_key_id"],
secret_key=self.aws_log_in_dict["aws_secret_access_key"])
secret_key=self.aws_log_in_dict["aws_secret_access_key"],
token=self.aws_log_in_dict["aws_session_token"])
aws_auth = opensearchpy.AWSV4SignerAsyncAuth(credentials, self.aws_log_in_dict["region"],
self.aws_log_in_dict["service"])
return BenchmarkAsyncOpenSearch(hosts=self.hosts,
Expand Down
10 changes: 9 additions & 1 deletion osbenchmark/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def __init__(self, cfg):
default_value=None, mandatory=False)
metrics_aws_access_key_id = None
metrics_aws_secret_access_key = None
metrics_aws_session_token = None
metrics_aws_region = None
metrics_aws_service = None

Expand All @@ -196,13 +197,16 @@ def __init__(self, cfg):
default_value=None, mandatory=False)
metrics_aws_secret_access_key = self._config.opts("results_publishing", "datastore.aws_secret_access_key",
default_value=None, mandatory=False)
metrics_aws_session_token = self._config.opts("results_publishing", "datastore.aws_session_token",
default_value=None, mandatory=False)
metrics_aws_region = self._config.opts("results_publishing", "datastore.region",
default_value=None, mandatory=False)
metrics_aws_service = self._config.opts("results_publishing", "datastore.service",
default_value=None, mandatory=False)
elif metrics_amazon_aws_log_in == 'environment':
metrics_aws_access_key_id = os.getenv("OSB_DATASTORE_AWS_ACCESS_KEY_ID", default=None)
metrics_aws_secret_access_key = os.getenv("OSB_DATASTORE_AWS_SECRET_ACCESS_KEY", default=None)
metrics_aws_session_token = os.getenv("OSB_DATASTORE_AWS_SESSION_TOKEN", default=None)
metrics_aws_region = os.getenv("OSB_DATASTORE_REGION", default=None)
metrics_aws_service = os.getenv("OSB_DATASTORE_SERVICE", default=None)

Expand Down Expand Up @@ -254,14 +258,18 @@ def __init__(self, cfg):
client_options["basic_auth_user"] = user
client_options["basic_auth_password"] = password

#add options for aws user login: pass in aws access key id, aws secret access key, service and region on command
# add options for aws user login:
# pass in aws access key id, aws secret access key, aws session token, service and region on command
if metrics_amazon_aws_log_in is not None:
client_options["amazon_aws_log_in"] = 'client_option'
client_options["aws_access_key_id"] = metrics_aws_access_key_id
client_options["aws_secret_access_key"] = metrics_aws_secret_access_key
client_options["service"] = metrics_aws_service
client_options["region"] = metrics_aws_region

if metrics_aws_session_token:
client_options["aws_session_token"] = metrics_aws_session_token

factory = client.OsClientFactory(hosts=[{"host": host, "port": port}], client_options=client_options)
self._client = factory.create()

Expand Down
62 changes: 39 additions & 23 deletions tests/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_create_https_connection_unverified_certificate(self, mocked_load_cert_c
@mock.patch.object(ssl.SSLContext, "load_cert_chain")
def test_create_https_connection_with_aws_creds(self, mocked_load_cert_chain):
hosts = [{"host": "localhost", "port": 9200}]
client_options = {
user_based_client_options = {
"use_ssl": True,
"timeout": 120,
"amazon_aws_log_in": 'client_option',
Expand All @@ -259,34 +259,50 @@ def test_create_https_connection_with_aws_creds(self, mocked_load_cert_chain):
"region": "us-east-1",
"verify_certs": True
}
# make a copy so we can verify later that the factory did not modify it
original_client_options = dict(client_options)

role_based_client_options = dict(user_based_client_options)
role_based_client_options["aws_session_token"] = "dummy_token"

client_options_list = [
user_based_client_options,
role_based_client_options
]

logger = logging.getLogger("osbenchmark.client")
with mock.patch.object(logger, "info") as mocked_info_logger:
f = client.OsClientFactory(hosts, client_options)
mocked_info_logger.assert_has_calls([
mock.call("SSL support: on"),
mock.call("SSL certificate verification: on"),
mock.call("SSL client authentication: off")
])

assert not mocked_load_cert_chain.called, "ssl_context.load_cert_chain should not have been called as we have not supplied " \
"client certs"
for client_options in client_options_list:
# make a copy so we can verify later that the factory did not modify it
original_client_options = dict(client_options)

self.assertEqual(hosts, f.hosts)
self.assertTrue(f.ssl_context.check_hostname)
self.assertEqual(ssl.CERT_REQUIRED, f.ssl_context.verify_mode)
with mock.patch.object(logger, "info") as mocked_info_logger:
f = client.OsClientFactory(hosts, client_options)

self.assertEqual("https", f.client_options["scheme"])
self.assertIn("timeout", f.client_options)
self.assertIn("aws_access_key_id", f.client_options)
self.assertIn("aws_secret_access_key", f.client_options)
self.assertIn("amazon_aws_log_in", f.client_options)
self.assertIn("service", f.client_options)
self.assertIn("region", f.client_options)
mocked_info_logger.assert_has_calls([
mock.call("SSL support: on"),
mock.call("SSL certificate verification: on"),
mock.call("SSL client authentication: off")
])

assert not mocked_load_cert_chain.called, "ssl_context.load_cert_chain should not have been called as we have not supplied " \
"client certs"

self.assertEqual(hosts, f.hosts)
self.assertTrue(f.ssl_context.check_hostname)
self.assertEqual(ssl.CERT_REQUIRED, f.ssl_context.verify_mode)

self.assertEqual("https", f.client_options["scheme"])
self.assertIn("timeout", f.client_options)
self.assertIn("aws_access_key_id", f.client_options)
self.assertIn("aws_secret_access_key", f.client_options)
self.assertIn("amazon_aws_log_in", f.client_options)
self.assertIn("service", f.client_options)
self.assertIn("region", f.client_options)

if "aws_session_token" in original_client_options:
self.assertIn("aws_session_token", f.client_options)

self.assertDictEqual(original_client_options, client_options)

self.assertDictEqual(original_client_options, client_options)

@mock.patch.object(ssl.SSLContext, "load_cert_chain")
def test_create_https_connection_unverified_certificate_present_client_certificates(self, mocked_load_cert_chain):
Expand Down
65 changes: 45 additions & 20 deletions tests/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
from osbenchmark.metrics import GlobalStatsCalculator
from osbenchmark.workload import Task, Operation, TestProcedure, Workload

AWS_ACCESS_KEY_ID_LENGTH = 12
AWS_SECRET_ACCESS_KEY_LENGTH = 40
AWS_SESSION_TOKEN_LENGTH = 752

class MockClientFactory:
def __init__(self, cfg):
Expand Down Expand Up @@ -239,24 +242,33 @@ def test_config_opts_parsing_aws_creds_with_env(self, client_OsClientfactory):
}
self.config_opts_parsing_aws_creds("environment", override_datastore=override_config)

# verify config parsing is successful when all required parameters are present
config_opts = self.config_opts_parsing_aws_creds("environment")

expected_client_options = {
"use_ssl": True,
"timeout": 120,
"amazon_aws_log_in": 'client_option',
"aws_access_key_id": config_opts["_datastore_aws_access_key_id"],
"aws_secret_access_key": config_opts["_datastore_aws_secret_access_key"],
"service": config_opts["_datastore_aws_service"],
"region": config_opts["_datastore_aws_region"],
"verify_certs": config_opts["_datastore_verify_certs"]
}

client_OsClientfactory.assert_called_with(
hosts=[{"host": config_opts["_datastore_host"], "port": config_opts["_datastore_port"]}],
client_options=expected_client_options
)
# validate client_options when session_token is passed
enable_role_access = [False, True]
for role_based in enable_role_access:
# verify config parsing is successful when all required parameters are present
config_opts = self.config_opts_parsing_aws_creds("environment", role_based=role_based)

expected_client_options = {
"use_ssl": True,
"timeout": 120,
"amazon_aws_log_in": 'client_option',
"aws_access_key_id": config_opts["_datastore_aws_access_key_id"],
"aws_secret_access_key": config_opts["_datastore_aws_secret_access_key"],
"service": config_opts["_datastore_aws_service"],
"region": config_opts["_datastore_aws_region"],
"verify_certs": config_opts["_datastore_verify_certs"]
}

if role_based:
expected_client_options["aws_session_token"] = config_opts["_datastore_aws_session_token"]

client_OsClientfactory.assert_called_with(
hosts=[{"host": config_opts["_datastore_host"], "port": config_opts["_datastore_port"]}],
client_options=expected_client_options
)


def config_opts_parsing(self, password_configuration):
cfg = config.Config()
Expand Down Expand Up @@ -302,7 +314,7 @@ def config_opts_parsing(self, password_configuration):
"_datastore_verify_certs": _datastore_verify_certs
}

def config_opts_parsing_aws_creds(self, configuration_source, override_datastore=None):
def config_opts_parsing_aws_creds(self, configuration_source, override_datastore=None, role_based=False):
if override_datastore is None:
override_datastore = {}
cfg = config.Config()
Expand All @@ -314,11 +326,14 @@ def config_opts_parsing_aws_creds(self, configuration_source, override_datastore
_datastore_password = ""
_datastore_verify_certs = random.choice([True, False])
_datastore_amazon_aws_log_in = configuration_source
_datastore_aws_access_key_id = "".join([random.choice(string.digits) for _ in range(12)])
_datastore_aws_secret_access_key = "".join([random.choice(string.ascii_letters + string.digits) for _ in range(40)])
_datastore_aws_access_key_id = "".join([random.choice(string.digits) for _ in range(AWS_ACCESS_KEY_ID_LENGTH)])
_datastore_aws_secret_access_key = "".join([random.choice(string.ascii_letters + string.digits) for _ in range(AWS_SECRET_ACCESS_KEY_LENGTH)])
_datastore_aws_service = random.choice(['es', 'aoss'])
_datastore_aws_region = random.choice(['us-east-1', 'eu-west-1'])

# optional
_datastore_aws_session_token = "".join([random.choice(string.ascii_letters + string.digits) for _ in range(AWS_SESSION_TOKEN_LENGTH)])

cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.host", _datastore_host)
cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.port", _datastore_port)
cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.secure", _datastore_secure)
Expand All @@ -334,12 +349,17 @@ def config_opts_parsing_aws_creds(self, configuration_source, override_datastore
_datastore_aws_secret_access_key)
cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.service", _datastore_aws_service)
cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.region", _datastore_aws_region)
if role_based:
cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.aws_session_token", _datastore_aws_session_token)
elif _datastore_amazon_aws_log_in == 'environment':
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setenv("OSB_DATASTORE_AWS_ACCESS_KEY_ID", _datastore_aws_access_key_id)
monkeypatch.setenv("OSB_DATASTORE_AWS_SECRET_ACCESS_KEY", _datastore_aws_secret_access_key)
monkeypatch.setenv("OSB_DATASTORE_SERVICE", _datastore_aws_service)
monkeypatch.setenv("OSB_DATASTORE_REGION", _datastore_aws_region)
if role_based:
monkeypatch.setenv("OSB_DATASTORE_AWS_SESSION_TOKEN", _datastore_aws_session_token)


if not _datastore_verify_certs:
cfg.add(config.Scope.applicationOverride, "results_publishing", "datastore.ssl.verification_mode", "none")
Expand Down Expand Up @@ -375,7 +395,7 @@ def config_opts_parsing_aws_creds(self, configuration_source, override_datastore
assert e.message == missing_aws_credentials_message
return

return {
response = {
"_datastore_user": _datastore_user,
"_datastore_host": _datastore_host,
"_datastore_password": _datastore_password,
Expand All @@ -387,6 +407,11 @@ def config_opts_parsing_aws_creds(self, configuration_source, override_datastore
"_datastore_aws_region": _datastore_aws_region
}

if role_based:
response["_datastore_aws_session_token"] = _datastore_aws_session_token

return response

def test_raises_sytem_setup_error_on_connection_problems(self):
def raise_connection_error():
raise opensearchpy.exceptions.ConnectionError("unit-test")
Expand Down

0 comments on commit 328f352

Please sign in to comment.