Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

testDocumentGenerator fetches endpoint and auth from config file #1300

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import argparse
from datetime import datetime
import urllib3
import os
from collections import deque
import logging
import json
import yaml
from requests_aws4auth import AWS4Auth
from requests.auth import HTTPBasicAuth
import boto3

# Disable InsecureRequestWarning
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
Expand Down Expand Up @@ -59,16 +62,15 @@ def send_multi_type_request(session, index_name, type_name, payload, url_base, a
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--endpoint", help="Cluster endpoint e.g. http://test.elb.us-west-2.amazonaws.com:9200.")
parser.add_argument("--username", help="Cluster username.")
parser.add_argument("--password", help="Cluster password.")
# Removed endpoint, username, password
parser.add_argument("--enable_multi_type", action='store_true',
help="Flag to enable sending documents to a multi-type index.")
parser.add_argument("--no-clear-output", action='store_true',
help="Flag to not clear the output before each run. " +
"Helpful for piping to a file or other utility.")
parser.add_argument("--requests-per-sec", type=float, default=10.0, help="Target requests per second to be sent.")
parser.add_argument("--no-refresh", action='store_true', help="Flag to disable refresh after each request.")
parser.add_argument("--target", type=str, default="source", help="Specify 'source' or 'target' cluster.")
return parser.parse_args()


Expand Down Expand Up @@ -121,11 +123,77 @@ def calculate_sleep_time(request_timestamps, target_requests_per_sec):
return max(0, sleep_time)


def load_config(yaml_path="/etc/migration_services.yaml"):
"""Loads the configuration from the specified YAML file."""
try:
with open(yaml_path, 'r') as f:
config = yaml.safe_load(f)
return config
except FileNotFoundError:
logger.error(f"Configuration file not found: {yaml_path}")
sys.exit(1)
except yaml.YAMLError as e:
logger.error(f"Error parsing YAML file: {e}")
sys.exit(1)


def get_cluster_config(config, cluster_type="source"):
"""Extracts cluster configuration from the loaded YAML."""
cluster_config = config.get(f"{cluster_type}_cluster")
if not cluster_config:
logger.error(f"Cluster configuration not found for type: {cluster_type}")
sys.exit(1)
return cluster_config


def get_auth(cluster_config):
"""Determine authentication method and return appropriate auth object."""
if 'sigv4' in cluster_config:
region = cluster_config['sigv4']['region']
service = cluster_config['sigv4']['service']

# boto3 to get session, which uses the default AWS credential provider chain
session = boto3.Session()
credentials = session.get_credentials()

if credentials:
auth = AWS4Auth(
credentials.access_key,
credentials.secret_key,
region,
service,
session_token=credentials.token
)
return auth
else:
logger.error("Could not retrieve AWS credentials from boto3 session.")
sys.exit(1)

elif 'basic' in cluster_config:
username = cluster_config['basic']['username']
password = cluster_config['basic']['password']
auth = HTTPBasicAuth(username, password)
return auth

elif 'no_auth' in cluster_config:
return None

else:
logger.warning("No authentication method found in configuration. Assuming no authentication.")
return None


def main():
args = parse_args()
config = load_config()

# Determine which cluster to target as endpoint, SOURCE or TARGET
cluster_type = args.target
cluster_config = get_cluster_config(config, cluster_type)

url_base = args.endpoint or os.environ.get('SOURCE_DOMAIN_ENDPOINT', 'https://capture-proxy:9200')
auth = (args.username, args.password) if args.username and args.password else None
# Extract endpoint and authentication details
url_base = cluster_config['endpoint']
auth = get_auth(cluster_config)

session = requests.Session()
keep_alive_headers = {'Connection': 'keep-alive'}
Expand Down
Loading