diff --git a/creation/lib/cgWCreate.py b/creation/lib/cgWCreate.py index d9727aabd..99f652624 100644 --- a/creation/lib/cgWCreate.py +++ b/creation/lib/cgWCreate.py @@ -161,7 +161,7 @@ def populate(self, exe_fname, entry_name, conf, entry): submit_attrs = entry.get_child("config").get_child("submit").get_child_list("submit_attrs") enc_input_files = [] - enc_input_files.append("$ENV(IDTOKENS_FILE:)") + enc_input_files.append("$ENV(IDENTITY_CREDENTIALS:)") self.add_environment("IDTOKENS_FILE=$ENV(IDTOKENS_FILE:)") if gridtype not in ["ec2", "gce"] and not (gridtype == "arc" and auth_method == "grid_proxy"): @@ -308,7 +308,7 @@ def populate_standard_grid(self, rsl, auth_method, gridtype, entry_enabled, entr self.add("cream_attributes", "$ENV(GLIDEIN_RSL)") elif gridtype == "nordugrid" and rsl: self.add("nordugrid_rsl", "$ENV(GLIDEIN_RSL)") - elif (gridtype == "condor") and ("project_id" in auth_method): + elif (gridtype == "condor") and ("project_id" in auth_method): # TODO: Check for credentials refactoring impact self.add("+ProjectName", '"$ENV(GLIDEIN_PROJECT_ID)"') # Force the copy to spool to prevent caching at the CE side diff --git a/creation/lib/cvWParamDict.py b/creation/lib/cvWParamDict.py index 1e40c9ea1..08b9657b7 100644 --- a/creation/lib/cvWParamDict.py +++ b/creation/lib/cvWParamDict.py @@ -1121,6 +1121,8 @@ def populate_common_descript(descript_dict, params): "security_class": "ProxySecurityClasses", "trust_domain": "ProxyTrustDomains", "type": "ProxyTypes", + "purpose": "CredentialPurposes", + "context": "CredentialContexts", # credential files probably should be handles as a list, each w/ name and path # or the attributes ending in _file are files # "file": "CredentialFiles", # placeholder for when name will not be absfname @@ -1186,6 +1188,7 @@ def populate_common_descript(descript_dict, params): descript_dict.add(proxy_attr_names[attr], repr(proxy_descript_values[attr])) match_expr = params.match.match_expr + descript_dict.add("Parameters", repr(params.security.parameters)) descript_dict.add("MatchExpr", match_expr) @@ -1345,7 +1348,7 @@ def populate_group_security(client_security, params, sub_params, group_name): client_security["schedd_DNs"] = schedd_dns pilot_dns = [] - exclude_from_pilot_dns = ["SCITOKEN", "IDTOKEN"] + exclude_from_pilot_dns = ["SCITOKEN", "IDTOKEN", "GENERATOR"] for credentials in (params.security.credentials, sub_params.security.credentials): if is_true(params.groups[group_name].enabled): for pel in credentials: diff --git a/creation/lib/cvWParams.py b/creation/lib/cvWParams.py index b535c2b8c..d0af80827 100644 --- a/creation/lib/cvWParams.py +++ b/creation/lib/cvWParams.py @@ -297,6 +297,18 @@ def init_defaults(self): "Type of credential: grid_proxy,cert_pair,key_pair,username_password,auth_file", None, ) + proxy_defaults["purpose"] = ( + "request", + "credential purpose", + "Purpose of credential: request,payload", + None, + ) + proxy_defaults["context"] = ( + None, + "PythonExpr", + "Python mapping with a context for credential generators", + None, + ) proxy_defaults["trust_domain"] = ("OSG", "grid_type", "Trust Domain", None) proxy_defaults["creation_script"] = (None, "command", "Script to re-create credential", None) proxy_defaults["update_frequency"] = (None, "int", "Update proxy when there is this much time left", None) @@ -332,6 +344,23 @@ def init_defaults(self): proxy_defaults["vm_type_fname"] = (None, "fname", "to specify a vm type without reconfig", None) proxy_defaults["project_id"] = (None, "string", "OSG Project ID. Ex TG-12345", None) + # Parameter settings + parameter_defaults = cWParams.CommentedOrderedDict() + parameter_defaults["name"] = (None, "string", "parameter name", None) + parameter_defaults["value"] = (None, "string", "parameter value", None) + parameter_defaults["type"] = ( + None, + "string", + "parameter type (int, string, expr)", + None, + ) + parameter_defaults["context"] = ( + None, + "PythonExpr", + "Python mapping with a context for parameter generators", + None, + ) + security_defaults = cWParams.CommentedOrderedDict() security_defaults["proxy_selection_plugin"] = ( None, @@ -345,6 +374,12 @@ def init_defaults(self): "Each credential element contains", proxy_defaults, ) + security_defaults["parameters"] = ( + OrderedDict(), + "List of parameters", + "Each parameter element contains", + parameter_defaults, + ) security_defaults["security_name"] = ( None, "frontend_name", @@ -715,6 +750,7 @@ def get_xml_format(self): "attrs": {"el_name": "attr", "subtypes_params": {"class": {}}}, "groups": {"el_name": "group", "subtypes_params": {"class": {}}}, "match_attrs": {"el_name": "match_attr", "subtypes_params": {"class": {}}}, + "parameters": {"el_name": "parameter", "subtypes_params": {"class": {}}}, }, } diff --git a/creation/web_base/setup_x509.sh b/creation/web_base/setup_x509.sh index 451d0cf08..e6a982c43 100644 --- a/creation/web_base/setup_x509.sh +++ b/creation/web_base/setup_x509.sh @@ -348,6 +348,32 @@ copy_idtokens() { return } +# Copy non-idtoken credentials from the start directory to the credentials directory +# Credentials must match: ^credential_.*\.(scitoken|jwt|pem|rsa|txt)$ +copy_credentials() { + local start_dir from_dir=$1 to_dir=$2 + start_dir=$(pwd) + if ! cd "$from_dir"; then + ERROR="Cannot cd to from_dir ($from_dir)" + # did not change directory, OK to just return + return 1 + fi + for cred in credential_*; do + [[ -e "$cred" ]] || continue # protect against nullglob (no match) + if [[ "$cred" =~ ^credential_.*\.(scitoken|jwt|pem|rsa|txt)$ ]]; then + if cp "$cred" "$to_dir/$cred"; then + warn "Copied credential '${cred}' to '${to_dir}/'" + else + warn "Failed to copy credential '${cred}'" + fi + else + warn "Skipping credential '${cred}'" + fi + done + cd "$start_dir" || true + return +} + # Retrieve trust domain # Uses TRUST_DOMAIN, GLIDEIN_Collector and CCB_ADDRESS from glidein_config # Return only the first Collector if more are in the list (separators:,\ \t) @@ -457,6 +483,13 @@ _main() { else warn "$ERROR" fi + # TODO: Initial copy. Evaluate separately credentials refresh. + # PAYLOAD CREDENTIALS + if copy_credentials "$GLIDEIN_START_DIR_ORIG" "${gwms_credentials_dir}"; then + cred_updated+=credentials, + else + warn "$ERROR" + fi # x509 - skip if there is no X509_USER_PROXY if ! X509_USER_PROXY=$(get_x509_proxy); then diff --git a/factory/glideFactoryConfig.py b/factory/glideFactoryConfig.py index 89397678e..12e299ac2 100644 --- a/factory/glideFactoryConfig.py +++ b/factory/glideFactoryConfig.py @@ -5,7 +5,7 @@ import os.path import shutil -from glideinwms.lib import pubCrypto, symCrypto +from glideinwms.lib import credentials, pubCrypto, symCrypto ############################################################ # @@ -135,7 +135,7 @@ def __init__(self, entry_name, config_file, convert_function=repr): ############################################################ -class GlideinKey: +class GlideinKey: # TODO: Check for credentials refactor def __init__(self, pub_key_type, key_fname=None, recreate=False): self.pub_key_type = pub_key_type self.load(key_fname, recreate) @@ -247,7 +247,7 @@ def load_old_rsa_key(self): if self.data["OldPubKeyType"] is not None: try: - self.data["OldPubKeyObj"] = GlideinKey(self.data["OldPubKeyType"], key_fname=self.backup_rsakey_fname) + self.data["OldPubKeyObj"] = credentials.RSAKey(path=self.backup_rsakey_fname) except Exception: self.data["OldPubKeyType"] = None self.data["OldPubKeyObj"] = None @@ -271,9 +271,9 @@ def load_pub_key(self, recreate=False): recreate (bool): Create a new key overwriting the old one. Defaults to False """ if self.data["PubKeyType"] is not None: - self.data["PubKeyObj"] = GlideinKey( - self.data["PubKeyType"], key_fname=self.default_rsakey_fname, recreate=recreate - ) + self.data["PubKeyObj"] = credentials.RSAKey(path=self.default_rsakey_fname) + if recreate: + self.data["PubKeyObj"].recreate() else: self.data["PubKeyObj"] = None return diff --git a/factory/glideFactoryCredentials.py b/factory/glideFactoryCredentials.py index 2664ef384..88c1daa05 100644 --- a/factory/glideFactoryCredentials.py +++ b/factory/glideFactoryCredentials.py @@ -11,6 +11,7 @@ from glideinwms.lib import condorMonitor, logSupport from glideinwms.lib.defaults import force_bytes +from glideinwms.lib.util import is_str_safe from . import glideFactoryInterface, glideFactoryLib @@ -51,7 +52,7 @@ def add_security_credential(self, cred_type, filename): """ Adds a security credential. """ - if not glideFactoryLib.is_str_safe(filename): + if not is_str_safe(filename): return False cred_fname = os.path.join(self.cred_dir, "credential_%s" % filename) @@ -283,11 +284,12 @@ def check_security_credentials(auth_method, params, client_int_name, entry_name, CredentialError: if the credentials in params don't match what is defined for the auth method """ - auth_method_list = auth_method.split("+") - if not set(auth_method_list) & set(SUPPORTED_AUTH_METHODS): + # TODO: This function policies need to be reviewed and updated. + + auth_set = params.get("AuthSet", auth_method.split("+")) # Fall back to auth_method (str) for retrocompatibility + if isinstance(auth_set, str) and not set(auth_set) & set(SUPPORTED_AUTH_METHODS): logSupport.log.warning( - "None of the supported auth methods %s in provided auth methods: %s" - % (SUPPORTED_AUTH_METHODS, auth_method_list) + f"None of the supported auth methods {SUPPORTED_AUTH_METHODS} in provided auth methods: {auth_set}" ) return @@ -306,12 +308,12 @@ def check_security_credentials(auth_method, params, client_int_name, entry_name, "AuthFile", } - if "scitoken" in auth_method_list or "frontend_scitoken" in params and scitoken_passthru: + if "scitoken" in auth_set or "frontend_scitoken" in params and scitoken_passthru: # TODO check validity # TODO Specifically, Add checks that no undesired credentials are # sent also when token is used return - if "grid_proxy" in auth_method_list: + if "grid_proxy" in auth_set: if not scitoken_passthru: if "SubmitProxy" in params: # v3+ protocol @@ -319,89 +321,79 @@ def check_security_credentials(auth_method, params, client_int_name, entry_name, invalid_keys = relevant_keys.difference(valid_keys) if params_keys.intersection(invalid_keys): raise CredentialError( - "Request from %s has credentials not required by the entry %s, skipping request" - % (client_int_name, entry_name) + f"Request from {client_int_name} has credentials not required by the entry {entry_name}, skipping request" ) else: # No proxy sent raise CredentialError( - "Request from client %s did not provide a proxy as required by the entry %s, skipping request" - % (client_int_name, entry_name) + f"Request from client {client_int_name} did not provide a proxy as required by the entry {entry_name}, skipping request" ) else: # Only v3+ protocol supports non grid entries # Verify that the glidein proxy was provided for non-proxy auth methods if "GlideinProxy" not in params and not scitoken_passthru: - raise CredentialError("Glidein proxy cannot be found for client %s, skipping request" % client_int_name) + raise CredentialError(f"Glidein proxy cannot be found for client {client_int_name}, skipping request") - if "cert_pair" in auth_method_list: + if "cert_pair" in auth_set: # Validate both the public and private certs were passed if not (("PublicCert" in params) and ("PrivateCert" in params)): # if not ('PublicCert' in params and 'PrivateCert' in params): # cert pair is required, cannot service request raise CredentialError( - "Client '%s' did not specify the certificate pair in the request, this is required by entry %s, skipping " - % (client_int_name, entry_name) + f"Client '{client_int_name}' did not specify the certificate pair in the request, this is required by entry {entry_name}, skipping" ) # Verify no other credentials were passed valid_keys = {"GlideinProxy", "PublicCert", "PrivateCert", "VMId", "VMType"} invalid_keys = relevant_keys.difference(valid_keys) if params_keys.intersection(invalid_keys): raise CredentialError( - "Request from %s has credentials not required by the entry %s, skipping request" - % (client_int_name, entry_name) + f"Request from {client_int_name} has credentials not required by the entry {entry_name}, skipping request" ) - elif "key_pair" in auth_method_list: + elif "key_pair" in auth_set: # Validate both the public and private keys were passed if not (("PublicKey" in params) and ("PrivateKey" in params)): # key pair is required, cannot service request raise CredentialError( - "Client '%s' did not specify the key pair in the request, this is required by entry %s, skipping " - % (client_int_name, entry_name) + f"Client '{client_int_name}' did not specify the key pair in the request, this is required by entry {entry_name}, skipping" ) # Verify no other credentials were passed valid_keys = {"GlideinProxy", "PublicKey", "PrivateKey", "VMId", "VMType"} invalid_keys = relevant_keys.difference(valid_keys) if params_keys.intersection(invalid_keys): raise CredentialError( - "Request from %s has credentials not required by the entry %s, skipping request" - % (client_int_name, entry_name) + f"Request from {client_int_name} has credentials not required by the entry {entry_name}, skipping request" ) - elif "auth_file" in auth_method_list: + elif "auth_file" in auth_set: # Validate auth_file is passed if "AuthFile" not in params: # auth_file is required, cannot service request raise CredentialError( - "Client '%s' did not specify the auth_file in the request, this is required by entry %s, skipping " - % (client_int_name, entry_name) + f"Client '{client_int_name}' did not specify the auth_file in the request, this is required by entry {entry_name}, skipping" ) # Verify no other credentials were passed valid_keys = {"GlideinProxy", "AuthFile", "VMId", "VMType"} invalid_keys = relevant_keys.difference(valid_keys) if params_keys.intersection(invalid_keys): raise CredentialError( - "Request from %s has credentials not required by the entry %s, skipping request" - % (client_int_name, entry_name) + f"Request from {client_int_name} has credentials not required by the entry {entry_name}, skipping request" ) - elif "username_password" in auth_method_list: + elif "username_password" in auth_set: # Validate username and password keys were passed if not (("Username" in params) and ("Password" in params)): # username and password is required, cannot service request raise CredentialError( - "Client '%s' did not specify the username and password in the request, this is required by entry %s, skipping " - % (client_int_name, entry_name) + f"Client '{client_int_name}' did not specify the username and password in the request, this is required by entry {entry_name}, skipping request" ) # Verify no other credentials were passed valid_keys = {"GlideinProxy", "Username", "Password", "VMId", "VMType"} invalid_keys = relevant_keys.difference(valid_keys) if params_keys.intersection(invalid_keys): raise CredentialError( - "Request from %s has credentials not required by the entry %s, skipping request" - % (client_int_name, entry_name) + f"Request from {client_int_name} has credentials not required by the entry {entry_name}, skipping request" ) else: diff --git a/factory/glideFactoryEntry.py b/factory/glideFactoryEntry.py index 22c097968..5a3c0d906 100644 --- a/factory/glideFactoryEntry.py +++ b/factory/glideFactoryEntry.py @@ -10,6 +10,7 @@ import copy import os import os.path +import pickle import signal import sys import tempfile @@ -19,7 +20,18 @@ from glideinwms.factory import glideFactoryInterface as gfi from glideinwms.factory import glideFactoryLib, glideFactoryLogParser, glideFactoryMonitoring from glideinwms.lib import classadSupport, cleanupSupport, defaults, glideinWMSVersion, logSupport, token_util, util -from glideinwms.lib.util import chmod +from glideinwms.lib.credentials import ( + create_parameter, + CredentialError, + CredentialPair, + CredentialPairType, + CredentialPurpose, + CredentialType, + ParameterName, + standard_path, + SubmitBundle, +) +from glideinwms.lib.util import chmod, is_str_safe ############################################################ @@ -985,14 +997,27 @@ def check_and_perform_work(factory_in_downtime, entry, work): # STEP: Process every work one at a time. This is done only for entries with work to do # for work_key in work: - if not glideFactoryLib.is_str_safe(work_key): + if not is_str_safe(work_key): # may be used to write files... make sure it is reasonable entry.log.warning("Request name '%s' not safe. Skipping request" % work_key) continue # merge work and default params params = work[work_key]["params"] - decrypted_params = {key: value.decode() for key, value in work[work_key]["params_decrypted"].items()} + decrypted_params = {} + for key, value in work[work_key]["params_decrypted"].items(): + decrypted_value = None + try: + decrypted_value = value.decode() + except UnicodeDecodeError: + try: + decrypted_value = pickle.loads(value) + except pickle.PickleError: + entry.log.exception( + f'Failed to load decrypted value for key "{key}", value "{value}". Continuing with other keys.' + ) + if decrypted_value is not None: + decrypted_params[key] = decrypted_value # add default values if not defined for k in entry.jobParams.data: @@ -1008,7 +1033,7 @@ def check_and_perform_work(factory_in_downtime, entry, work): entry.log.warning("Request %s did not provide the client and/or request name. Skipping request" % work_key) continue - if not glideFactoryLib.is_str_safe(client_int_name): + if not is_str_safe(client_int_name): # may be used to write files... make sure it is reasonable entry.log.warning("Client name '%s' not safe. Skipping request" % client_int_name) continue @@ -1038,7 +1063,7 @@ def check_and_perform_work(factory_in_downtime, entry, work): glideFactoryCredentials.check_security_credentials( auth_method, decrypted_params, client_int_name, entry.name, scitoken_passthru ) - except glideFactoryCredentials.CredentialError: + except CredentialError: entry.log.exception("Error checking credentials, skipping request: ") continue @@ -1075,7 +1100,12 @@ def check_and_perform_work(factory_in_downtime, entry, work): # # STEP: Actually process the unit work using v3 protocol # - work_performed = unit_work_v3( + if "AuthSet" in decrypted_params: + unit_work = unit_work_v3_11 + else: + unit_work = unit_work_v3 + + work_performed = unit_work( entry, work[work_key], work_key, @@ -1706,6 +1736,309 @@ def unit_work_v3( return return_dict +def unit_work_v3_11( + entry, + work, + client_name, + client_int_name, + client_int_req, + client_expected_identity, + decrypted_params, + params, + in_downtime, + condorQ, +): + """Perform a single work unit using the v3 protocol. + + :param entry: Entry + :param work: work requests + :param client_name: work_key (key used in the work request) + :param client_int_name: client name declared in the request + :param client_int_req: name of the request (declared in the request) + :param client_expected_identity: + :param decrypted_params: + :param params: + :param in_downtime: + :param condorQ: list of HTCondor jobs for this entry as returned by entry.queryQueuedGlideins() + :return: Return dictionary w/ success, security_names and work_done + """ + + # Return dictionary. Only populate information to be passed at the end + # just before returning. + return_dict = { + "success": False, + "security_names": None, + "work_done": None, + } + + # + # STEP: CHECK THAT GLIDEINS ARE WITHIN ALLOWED LIMITS + # + can_submit_glideins = entry.glideinsWithinLimits(condorQ) + + # Get grid type + grid_type = entry.jobDescript.data["GridType"] + + # Get credential security class + credential_security_class = decrypted_params.get("SecurityClass") + client_security_name = decrypted_params.get("SecurityName") + + if not credential_security_class: + entry.log.warning("Client %s did not provide a security class. Skipping bad request." % client_int_name) + return return_dict + + # Check security class for downtime (in downtimes file) + entry.log.info( + "Checking downtime for frontend %s security class: %s (entry %s)." + % (client_security_name, credential_security_class, entry.name) + ) + + if entry.isSecurityClassInDowntime(client_security_name, credential_security_class): + # Cannot use proxy for submission but entry is not in downtime + # since other proxies may map to valid security classes + entry.log.warning( + "Security class %s is currently in a downtime window for entry: %s. Ignoring request." + % (credential_security_class, entry.name) + ) + # this below change is based on redmine ticket 3110. + # even though we do not return here, setting in_downtime=True (for entry downtime) + # will make sure no new glideins will be submitted in the same way that + # the code does for the factory downtime + in_downtime = True + # return return_dict + + # Deny Frontend from requesting glideins if the whitelist + # does not have its security class (or "All" for everyone) + if entry.isClientWhitelisted(client_security_name): + if entry.isSecurityClassAllowed(client_security_name, credential_security_class): + entry.log.info(f"Security test passed for : {entry.name} {credential_security_class} ") + else: + entry.log.warning( + "Security class not in whitelist, skipping request (%s %s)." + % (client_security_name, credential_security_class) + ) + return return_dict + + # Check that security class maps to a username for submission + # The username is still used also in single user factory (for log dirs, ...) + credential_username = entry.frontendDescript.get_username(client_security_name, credential_security_class) + if credential_username is None: + entry.log.warning( + "No username mapping for security class %s of credential for %s (secid: %s), skipping request." + % (credential_security_class, client_int_name, client_security_name) + ) + return return_dict + + # Initialize submit credential object & determine the credential location + submit_credentials = SubmitBundle(credential_username, credential_security_class) + submit_credentials.auth_set = decrypted_params.get("AuthSet") + submit_credentials.cred_dir = entry.gflFactoryConfig.get_client_proxies_dir(credential_username) + + # Load credentials + frontend_scitoken = decrypted_params.get("frontend_scitoken", None) + frontend_condortoken = decrypted_params.get(f"{entry.name}.idtoken", None) + for cred in decrypted_params.get("RequestCredentials") + decrypted_params.get("PayloadCredentials"): + if not cred.path: + cred.path = cred.id + cred.path = os.path.join(submit_credentials.cred_dir, os.path.basename(cred.path)) + cred.path = standard_path(cred) + cred.save_to_file(backup=True) + if isinstance(cred, CredentialPair): + if cred.private_credential.path: + cred.private_credential.path = os.path.join( + submit_credentials.cred_dir, os.path.basename(cred.private_credential.path) + ) + cred.private_credential.path = standard_path(cred.private_credential) + cred.private_credential.save_to_file(backup=True) + if cred.purpose == CredentialPurpose.REQUEST: + submit_credentials.add_security_credential(cred) # TODO: Check if should use classad_attribute as id + if cred.string.decode() == frontend_scitoken: + submit_credentials.add_identity_credential(cred, "frontend_scitoken") + else: + submit_credentials.add_identity_credential(cred) + if cred.string.decode() == frontend_condortoken: + submit_credentials.add_identity_credential(cred, "frontend_condortoken") + + # Handle cloud-specific credentials # TODO: Check if still needed + if grid_type in ("ec2", "gce"): + for cred in submit_credentials.security_credentials: + if cred.credential_type == CredentialType.X509_CERT: + if os.path.exists(cred.path + "_compressed"): + cred.path = cred.path + "_compressed" + + # Load parameters + for param in decrypted_params.get("SecurityParameters"): + if submit_credentials.auth_set.supports(param.name): + submit_credentials.add_parameter(param) + + # Load default username from entry if needed + if submit_credentials.auth_set.supports(CredentialPairType.KEY_PAIR): + if ParameterName.REMOTE_USERNAME not in submit_credentials.parameters: + gatekeeper_list = entry.jobDescript.data["Gatekeeper"].split("@") + if len(gatekeeper_list) == 2: + submit_credentials.add_parameter( + create_parameter(ParameterName.REMOTE_USERNAME, gatekeeper_list[0].strip()) + ) + else: + entry.log.warning( + f"Client '{client_int_name}' did not specify a Username to use with {CredentialPairType.KEY_PAIR} and the entry {entry.name} does not provide a default username in the gatekeeper string, skipping request" + ) + return return_dict + + # Determine submit_credentials.id # TODO: What if there are multiple RequestCredentials? + submit_credentials.id = decrypted_params.get("RequestCredentials")[-1].id + + # Set the downtime status so the frontend-specific + # downtime is advertised in glidefactoryclient ads + entry.setDowntime(in_downtime) + entry.gflFactoryConfig.qc_stats.set_downtime(in_downtime) + + # + # STEP: CHECK IF CLEANUP OF IDLE GLIDEINS IS REQUIRED + # + + remove_excess = ( + work["requests"].get("RemoveExcess", "NO"), + work["requests"].get("RemoveExcessMargin", 0), + work["requests"].get("IdleGlideins", 0), + ) + idle_lifetime = work["requests"].get("IdleLifetime", 0) + + if "IdleGlideins" not in work["requests"]: + # Malformed, if no IdleGlideins + entry.log.warning("Skipping malformed classad for client %s" % client_name) + return return_dict + + try: + idle_glideins = int(work["requests"]["IdleGlideins"]) + except ValueError: + entry.log.warning( + f"Client {client_int_name} provided an invalid ReqIdleGlideins: '{work['requests']['IdleGlideins']}' not a number. Skipping request" + ) + return return_dict + + if "MaxGlideins" in work["requests"]: + try: + max_glideins = int(work["requests"]["MaxGlideins"]) + except ValueError: + entry.log.warning( + f"Client {client_int_name} provided an invalid ReqMaxGlideins: '{work['requests']['MaxGlideins']}' not a number. Skipping request." + ) + return return_dict + else: + try: + max_glideins = int(work["requests"]["MaxRunningGlideins"]) + except ValueError: + entry.log.warning( + f"Client {client_int_name} provided an invalid ReqMaxRunningGlideins: '{work['requests']['MaxRunningGlideins']}' not a number. Skipping request" + ) + return return_dict + + # If we got this far, it was because we were able to + # successfully update all the credentials in the request + # If we already have hit our limits checked at beginning of this + # method and logged there, we can't submit. + # We still need to check/update all the other request credentials + # and do cleanup. + + # We'll set idle glideins to zero if hit max or in downtime. + if in_downtime or not can_submit_glideins: + idle_glideins = 0 + + try: + client_web_url = work["web"]["URL"] + client_signtype = work["web"]["SignType"] + client_descript = work["web"]["DescriptFile"] + client_sign = work["web"]["DescriptSign"] + client_group = work["internals"]["GroupName"] + client_group_web_url = work["web"]["GroupURL"] + client_group_descript = work["web"]["GroupDescriptFile"] + client_group_sign = work["web"]["GroupDescriptSign"] + + client_web = glideFactoryLib.ClientWeb( + client_web_url, + client_signtype, + client_descript, + client_sign, + client_group, + client_group_web_url, + client_group_descript, + client_group_sign, + ) + except Exception: + # malformed classad, skip + entry.log.warning("Malformed classad for client %s, missing web parameters, skipping request." % client_name) + return return_dict + + # Should log here or in perform_work + glideFactoryLib.logWorkRequest( + client_int_name, + client_security_name, + submit_credentials.security_class, + idle_glideins, + max_glideins, + remove_excess, + work, + log=entry.log, + factoryConfig=entry.gflFactoryConfig, + ) + + # Iv v2 this was: + # entry_condorQ = glideFactoryLib.getQProxSecClass( + # condorQ, client_int_name, + # submit_credentials.security_class, + # client_schedd_attribute=entry.gflFactoryConfig.client_schedd_attribute, + # credential_secclass_schedd_attribute=entry.gflFactoryConfig.credential_secclass_schedd_attribute, + # factoryConfig=entry.gflFactoryConfig) + + # Sub-query selecting jobs in Factory schedd (still dictionary keyed by cluster, proc) + # for (client_schedd_attribute, credential_secclass_schedd_attribute, credential_id_schedd_attribute) + # ie (GlideinClient, GlideinSecurityClass, GlideinCredentialIdentifier) + entry_condorQ = glideFactoryLib.getQCredentials( + condorQ, + client_int_name, + submit_credentials, + entry.gflFactoryConfig.client_schedd_attribute, + entry.gflFactoryConfig.credential_secclass_schedd_attribute, + entry.gflFactoryConfig.credential_id_schedd_attribute, + ) + + # Map the identity to a frontend:sec_class for tracking totals + frontend_name = "{}:{}".format( + entry.frontendDescript.get_frontend_name(client_expected_identity), + credential_security_class, + ) + + # do one iteration for the credential set (maps to a single security class) + # entry.gflFactoryConfig.client_internals[client_int_name] = \ + # {"CompleteName":client_name, "ReqName":client_int_req} + + done_something = perform_work_v3( + entry, + entry_condorQ, + client_name, + client_int_name, + client_security_name, + submit_credentials, + remove_excess, + idle_glideins, + max_glideins, + idle_lifetime, + credential_username, + entry.glideinTotals, + frontend_name, + client_web, + params, + ) + + # Gather the information to be returned back + return_dict["success"] = True + return_dict["work_done"] = done_something + return_dict["security_names"] = {client_security_name, credential_security_class} + + return return_dict + + ############################################################################### # removed diff --git a/factory/glideFactoryInterface.py b/factory/glideFactoryInterface.py index 655b215fa..4d18e6f7e 100644 --- a/factory/glideFactoryInterface.py +++ b/factory/glideFactoryInterface.py @@ -211,7 +211,7 @@ def findGroupWork( # Any other key will not work status_constraint += ( ' && (((ReqPubKeyID=?="%s") && (ReqEncKeyCode=!=Undefined) && (ReqEncIdentity=!=Undefined)) || (ReqPubKeyID=?=Undefined))' - % pub_key_obj.get_pub_key_id() + % pub_key_obj.pub_key_id ) if additional_constraints is not None: @@ -623,10 +623,10 @@ def __init__( self.adParams["GlideinWMSVersion"] = factoryConfig.glideinwms_version if pub_key_obj is not None: - self.adParams["PubKeyID"] = "%s" % pub_key_obj.get_pub_key_id() - self.adParams["PubKeyType"] = "%s" % pub_key_obj.get_pub_key_type() - self.adParams["PubKeyValue"] = "%s" % pub_key_obj.get_pub_key_value().decode("ascii").replace("\n", "\\n") - if "grid_proxy" in auth_method: + self.adParams["PubKeyID"] = "%s" % pub_key_obj.pub_key_id + self.adParams["PubKeyType"] = "%s" % pub_key_obj.key_type + self.adParams["PubKeyValue"] = "%s" % pub_key_obj.pub_key.get().decode("ascii").replace("\n", "\\n") + if "grid_proxy" in auth_method: # TODO: Check for credentials refactoring impact self.adParams["GlideinAllowx509_Proxy"] = "%s" % True self.adParams["GlideinRequirex509_Proxy"] = "%s" % True self.adParams["GlideinRequireGlideinProxy"] = "%s" % False @@ -707,9 +707,9 @@ def __init__(self, factory_name, glidein_name, supported_signtypes, pub_key_obj) self.adParams["UpdateSequenceNumber"] = advertizeGlobalCounter advertizeGlobalCounter += 1 self.adParams["GlideinWMSVersion"] = factoryConfig.glideinwms_version - self.adParams["PubKeyID"] = "%s" % pub_key_obj.get_pub_key_id() - self.adParams["PubKeyType"] = "%s" % pub_key_obj.get_pub_key_type() - self.adParams["PubKeyValue"] = "%s" % pub_key_obj.get_pub_key_value().decode("ascii").replace("\n", "\\n") + self.adParams["PubKeyID"] = "%s" % pub_key_obj.pub_key_id + self.adParams["PubKeyType"] = "%s" % pub_key_obj.key_type + self.adParams["PubKeyValue"] = "%s" % pub_key_obj.pub_key.get().decode("ascii").replace("\n", "\\n") def advertizeGlobal( diff --git a/factory/glideFactoryLib.py b/factory/glideFactoryLib.py index c6f1fef06..bd1ff9121 100644 --- a/factory/glideFactoryLib.py +++ b/factory/glideFactoryLib.py @@ -11,7 +11,6 @@ import os import pwd import re -import string # in codice commentato: import tempfile import time @@ -28,6 +27,7 @@ logSupport, timeConversion, ) +from glideinwms.lib.credentials import cred_path, CredentialPair, ParameterName from glideinwms.lib.defaults import BINARY_ENCODING MY_USERNAME = pwd.getpwuid(os.getuid())[0] @@ -801,9 +801,9 @@ def keepIdleGlideins( except RuntimeError as e: log.warning("%s" % e) return 0 # something is wrong... assume 0 and exit - except Exception: - log.warning("Unexpected error submiting glideins") - log.exception("Unexpected error submiting glideins") + except Exception as e: + log.warning(f"Unexpected error submiting glideins: {e}") + log.exception(f"Unexpected error submiting glideins: {e}") return 0 # something is wrong... assume 0 and exit return 0 @@ -1818,13 +1818,26 @@ def get_submit_environment( exe_env = ["GLIDEIN_ENTRY_NAME=%s" % entry_name] if "frontend_scitoken" in submit_credentials.identity_credentials: - exe_env.append("SCITOKENS_FILE=%s" % submit_credentials.identity_credentials["frontend_scitoken"]) + exe_env.append( + cred_path("SCITOKENS_FILE=%s" % submit_credentials.identity_credentials["frontend_scitoken"]) + ) if "frontend_condortoken" in submit_credentials.identity_credentials: - exe_env.append("IDTOKENS_FILE=%s" % submit_credentials.identity_credentials["frontend_condortoken"]) + exe_env.append( + cred_path("IDTOKENS_FILE=%s" % submit_credentials.identity_credentials["frontend_condortoken"]) + ) else: # TODO: this ends up transferring an empty file called 'null' in the Glidein start dir. Find a better way exe_env.append("IDTOKENS_FILE=/dev/null") + id_cred_paths = [] + for cred in submit_credentials.identity_credentials.values(): + if cred_path(cred): + id_cred_paths.append(cred_path(cred)) + if isinstance(cred, CredentialPair): + if cred_path(cred.private_credential): + id_cred_paths.append(cred_path(cred.private_credential)) + exe_env.append(f"IDENTITY_CREDENTIALS={','.join(id_cred_paths)}") + # The parameter list to be added to the arguments for glidein_startup.sh params_str = "" # if client_web has been provided, get the arguments and add them to the string @@ -1938,12 +1951,12 @@ def get_submit_environment( pass exe_env.append( "GRID_RESOURCE_OPTIONS=--rgahp-key %s --rgahp-nopass" - % submit_credentials.security_credentials["PrivateKey"] + % cred_path(submit_credentials.security_credentials["PrivateKey"]) ) - exe_env.append("X509_USER_PROXY=%s" % submit_credentials.security_credentials["GlideinProxy"]) + exe_env.append("X509_USER_PROXY=%s" % cred_path(submit_credentials.security_credentials["GlideinProxy"])) exe_env.append( "X509_USER_PROXY_BASENAME=%s" - % os.path.basename(submit_credentials.security_credentials["GlideinProxy"]) + % os.path.basename(cred_path(submit_credentials.security_credentials["GlideinProxy"])) ) glidein_arguments += " -cluster $(Cluster) -subcluster $(Process)" # condor and batch (BLAH/BOSCO) submissions do not like arguments enclosed in quotes @@ -1961,21 +1974,29 @@ def get_submit_environment( # log.debug("submit_credentials.identity_credentials: %s" % str(submit_credentials.identity_credentials)) try: - exe_env.append("X509_USER_PROXY=%s" % submit_credentials.security_credentials["GlideinProxy"]) + exe_env.append( + "X509_USER_PROXY=%s" % cred_path(submit_credentials.security_credentials["GlideinProxy"]) + ) exe_env.append("IMAGE_ID=%s" % submit_credentials.identity_credentials["VMId"]) exe_env.append("INSTANCE_TYPE=%s" % submit_credentials.identity_credentials["VMType"]) if grid_type == "ec2": - exe_env.append("ACCESS_KEY_FILE=%s" % submit_credentials.security_credentials["PublicKey"]) - exe_env.append("SECRET_KEY_FILE=%s" % submit_credentials.security_credentials["PrivateKey"]) exe_env.append( - "CREDENTIAL_DIR=%s" % os.path.dirname(submit_credentials.security_credentials["PublicKey"]) + "ACCESS_KEY_FILE=%s" % cred_path(submit_credentials.security_credentials["PublicKey"]) + ) + exe_env.append( + "SECRET_KEY_FILE=%s" % cred_path(submit_credentials.security_credentials["PrivateKey"]) + ) + exe_env.append( + "CREDENTIAL_DIR=%s" + % os.path.dirname(cred_path(submit_credentials.security_credentials["PublicKey"])) ) elif grid_type == "gce": - exe_env.append("GCE_AUTH_FILE=%s" % submit_credentials.security_credentials["AuthFile"]) + exe_env.append("GCE_AUTH_FILE=%s" % cred_path(submit_credentials.security_credentials["AuthFile"])) exe_env.append("GRID_RESOURCE_OPTIONS=%s" % "$(gce_project_name) $(gce_availability_zone)") exe_env.append( - "CREDENTIAL_DIR=%s" % os.path.dirname(submit_credentials.security_credentials["AuthFile"]) + "CREDENTIAL_DIR=%s" + % os.path.dirname(cred_path(submit_credentials.security_credentials["AuthFile"])) ) try: @@ -2022,7 +2043,7 @@ def get_submit_environment( exe_env.append("USER_DATA=%s" % ini) # get the proxy - full_path_to_proxy = submit_credentials.security_credentials["GlideinProxy"] + full_path_to_proxy = cred_path(submit_credentials.security_credentials["GlideinProxy"]) exe_env.append("GLIDEIN_PROXY_FNAME=%s" % full_path_to_proxy) except KeyError: @@ -2038,8 +2059,9 @@ def get_submit_environment( # Unknown error, re-raise to stop the environment build raise else: - proxy = submit_credentials.security_credentials.get("SubmitProxy", "") - exe_env.append("X509_USER_PROXY=%s" % proxy) + proxy = submit_credentials.security_credentials.get("SubmitProxy", None) + proxy_path = cred_path(proxy) if proxy else "" + exe_env.append("X509_USER_PROXY=%s" % proxy_path) # TODO: we might review this part as we added this here because the macros were been expanded when used in the gt2 submission # we don't add the macros to the arguments for the EC2 submission since condor will never @@ -2057,10 +2079,22 @@ def get_submit_environment( if "GlobusRSL" in jobDescript.data: glidein_rsl = jobDescript.data["GlobusRSL"] - if "project_id" in jobDescript.data["AuthMethod"]: + supports_project_id = False + try: + # New protocol + if submit_credentials.auth_set.supports(ParameterName.PROJECT_ID): + supports_project_id = True + except AttributeError: + # Old protocol + if "ProjectId" in jobDescript.data["AuthMethod"]: + supports_project_id = True + + if supports_project_id: # Append project id to the rsl - glidein_rsl = "{}(project={})".format(glidein_rsl, submit_credentials.identity_credentials["ProjectId"]) - exe_env.append("GLIDEIN_PROJECT_ID=%s" % submit_credentials.identity_credentials["ProjectId"]) + glidein_rsl = "{}(project={})".format( + glidein_rsl, submit_credentials.parameters[ParameterName.PROJECT_ID] + ) + exe_env.append("GLIDEIN_PROJECT_ID=%s" % submit_credentials.parameters[ParameterName.PROJECT_ID]) exe_env.append("GLIDEIN_RSL=%s" % glidein_rsl) @@ -2211,15 +2245,6 @@ def isGlideinHeldNTimes(jobInfo, factoryConfig=None, n=20): return greater_than_n_iterations -############################################################ -# only allow simple strings -def is_str_safe(s): - for c in s: - if c not in ("._-@" + string.ascii_letters + string.digits): - return False - return True - - class GlideinTotals: """Keeps track of all glidein totals.""" diff --git a/frontend/glideinFrontend.py b/frontend/glideinFrontend.py index c7007ba75..3e2d3e54d 100755 --- a/frontend/glideinFrontend.py +++ b/frontend/glideinFrontend.py @@ -92,7 +92,7 @@ def poll_group_process(group_name, child): pass # ignore try: tempErr = child.stderr.read() - if tempOut: + if tempErr: logSupport.log.warning(f"[{group_name}]: {tempErr}") except OSError: pass # ignore diff --git a/frontend/glideinFrontendConfig.py b/frontend/glideinFrontendConfig.py index 7237310e2..c558cd6a6 100644 --- a/frontend/glideinFrontendConfig.py +++ b/frontend/glideinFrontendConfig.py @@ -427,18 +427,24 @@ def _merge(self): if t in data: self.merged_data[t] = data[t] - proxies = [] + proxies = [] # TODO: Investigate how to merge global and group credentials. + parameters = {} # switching the order, so that the group credential will # be chosen before the global credential when ProxyFirst is used. for data in (self.element_data, self.frontend_data): if "Proxies" in data: proxies += eval(data["Proxies"]) + if "Parameters" in data: + parameters.update(eval(data["Parameters"])) self.merged_data["Proxies"] = proxies + self.merged_data["Parameters"] = parameters proxy_descript_attrs = [ "ProxySecurityClasses", "ProxyTrustDomains", "ProxyTypes", + "CredentialPurposes", + "CredentialContexts", "CredentialGenerators", "ProxyKeyFiles", "ProxyPilotFiles", diff --git a/frontend/glideinFrontendElement.py b/frontend/glideinFrontendElement.py index 56fcb206d..327d2331b 100755 --- a/frontend/glideinFrontendElement.py +++ b/frontend/glideinFrontendElement.py @@ -21,8 +21,6 @@ import time import traceback -from importlib import import_module - from glideinwms.frontend import ( glideinFrontendConfig, glideinFrontendDowntimeLib, @@ -35,6 +33,7 @@ # from glideinwms.lib.util import file_tmp2final from glideinwms.lib import cleanupSupport, condorMonitor, logSupport, pubCrypto, servicePerformance, token_util +from glideinwms.lib.credentials import CredentialPurpose from glideinwms.lib.disk_cache import DiskCache from glideinwms.lib.fork import fork_in_bg, ForkManager, wait_for_pids from glideinwms.lib.pidSupport import register_sighandler @@ -163,8 +162,8 @@ def __init__(self, parent_pid, work_dir, group_name, action): self.removal_requests_tracking = self.elementDescript.element_data["RemovalRequestsTracking"] self.removal_margin = int(self.elementDescript.element_data["RemovalMargin"]) - # Default behavior: Use factory proxies unless configure overrides it - self.x509_proxy_plugin = None + # Default behavior: Use first credential that matches auth_method. + self.credentials_plugin: glideinFrontendPlugins.CredentialsPlugin = None # If not None, this is a request for removal of glideins only (i.e. do not ask for more) self.request_removal_wtype = None @@ -217,8 +216,8 @@ def configure(self): % (self.elementDescript.merged_data["ProxySelectionPlugin"], list(proxy_plugins.keys())) ) return 1 - self.x509_proxy_plugin = proxy_plugins[self.elementDescript.merged_data["ProxySelectionPlugin"]]( - group_dir, glideinFrontendPlugins.createCredentialList(self.elementDescript) + self.credentials_plugin = proxy_plugins[self.elementDescript.merged_data["ProxySelectionPlugin"]]( + group_dir, glideinFrontendPlugins.createRequestBundle(self.elementDescript) ) self.idtoken_lifetime = int(self.elementDescript.merged_data.get("IDTokenLifetime", 24)) self.idtoken_keyname = self.elementDescript.merged_data.get("IDTokenKeyname", "FRONTEND") @@ -535,9 +534,9 @@ def iterate_one(self): # Update x509 user map and give proxy plugin a chance # to update based on condor stats - if self.x509_proxy_plugin: + if self.credentials_plugin: logSupport.log.info("Updating usermap") - self.x509_proxy_plugin.update_usermap( + self.credentials_plugin.update_usermap( self.condorq_dict, condorq_dict_types, self.status_dict, self.status_dict_types ) @@ -552,7 +551,7 @@ def iterate_one(self): self.signatureDescript.signature_type, self.signatureDescript.frontend_descript_signature, self.signatureDescript.group_descript_signature, - x509_proxies_plugin=self.x509_proxy_plugin, + credentials_plugin=self.credentials_plugin, ha_mode=self.ha_mode, ) descript_obj.add_monitoring_url(self.monitoring_web_url) @@ -566,7 +565,7 @@ def iterate_one(self): self.condorq_match_list = [f[0] for f in self.elementDescript.merged_data["JobMatchAttrs"]] servicePerformance.startPerfMetricEvent(self.group_name, "matchmaking") - self.do_match() + self.do_match() # Populates condorq_dict_types["Idle"]["total"] servicePerformance.endPerfMetricEvent(self.group_name, "matchmaking") logSupport.log.info( @@ -776,10 +775,10 @@ def iterate_one(self): glidein_monitors["Glideins%s" % t] = count_status[t] """ - for cred in self.x509_proxy_plugin.cred_list: - glidein_monitors_per_cred[cred.getId()] = {} + for cred in self.credentials_plugin.cred_list: + glidein_monitors_per_cred[cred.id] = {} for t in count_status: - glidein_monitors_per_cred[cred.getId()]['Glideins%s' % t] = count_status_per_cred[cred.getId()][t] + glidein_monitors_per_cred[cred.id]['Glideins%s' % t] = count_status_per_cred[cred.id][t] """ # Number of credentials that have running and glideins. @@ -789,35 +788,36 @@ def iterate_one(self): # Credential specific stats are not presented anywhere except the # classad. Monitoring info in frontend and factory shows # aggregated info considering all the credentials + submit_creds = self.credentials_plugin.get_credentials(credential_purpose=CredentialPurpose.REQUEST) creds_with_running = 0 - for cred in self.x509_proxy_plugin.cred_list: - glidein_monitors_per_cred[cred.getId()] = {} + for cred in submit_creds: + glidein_monitors_per_cred[cred.id] = {} for t in count_status: - glidein_monitors_per_cred[cred.getId()]["Glideins%s" % t] = count_status_per_cred[cred.getId()][t] - glidein_monitors_per_cred[cred.getId()]["ScaledRunning"] = 0 + glidein_monitors_per_cred[cred.id]["Glideins%s" % t] = count_status_per_cred[cred.id][t] + glidein_monitors_per_cred[cred.id]["ScaledRunning"] = 0 # This credential has running glideins. - if glidein_monitors_per_cred[cred.getId()]["GlideinsRunning"]: + if glidein_monitors_per_cred[cred.id]["GlideinsRunning"]: creds_with_running += 1 if creds_with_running: # Counter to handle rounding errors scaled = 0 tr = glidein_monitors["Running"] - for cred in self.x509_proxy_plugin.cred_list: - if glidein_monitors_per_cred[cred.getId()]["GlideinsRunning"]: + for cred in submit_creds: + if glidein_monitors_per_cred[cred.id]["GlideinsRunning"]: # This cred has running. Scale them down if (creds_with_running - scaled) == 1: # This is the last one. Assign remaining running - glidein_monitors_per_cred[cred.getId()]["ScaledRunning"] = ( + glidein_monitors_per_cred[cred.id]["ScaledRunning"] = ( tr - (tr // creds_with_running) * scaled ) scaled += 1 break else: - glidein_monitors_per_cred[cred.getId()]["ScaledRunning"] = tr // creds_with_running + glidein_monitors_per_cred[cred.id]["ScaledRunning"] = tr // creds_with_running scaled += 1 key_obj = None @@ -852,25 +852,46 @@ def iterate_one(self): else: logSupport.log.debug("could NOT find condor token: %s" % entry_token_name) - # now try to generate a credential using a generator plugin - generator_name, stkn = self.generate_credential( - self.elementDescript, glidein_el, self.group_name, trust_domain + # Generate credentials and parameters + self.credentials_plugin.generate_credentials( + elementDescript=self.elementDescript, + glidein_el=glidein_el, + group_name=self.group_name, + trust_domain=trust_domain, ) + self.credentials_plugin.generate_parameters( + elementDescript=self.elementDescript, + glidein_el=glidein_el, + group_name=self.group_name, + trust_domain=trust_domain, + ) + + # TODO: Add a `context` attribute to the Credential class. + # TODO: Add a `context` argument to the Credential constructor, load, and renew methods. + # TODO: Add a `context` attribute to the credential elements in the frontend configuration. + # TODO: Pass `context` to Generator.generate on CredentialGenerator. + # NOTE: This allows contextual credential generation, and will make it possible to define round-robin from the config file. + # now try to generate a credential using a generator plugin + # generator_name, stkn = credentials.generate_credential( + # self.elementDescript, glidein_el, self.group_name, trust_domain + # ) + # TODO: Remove this code once we are sure the new credentials work properly + # # look for a local scitoken if no credential was generated - if not stkn: - stkn = self.get_scitoken(self.elementDescript, trust_domain) - - if stkn: - if generator_name: - for cred_el in advertizer.descript_obj.x509_proxies_plugin.cred_list: - if cred_el.filename == generator_name: - cred_el.generated_data = stkn - break - if token_util.token_str_expired(stkn): - logSupport.log.warning("SciToken is expired, not forwarding.") - else: - gp_encrypt["frontend_scitoken"] = stkn + # if not stkn: + # stkn = credentials.get_scitoken(self.elementDescript, trust_domain) + + # if stkn: + # if generator_name: + # for cred_el in advertizer.descript_obj.credentials_plugin.cred_list: + # if cred_el.filename == generator_name: + # cred_el.generated_data = stkn + # break + # if token_util.token_str_expired(stkn): + # logSupport.log.warning("SciToken is expired, not forwarding.") + # else: + # gp_encrypt["frontend_scitoken"] = stkn # now advertise logSupport.log.debug("advertising tokens %s" % gp_encrypt.keys()) @@ -951,109 +972,6 @@ def iterate_one(self): return - def get_scitoken(self, elementDescript, trust_domain): - """Look for a local SciToken specified for the trust domain. - - Args: - elementDescript (ElementMergedDescript): element descript - trust_domain (string): trust domain for the element - - Returns: - string, None: SciToken or None if not found - """ - - scitoken_fullpath = "" - cred_type_data = elementDescript.element_data.get("ProxyTypes") - trust_domain_data = elementDescript.element_data.get("ProxyTrustDomains") - if not cred_type_data: - cred_type_data = elementDescript.frontend_data.get("ProxyTypes") - if not trust_domain_data: - trust_domain_data = elementDescript.frontend_data.get("ProxyTrustDomains") - if trust_domain_data and cred_type_data: - cred_type_map = eval(cred_type_data) - trust_domain_map = eval(trust_domain_data) - for cfname in cred_type_map: - if cred_type_map[cfname] == "scitoken": - if trust_domain_map[cfname] == trust_domain: - scitoken_fullpath = cfname - - if os.path.exists(scitoken_fullpath): - try: - logSupport.log.debug(f"found scitoken {scitoken_fullpath}") - stkn = "" - with open(scitoken_fullpath) as fbuf: - for line in fbuf: - stkn += line - stkn = stkn.strip() - return stkn - except Exception as err: - logSupport.log.exception(f"failed to read scitoken: {err}") - - return None - - def generate_credential(self, elementDescript, glidein_el, group_name, trust_domain): - """Generates a credential with a credential generator plugin provided for the trust domain. - - Args: - elementDescript (ElementMergedDescript): element descript - glidein_el (dict): glidein element - group_name (string): group name - trust_domain (string): trust domain for the element - - Returns: - string, None: Credential or None if not generated - """ - - ### The credential generator plugin should define the following function: - # def get_credential(log:logger, group:str, entry:dict{name:str, gatekeeper:str}, trust_domain:str): - # Generates a credential given the parameter - - # Args: - # log:logger - # group:str, - # entry:dict{ - # name:str, - # gatekeeper:str}, - # trust_domain:str, - # Return - # tuple - # token:str - # lifetime:int seconds of remaining lifetime - # Exception - # KeyError - miss some information to generate - # ValueError - could not generate the token - - generator = None - generators = elementDescript.element_data.get("CredentialGenerators") - trust_domain_data = elementDescript.element_data.get("ProxyTrustDomains") - if not generators: - generators = elementDescript.frontend_data.get("CredentialGenerators") - if not trust_domain_data: - trust_domain_data = elementDescript.frontend_data.get("ProxyTrustDomains") - if trust_domain_data and generators: - generators_map = eval(generators) - trust_domain_map = eval(trust_domain_data) - for cfname in generators_map: - if trust_domain_map[cfname] == trust_domain: - generator = generators_map[cfname] - logSupport.log.debug(f"found credential generator plugin {generator}") - try: - if generator not in plugins: - plugins[generator] = import_module(generator) - entry = { - "name": glidein_el["attrs"].get("EntryName"), - "gatekeeper": glidein_el["attrs"].get("GLIDEIN_Gatekeeper"), - "factory": glidein_el["attrs"].get("AuthenticatedIdentity"), - } - stkn, _ = plugins[generator].get_credential(logSupport, group_name, entry, trust_domain) - return cfname, stkn - except ModuleNotFoundError: - logSupport.log.warning(f"Failed to load credential generator plugin {generator}") - except Exception as e: # catch any exception from the plugin to prevent the frontend from crashing - logSupport.log.warning(f"Failed to generate credential: {e}.") - - return None, None - def refresh_entry_token(self, glidein_el): """Create or update a condor token for an entry point @@ -2025,9 +1943,9 @@ def get_condor_q(self, schedd_name): condorq_dict = {} try: condorq_format_list = self.elementDescript.merged_data["JobMatchAttrs"] - if self.x509_proxy_plugin: + if self.credentials_plugin: condorq_format_list = list(condorq_format_list) + list( - self.x509_proxy_plugin.get_required_job_attributes() + self.credentials_plugin.get_required_job_attributes() ) ### Add in elements to help in determining if jobs have voms creds @@ -2076,9 +1994,9 @@ def get_condor_status(self): ("TotalSlotCpus", "i"), ] - if self.x509_proxy_plugin: + if self.credentials_plugin: status_format_list = list(status_format_list) + list( - self.x509_proxy_plugin.get_required_classad_attributes() + self.credentials_plugin.get_required_classad_attributes() ) # Consider multicore slots with free cpus/memory only @@ -2303,6 +2221,8 @@ def subprocess_count_glidein(self, glidein_list): """ out = () + submit_creds = self.credentials_plugin.get_credentials(credential_purpose=CredentialPurpose.REQUEST) + count_status_multi = {} # Count distribution per credentials count_status_multi_per_cred = {} @@ -2311,8 +2231,8 @@ def subprocess_count_glidein(self, glidein_list): count_status_multi[request_name] = {} count_status_multi_per_cred[request_name] = {} - for cred in self.x509_proxy_plugin.cred_list: - count_status_multi_per_cred[request_name][cred.getId()] = {} + for cred in submit_creds: + count_status_multi_per_cred[request_name][cred.id] = {} # It is cheaper to get Idle and Running from request-only # classads then filter out requests from Idle and Running @@ -2331,8 +2251,7 @@ def subprocess_count_glidein(self, glidein_list): "RunningCores": glideinFrontendLib.getRunningCoresCondorStatus(total_req_dict), } - for st in req_dict_types: - req_dict = req_dict_types[st] + for st, req_dict in req_dict_types.items(): if st in ("TotalCores", "IdleCores", "RunningCores"): count_status_multi[request_name][st] = glideinFrontendLib.countCoresCondorStatus(req_dict, st) elif st == "Running": @@ -2343,23 +2262,22 @@ def subprocess_count_glidein(self, glidein_list): else: count_status_multi[request_name][st] = glideinFrontendLib.countCondorStatus(req_dict) - for cred in self.x509_proxy_plugin.cred_list: - cred_id = cred.getId() - cred_dict = glideinFrontendLib.getClientCondorStatusCredIdOnly(req_dict, cred_id) + for cred in submit_creds: + cred_dict = glideinFrontendLib.getClientCondorStatusCredIdOnly(req_dict, cred.id) if st in ("TotalCores", "IdleCores", "RunningCores"): - count_status_multi_per_cred[request_name][cred_id][st] = ( + count_status_multi_per_cred[request_name][cred.id][st] = ( glideinFrontendLib.countCoresCondorStatus(cred_dict, st) ) elif st == "Running": # Running counts are computed differently because of # the dict composition. Dict also has p-slots # corresponding to the dynamic slots - count_status_multi_per_cred[request_name][cred_id][st] = ( + count_status_multi_per_cred[request_name][cred.id][st] = ( glideinFrontendLib.countRunningCondorStatus(cred_dict) ) else: - count_status_multi_per_cred[request_name][cred_id][st] = glideinFrontendLib.countCondorStatus( + count_status_multi_per_cred[request_name][cred.id][st] = glideinFrontendLib.countCondorStatus( cred_dict ) diff --git a/frontend/glideinFrontendInterface.py b/frontend/glideinFrontendInterface.py index 655df1812..cecbbeca1 100644 --- a/frontend/glideinFrontendInterface.py +++ b/frontend/glideinFrontendInterface.py @@ -8,6 +8,7 @@ import calendar import copy import os +import pickle import time from glideinwms.lib import symCrypto # pubCrypto was removed because unused @@ -19,9 +20,16 @@ defaults, glideinWMSVersion, logSupport, - token_util, x509Support, ) +from glideinwms.lib.credentials import ( + AuthenticationMethod, + create_credential, + CredentialError, + CredentialPair, + CredentialPurpose, + CredentialType, +) from glideinwms.lib.util import hash_nc ############################################################ @@ -286,7 +294,8 @@ def format_condor_dict(data): # and not for every iteration. -class Credential: +# DEPRECATED: use glideinwms.lib.credentials.Credential +class LegacyCredential: def __init__(self, proxy_id, proxy_fname, elementDescript): self.req_idle = 0 self.req_max_run = 0 @@ -457,7 +466,7 @@ def renew(self): if (remaining != -1) and (self.update_frequency != -1) and (remaining < self.update_frequency): self.create() - def supports_auth_method(self, auth_method): + def supports_auth_method(self, auth_method): # TODO: Check for credentials refactoring impact """ Check if this credential has all the necessary info to support auth_method for a given factory entry @@ -522,7 +531,7 @@ def __init__( signtype, main_sign, group_sign, - x509_proxies_plugin=None, + credentials_plugin=None, ha_mode="master", ): self.my_name = my_name @@ -532,7 +541,7 @@ def __init__( self.main_descript = main_descript self.signtype = signtype self.main_sign = main_sign - self.x509_proxies_plugin = x509_proxies_plugin + self.credentials_plugin = credentials_plugin self.group_name = group_name self.group_descript = group_descript self.group_sign = group_sign @@ -543,7 +552,7 @@ def add_monitoring_url(self, monitoring_web_url): self.monitoring_web_url = monitoring_web_url def need_encryption(self): - return self.x509_proxies_plugin is not None + return self.credentials_plugin is not None # return a list of strings def get_id_attrs(self): @@ -769,7 +778,7 @@ def __init__(self, descript_obj): # must be of type FrontendDescript # set a few defaults self.unique_id = 1 self.adname = None - self.x509_proxies_data = [] + self.request_credentials = [] self.ha_mode = "master" self.glidein_config_limits = {} @@ -828,41 +837,26 @@ def get_queue_len(self): count += len(self.factory_queue[factory_pool]) return count + # TODO: Review the need for this method. If we're keeping it, it should be refactored. def renew_and_load_credentials(self): """ Get the list of proxies, invoke the renew scripts if any, and read the credentials in memory. - Modifies the self.x509_proxies_data variable. + Modifies the self.request_credentials variable. """ - self.x509_proxies_data = [] - if self.descript_obj.x509_proxies_plugin is not None: - self.x509_proxies_data = self.descript_obj.x509_proxies_plugin.get_credentials() - nr_credentials = len(self.x509_proxies_data) + # TODO: remove references to x509 + self.request_credentials = [] + if self.descript_obj.credentials_plugin is not None: + self.request_credentials = self.descript_obj.credentials_plugin.get_request_credentials() + nr_credentials = len(self.request_credentials) else: nr_credentials = 0 - for i in range(nr_credentials): - cred_el = self.x509_proxies_data[i] + for cred_el in self.request_credentials: cred_el.advertize = True - cred_el.renew() - cred_el.createIfNotExist() - - cred_el.loaded_data = [] - for cred_file in (cred_el.filename, cred_el.key_fname, cred_el.pilot_fname): - if cred_file: - try: - cred_data = cred_el.generated_data - except AttributeError: - # TODO: credential parsing form file could fail (wrong permission, not found, ...) - # Add message? Handle here or declare raising - cred_data = cred_el.getString(cred_file) - if cred_data: - cred_el.loaded_data.append((cred_file, cred_data)) - else: - # We encountered error with this credential - # Move onto next credential - break + cred_el.credential.renew() + cred_el.credential.save_to_file(overwrite=False, continue_if_no_path=True) return nr_credentials @@ -962,35 +956,34 @@ def createGlobalAdvertizeWorkFile(self, factory_pool): tmpname = self.adname glidein_params_to_encrypt = {} with open(tmpname, "a") as fd: - nr_credentials = len(self.x509_proxies_data) + nr_credentials = len(self.request_credentials) if nr_credentials > 0: - glidein_params_to_encrypt["NumberOfCredentials"] = "%s" % nr_credentials + glidein_params_to_encrypt["NumberOfCredentials"] = ( + f"{nr_credentials}" # TODO: Check if it needs refactoring + ) request_name = "Global" if factory_pool in self.global_params: request_name, security_name = self.global_params[factory_pool] glidein_params_to_encrypt["SecurityName"] = security_name classad_name = f"{request_name}@{self.descript_obj.my_name}" - fd.write('MyType = "%s"\n' % frontendConfig.client_global) - fd.write('GlideinMyType = "%s"\n' % frontendConfig.client_global) - fd.write('GlideinWMSVersion = "%s"\n' % frontendConfig.glideinwms_version) - fd.write('Name = "%s"\n' % classad_name) - fd.write('FrontendName = "%s"\n' % self.descript_obj.frontend_name) - fd.write('FrontendHAMode = "%s"\n' % self.ha_mode) - fd.write('GroupName = "%s"\n' % self.descript_obj.group_name) - fd.write('ClientName = "%s"\n' % self.descript_obj.my_name) - for i in range(nr_credentials): - cred_el = self.x509_proxies_data[i] + fd.write(f'MyType = "{frontendConfig.client_global}"\n') + fd.write(f'GlideinMyType = "{frontendConfig.client_global}"\n') + fd.write(f'GlideinWMSVersion = "{frontendConfig.glideinwms_version}"\n') + fd.write(f'Name = "{classad_name}"\n') + fd.write(f'FrontendName = "{self.descript_obj.frontend_name}"\n') + fd.write(f'FrontendHAMode = "{self.ha_mode}"\n') + fd.write(f'GroupName = "{self.descript_obj.group_name}"\n') + fd.write(f'ClientName = "{self.descript_obj.my_name}"\n') + for cred_el in self.request_credentials: if not cred_el.advertize: continue # we already determined it cannot be used - for ld_el in cred_el.loaded_data: - ld_fname, ld_data = ld_el - glidein_params_to_encrypt[cred_el.file_id(ld_fname)] = ld_data - if hasattr(cred_el, "security_class"): - # Convert the sec class to a string so the Factory can interpret the value correctly - glidein_params_to_encrypt["SecurityClass" + cred_el.file_id(ld_fname)] = str( - cred_el.security_class - ) + glidein_params_to_encrypt[cred_el.credential.id] = cred_el.credential.string + if hasattr(cred_el, "security_class"): + # Convert the sec class to a string so the Factory can interpret the value correctly + glidein_params_to_encrypt["SecurityClass" + cred_el.credential.id] = str( + cred_el.credential.security_class + ) key_obj = None if factory_pool in self.global_key: @@ -1007,7 +1000,7 @@ def createGlobalAdvertizeWorkFile(self, factory_pool): advertizeGCGounter[classad_name] += 1 else: advertizeGCGounter[classad_name] = 0 - fd.write("UpdateSequenceNumber = %s\n" % advertizeGCGounter[classad_name]) + fd.write(f"UpdateSequenceNumber = {advertizeGCGounter[classad_name]}\n") # add a final empty line... useful when appending fd.write("\n") @@ -1144,230 +1137,180 @@ def createAdvertizeWorkFile(self, factory_pool, params_obj, key_obj=None, file_i adname, unique_id and x509_proxies_data to be set. """ - global frontendConfig - global advertizeGCCounter - descript_obj = self.descript_obj + cred_filename_arr = [] logSupport.log.debug("In create Advertize work") - factory_trust, factory_auth = self.factory_constraint[params_obj.request_name] - - total_nr_credentials = len(self.x509_proxies_data) - - cred_filename_arr = [] - - if total_nr_credentials == 0: + # Make sure we have credentials to work with + if len(self.request_credentials) == 0: + logSupport.log.warning(f"No credentials match for factory pool {factory_pool}, not advertising request") raise NoCredentialException - # get_credentials will augment the needed credentials with the requests - # A little weird, but that's how it works right now - # The credential objects are also persistent, so this will be a subset of self.x509_proxies_data - credentials_with_requests = descript_obj.x509_proxies_plugin.get_credentials( - params_obj=params_obj, credential_type=factory_auth, trust_domain=factory_trust - ) - nr_credentials = len(credentials_with_requests) - if nr_credentials == 0: + # Determine required credentials for the entry + factory_trust, factory_auth = self.factory_constraint[params_obj.request_name] + factory_auth = AuthenticationMethod(factory_auth) + auth_set = factory_auth.match(self.descript_obj.credentials_plugin.security_bundle) + if not auth_set: raise NoCredentialException + params_obj.glidein_params_to_encrypt["AuthSet"] = pickle.dumps(auth_set) - if file_id_cache is None: - # create a local cache, if no global provided - file_id_cache = CredentialCache() + # Pack payload credentials to send with the request + payload_creds = [ + cred.copy() + for cred in self.descript_obj.credentials_plugin.get_credentials( + trust_domain=factory_trust, credential_purpose=CredentialPurpose.PAYLOAD + ) + ] + for key in params_obj.glidein_params_to_encrypt.keys(): + if key.endswith(".idtoken"): + try: + idtoken_str = params_obj.glidein_params_to_encrypt[key] + idtoken = create_credential(idtoken_str, cred_type=CredentialType.IDTOKEN) + payload_creds.append(idtoken) + except CredentialError as e: + logSupport.log.warning(f"Failed to create idtoken credential: {e}") + params_obj.glidein_params_to_encrypt["PayloadCredentials"] = pickle.dumps(payload_creds) + + # Pack request credentials to send with the request + request_creds = [ + rc.credential.copy() for rc in self.request_credentials if rc.credential.trust_domain == factory_trust + ] + params_obj.glidein_params_to_encrypt["RequestCredentials"] = pickle.dumps(request_creds) + + # Pack parameters to send to the request + security_params = [param.copy() for param in self.descript_obj.credentials_plugin.params_dict.values()] + params_obj.glidein_params_to_encrypt["SecurityParameters"] = pickle.dumps(security_params) + + # Assign work to the credentials per the plugin policy + self.descript_obj.credentials_plugin.assign_work(self.request_credentials, params_obj, auth_set) + + for request_cred in self.request_credentials: + if not request_cred.advertize: + logSupport.log.debug( + f"Skipping credential with 'advertize' set to False. ({request_cred.credential.path})" + ) + continue # We already determined it cannot be used + if (request_cred.credential.trust_domain != factory_trust) and (factory_trust != "Any"): + logSupport.log.warning( + f"Skipping credential with trust_domain {request_cred.credential.trust_domain}. " + f"Factory requires {factory_trust}. ({request_cred.credential.path})" + ) + continue # Skip credentials that don't match the trust domain + if request_cred.req_idle == 0 and request_cred.req_max_run == 0: + logSupport.log.debug(f"Skipping credential with no work assigned. ({request_cred.credential.path})") + continue # Skip credentials with no work assigned + + classad_name = f"{request_cred.credential.id}_{params_obj.request_name}@{self.descript_obj.my_name}" + + glidein_params_to_encrypt = {} + if params_obj.glidein_params_to_encrypt: + glidein_params_to_encrypt = copy.deepcopy(params_obj.glidein_params_to_encrypt) + + # Convert the security class to a string so the Factory can interpret the value correctly + glidein_params_to_encrypt["SecurityClass"] = str(request_cred.credential.security_class) + if params_obj.security_name is not None: + glidein_params_to_encrypt["SecurityName"] = params_obj.security_name + + glidein_params_to_encrypt[request_cred.credential.classad_attribute] = request_cred.credential.id + if request_cred.credential.cred_type is CredentialType.SCITOKEN: + glidein_params_to_encrypt["frontend_scitoken"] = request_cred.credential.string + if isinstance(request_cred.credential, CredentialPair): + glidein_params_to_encrypt[request_cred.credential.private_credential.classad_attribute] = ( + request_cred.credential.private_credential.id + ) - for i in range(nr_credentials): - fd = None - glidein_monitors_this_cred = {} - try: - encrypted_params = {} # none by default - glidein_params_to_encrypt = params_obj.glidein_params_to_encrypt - if glidein_params_to_encrypt is None: - glidein_params_to_encrypt = {} - else: - glidein_params_to_encrypt = copy.deepcopy(glidein_params_to_encrypt) - classad_name = f"{params_obj.request_name}@{descript_obj.my_name}" - - req_idle = 0 - req_max_run = 0 - - # credential_el (Credebtial()) - credential_el = credentials_with_requests[i] - logSupport.log.debug(f"Checking Credential file {credential_el.filename} ...") - if not credential_el.advertize: - # We already determined it cannot be used - # if hasattr(credential_el,'filename'): - # filestr=credential_el.filename - # logSupport.log.warning("Credential file %s had some earlier problem in loading so not advertizing, skipping..."%(filestr)) - continue - - if credential_el.supports_auth_method("scitoken"): - try: - # try first for credential generator - token_expired = token_util.token_str_expired(credential_el.generated_data) - except AttributeError: - # then try file stored credential - token_expired = token_util.token_file_expired(credential_el.filename) - if token_expired: - logSupport.log.warning( - f"Credential file {credential_el.filename} has expired scitoken, skipping" - ) - continue - glidein_params_to_encrypt["ScitokenId"] = file_id_cache.file_id( - credential_el, credential_el.filename + # Encrypt parameters + encrypted_params = {} + if key_obj: + for attr in glidein_params_to_encrypt: + encrypted_params[attr] = key_obj.encrypt_hex(glidein_params_to_encrypt[attr]).decode( + defaults.BINARY_ENCODING_CRYPTO ) - if params_obj.request_name in self.factory_constraint: - if (factory_auth != "Any") and (not credential_el.supports_auth_method(factory_auth)): - logSupport.log.warning( - "Credential %s does not match auth method %s (for %s), skipping..." - % (credential_el.type, factory_auth, params_obj.request_name) - ) - continue - if (credential_el.trust_domain != factory_trust) and (factory_trust != "Any"): - logSupport.log.warning( - "Credential %s does not match %s (for %s) domain, skipping..." - % (credential_el.trust_domain, factory_trust, params_obj.request_name) - ) - continue - # Convert the sec class to a string so the Factory can interpret the value correctly - glidein_params_to_encrypt["SecurityClass"] = str(credential_el.security_class) - classad_name = credential_el.file_id(credential_el.filename, ignoredn=True) + "_" + classad_name - if "username_password" in credential_el.type: - glidein_params_to_encrypt["Username"] = file_id_cache.file_id(credential_el, credential_el.filename) - glidein_params_to_encrypt["Password"] = file_id_cache.file_id( - credential_el, credential_el.key_fname - ) - if "grid_proxy" in credential_el.type: - glidein_params_to_encrypt["SubmitProxy"] = file_id_cache.file_id( - credential_el, credential_el.filename - ) - if "cert_pair" in credential_el.type: - glidein_params_to_encrypt["PublicCert"] = file_id_cache.file_id( - credential_el, credential_el.filename - ) - glidein_params_to_encrypt["PrivateCert"] = file_id_cache.file_id( - credential_el, credential_el.key_fname - ) - if "key_pair" in credential_el.type: - glidein_params_to_encrypt["PublicKey"] = file_id_cache.file_id( - credential_el, credential_el.filename - ) - glidein_params_to_encrypt["PrivateKey"] = file_id_cache.file_id( - credential_el, credential_el.key_fname - ) - if "auth_file" in credential_el.type: - glidein_params_to_encrypt["AuthFile"] = file_id_cache.file_id(credential_el, credential_el.filename) - if "vm_id" in credential_el.type: - if credential_el.vm_id_fname: - glidein_params_to_encrypt["VMId"] = self.vm_attribute_from_file( - credential_el.vm_id_fname, "VM_ID" - ) - else: - glidein_params_to_encrypt["VMId"] = str(credential_el.vm_id) - if "vm_type" in credential_el.type: - if credential_el.vm_type_fname: - glidein_params_to_encrypt["VMType"] = self.vm_attribute_from_file( - credential_el.vm_type_fname, "VM_TYPE" - ) - else: - glidein_params_to_encrypt["VMType"] = str(credential_el.vm_type) - # removing this, was here by mistake? glidein_params_to_encrypt['VMType']=str(credential_el.vm_type) - - # Process additional information of the credential - if credential_el.pilot_fname: - glidein_params_to_encrypt["GlideinProxy"] = file_id_cache.file_id( - credential_el, credential_el.pilot_fname - ) + # Generate classad info tuples + classad_info_tuples = ( + (frontendConfig.glidein_param_prefix, params_obj.glidein_params), + (frontendConfig.encrypted_param_prefix, encrypted_params), + (frontendConfig.glidein_config_prefix, self.glidein_config_limits), + ) - if credential_el.remote_username: # MM: or "username" in credential_el.type - glidein_params_to_encrypt["RemoteUsername"] = str(credential_el.remote_username) - if credential_el.project_id: - glidein_params_to_encrypt["ProjectId"] = str(credential_el.project_id) + # Get the glidein monitors for this credential + glidein_monitors_this_cred = params_obj.glidein_monitors_per_cred.get( + request_cred.credential.id, {} # type: ignore[attr-defined] + ) - (req_idle, req_max_run) = credential_el.get_usage_details() - logSupport.log.debug( - "Advertizing credential %s with (%d idle, %d max run) for request %s" - % (credential_el.filename, req_idle, req_max_run, params_obj.request_name) - ) + # Update Sequence number information + if classad_name in advertizeGCCounter: + advertizeGCCounter[classad_name] += 1 + else: + advertizeGCCounter[classad_name] = 0 - glidein_monitors_this_cred = params_obj.glidein_monitors_per_cred.get(credential_el.getId(), {}) + fname = f"{self.adname}" + if not frontendConfig.advertise_use_multi: + fname += f"_{self.unique_id}" + self.unique_id += 1 - if frontendConfig.advertise_use_multi is True: - fname = self.adname - cred_filename_arr.append(fname) - else: - fname = self.adname + "_" + str(self.unique_id) - self.unique_id += 1 - cred_filename_arr.append(fname) + try: logSupport.log.debug(f"Writing {fname}") - fd = open(fname, "a") - - fd.write('MyType = "%s"\n' % frontendConfig.client_id) - fd.write('GlideinMyType = "%s"\n' % frontendConfig.client_id) - fd.write('GlideinWMSVersion = "%s"\n' % frontendConfig.glideinwms_version) - fd.write('Name = "%s"\n' % classad_name) - fd.write("\n".join(descript_obj.get_id_attrs()) + "\n") - fd.write('ReqName = "%s"\n' % params_obj.request_name) - fd.write('ReqGlidein = "%s"\n' % params_obj.glidein_name) - - fd.write("\n".join(descript_obj.get_web_attrs()) + "\n") - - if params_obj.security_name is not None: - glidein_params_to_encrypt["SecurityName"] = params_obj.security_name - - if key_obj is not None: - fd.write("\n".join(key_obj.get_key_attrs()) + "\n") - for attr in glidein_params_to_encrypt: - encrypted_params[attr] = key_obj.encrypt_hex(glidein_params_to_encrypt[attr]).decode( - defaults.BINARY_ENCODING_CRYPTO - ) - - fd.write("ReqIdleGlideins = %i\n" % req_idle) - fd.write("ReqMaxGlideins = %i\n" % req_max_run) - fd.write('ReqRemoveExcess = "%s"\n' % params_obj.remove_excess_str) - fd.write("ReqRemoveExcessMargin = %i\n" % params_obj.remove_excess_margin) - fd.write('ReqIdleLifetime = "%s"\n' % params_obj.idle_lifetime) - fd.write('WebMonitoringURL = "%s"\n' % descript_obj.monitoring_web_url) - - # write out both the params - classad_info_tuples = ( - (frontendConfig.glidein_param_prefix, params_obj.glidein_params), - (frontendConfig.encrypted_param_prefix, encrypted_params), - (frontendConfig.glidein_config_prefix, self.glidein_config_limits), - ) - for prefix, data in classad_info_tuples: - for attr in list(data.keys()): - writeTypedClassadAttrToFile(fd, f"{prefix}{attr}", data[attr]) - - for attr_name in params_obj.glidein_monitors: - prefix = frontendConfig.glidein_monitor_prefix - # attr_value = params_obj.glidein_monitors[attr_name] - if (attr_name == "RunningHere") and glidein_monitors_this_cred: - # This double check is for backward compatibility - attr_value = glidein_monitors_this_cred.get("GlideinsRunning", 0) - elif (attr_name == "Running") and glidein_monitors_this_cred: - # This double check is for backward compatibility - attr_value = glidein_monitors_this_cred.get("ScaledRunning", 0) - else: - attr_value = glidein_monitors_this_cred.get(attr_name, params_obj.glidein_monitors[attr_name]) - writeTypedClassadAttrToFile(fd, f"{prefix}{attr_name}", attr_value) - - # Update Sequence number information - if classad_name in advertizeGCCounter: - advertizeGCCounter[classad_name] += 1 - else: - advertizeGCCounter[classad_name] = 0 - fd.write("UpdateSequenceNumber = %s\n" % advertizeGCCounter[classad_name]) - - # add a final empty line... useful when appending - fd.write("\n") - fd.close() - except Exception: - logSupport.log.exception("Exception writing advertisement file: ") - # remove file in case of problems - if fd is not None: - fd.close() + with open(fname, "a", encoding="utf-8") as fd: + fd.write(f'MyType = "{frontendConfig.client_id}"\n') + fd.write(f'GlideinMyType = "{frontendConfig.client_id}"\n') + fd.write(f'GlideinWMSVersion = "{frontendConfig.glideinwms_version}"\n') + fd.write(f'Name = "{classad_name}"\n') + fd.write("\n".join(self.descript_obj.get_id_attrs()) + "\n") + fd.write(f'ReqName = "{params_obj.request_name}"\n') + fd.write(f'ReqGlidein = "{params_obj.glidein_name}"\n') + + fd.write("\n".join(self.descript_obj.get_web_attrs()) + "\n") + + if key_obj: + fd.write("\n".join(key_obj.get_key_attrs()) + "\n") + + fd.write(f"ReqIdleGlideins = {request_cred.req_idle}\n") + fd.write(f"ReqMaxGlideins = {request_cred.req_max_run}\n") + fd.write(f'ReqRemoveExcess = "{params_obj.remove_excess_str}"\n') + fd.write(f"ReqRemoveExcessMargin = {params_obj.remove_excess_margin}\n") + fd.write(f'ReqIdleLifetime = "{params_obj.idle_lifetime}"\n') + fd.write(f'WebMonitoringURL = "{self.descript_obj.monitoring_web_url}"\n') + + for prefix, data in classad_info_tuples: + for attr in list(data.keys()): + writeTypedClassadAttrToFile(fd, f"{prefix}{attr}", data[attr]) + + for attr_name in params_obj.glidein_monitors: + prefix = frontendConfig.glidein_monitor_prefix + # attr_value = params_obj.glidein_monitors[attr_name] + if (attr_name == "RunningHere") and glidein_monitors_this_cred: + # This double check is for backward compatibility + attr_value = glidein_monitors_this_cred.get("GlideinsRunning", 0) + elif (attr_name == "Running") and glidein_monitors_this_cred: + # This double check is for backward compatibility + attr_value = glidein_monitors_this_cred.get("ScaledRunning", 0) + else: + attr_value = glidein_monitors_this_cred.get( + attr_name, params_obj.glidein_monitors[attr_name] + ) + writeTypedClassadAttrToFile(fd, f"{prefix}{attr_name}", attr_value) + + fd.write(f"UpdateSequenceNumber = {advertizeGCCounter[classad_name]}\n") + + # add a final empty line... useful when appending + fd.write("\n") + except Exception as e: + logSupport.log.exception(f"Exception writing advertisement file: {e}") + if os.path.exists(fname): os.remove(fname) raise + # TODO: Should we revert the changes done to advertizeGCCounter[classad_name]? + + cred_filename_arr.append(fname) + + logSupport.log.debug( + f"Advertizing credential {request_cred.credential.path} " # type: ignore[attr-defined] + f"with ({request_cred.req_idle} idle, {request_cred.req_max_run} max run) for request {params_obj.request_name}" + ) + return cred_filename_arr def set_glidein_config_limits(self, limits_data): diff --git a/frontend/glideinFrontendPlugins.py b/frontend/glideinFrontendPlugins.py index 3aed865c6..b2b937b13 100644 --- a/frontend/glideinFrontendPlugins.py +++ b/frontend/glideinFrontendPlugins.py @@ -10,9 +10,24 @@ import random import time -from glideinwms.lib import logSupport, util +from abc import ABC, abstractmethod +from typing import Iterable, List, Mapping -from . import glideinFrontendInterface, glideinFrontendLib +from glideinwms.frontend import glideinFrontendLib +from glideinwms.frontend.glideinFrontendInterface import AdvertizeParams +from glideinwms.lib import logSupport, util +from glideinwms.lib.credentials import ( + AuthenticationSet, + Credential, + CredentialGenerator, + CredentialPurpose, + CredentialType, + Parameter, + ParameterGenerator, + ParameterName, + RequestCredential, + SecurityBundle, +) ################################################################################ # # @@ -43,6 +58,128 @@ ################################################################################ +# TODO: Add type annotations to this class +class CredentialsPlugin(ABC): + """Base class for all credential plugins""" + + def __init__(self, config_dir: str, security_bundle: SecurityBundle): + self.security_bundle = security_bundle + + @property + def cred_list(self) -> List[Credential]: + return list(self.security_bundle.credentials.values()) + + @property + def params_dict(self) -> Mapping[ParameterName, Parameter]: + return self.security_bundle.parameters + + def get_required_job_attributes(self): + """what job attributes are used by this plugin + + Returns: + list: used job attributes, none + """ + return [] + + def get_required_classad_attributes(self): + """what glidein attributes are used by this plugin + + Returns: + list: used glidein attributes, none + """ + return [] + + def generate_credentials(self, **kwargs): + """ + Generate all credentials that are generators + + Args: + **kwargs: keyword arguments to be passed to the generator + """ + + for cred in self.cred_list: + if isinstance(cred, CredentialGenerator): + cred.generate(**kwargs) + + def generate_parameters(self, **kwargs): + """ + Generate all parameters that are generators + + Args: + **kwargs: keyword arguments to be passed to the generator + """ + + for param in self.params_dict.values(): + if isinstance(param, ParameterGenerator): + param.generate(**kwargs) + + def update_usermap(self, condorq_dict, condorq_dict_types, status_dict, status_dict_types): + return + + @abstractmethod + def get_credentials(self, credential_type=None, trust_domain=None, credential_purpose=None) -> List[Credential]: + pass + + @abstractmethod + def get_request_credentials(self) -> List[RequestCredential]: + pass + + @abstractmethod + def assign_work( + self, req_creds: Iterable[RequestCredential], params_obj: AdvertizeParams, auth_set: AuthenticationSet + ): + pass + + +class CredentialsBasic(CredentialsPlugin): + """This plugin returns all credentials + + This is can be a very useful default policy + """ + + def get_credentials(self, credential_type=None, trust_domain=None, credential_purpose=None) -> List[Credential]: + """get the credentials, given the condor_q and condor_status data + + Args: + params_obj: optional parameters to be used in job splitting + credential_type (str): optional credential type to match with a supported auth_metod + trust_domain (str): optional trust domain + + Returns: + list: list of credentials + """ + rtnlist = [] + for cred in self.cred_list: + if trust_domain and hasattr(cred, "trust_domain") and cred.trust_domain != trust_domain: + continue + if credential_type and hasattr(cred, "cred_type") and cred.cred_type != credential_type: + continue + if credential_purpose and cred.purpose != credential_purpose: + continue + rtnlist.append(cred) + return rtnlist + + def get_request_credentials(self) -> List[RequestCredential]: + req_creds = self.get_credentials(credential_purpose=CredentialPurpose.REQUEST) + req_creds = [RequestCredential(cred) for cred in req_creds] + return req_creds + + def assign_work( + self, req_creds: Iterable[RequestCredential], params_obj: AdvertizeParams, auth_set: AuthenticationSet + ): + for req_cred in req_creds: + if not req_cred.credential.cred_type: + continue + if auth_set.supports(req_cred.credential.cred_type): + req_cred.add_usage_details(params_obj.min_nr_glideins, params_obj.max_run_glideins) + return + + +########################################################### +### Proxy plugins are deprecated and should not be used ### +########################################################### + + class ProxyFirst: """This plugin always returns the first proxy Useful when there is only one proxy or for testing @@ -570,14 +707,19 @@ def list2ilist(lst): return out +# TODO: Deprecate in favor of createRequestBundle def createCredentialList(elementDescript): """Creates a list of Credentials for a proxy plugin""" - credential_list = [] - num = 0 - for proxy in elementDescript.merged_data["Proxies"]: - credential_list.append(glideinFrontendInterface.Credential(num, proxy, elementDescript)) - num += 1 - return credential_list + security_bundle = SecurityBundle() + security_bundle.load_from_element(elementDescript) + return security_bundle + + +def createRequestBundle(elementDescript): + """Creates a list of Credentials for a proxy plugin""" + security_bundle = SecurityBundle() + security_bundle.load_from_element(elementDescript) + return security_bundle def fair_split(i, n, p): @@ -622,10 +764,10 @@ def fair_assign(cred_list, params_obj): # TODO: Remove this block when we stop sending scitokens with proxies automatically # This is a special case to send a scitoken as a secondary credential alongside a proxy if num_cred == 2: - cred_pair = {cred.type: cred for cred in cred_list} - if set(cred_pair.keys()) == {"grid_proxy", "scitoken"}: - cred_pair["grid_proxy"].add_usage_details(total_idle, total_max) - cred_pair["scitoken"].add_usage_details(0, 0) + cred_pair = {cred.credential.cred_type: cred for cred in cred_list} + if set(cred_pair.keys()) == {CredentialType.X509_CERT, CredentialType.TOKEN}: + cred_pair[CredentialType.X509_CERT].add_usage_details(total_idle, total_max) + cred_pair[CredentialType.TOKEN].add_usage_details(0, 0) return cred_list # End of special case @@ -647,6 +789,7 @@ def fair_assign(cred_list, params_obj): # They should go throug the dictionaries below to find the appropriate plugin proxy_plugins = { + "CredentialsBasic": CredentialsBasic, "ProxyAll": ProxyAll, "ProxyUserRR": ProxyUserRR, "ProxyFirst": ProxyFirst, diff --git a/lib/credentials.py b/lib/credentials.py new file mode 100644 index 000000000..2a560a483 --- /dev/null +++ b/lib/credentials.py @@ -0,0 +1,2229 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC +# SPDX-License-Identifier: Apache-2.0 + +""" +This module contains classes and functions for managing GlideinWMS credentials. +""" + +import base64 +import enum # TODO: Use StrEnum starting from Python 3.11 +import gzip +import os +import shutil +import tempfile + +from abc import ABC, abstractmethod +from datetime import datetime +from hashlib import md5 +from inspect import signature +from io import BytesIO +from typing import Any, Generic, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union + +import jwt +import M2Crypto.EVP +import M2Crypto.X509 + +from glideinwms.lib import logSupport, pubCrypto, symCrypto +from glideinwms.lib.defaults import force_bytes +from glideinwms.lib.generators import Generator, load_generator +from glideinwms.lib.util import hash_nc + +T = TypeVar("T") + + +########################## +### Credentials ########## +########################## + + +class CredentialError(Exception): + """defining new exception so that we can catch only the credential errors here + and let the "real" errors propagate up + """ + + +class ParameterError(Exception): + """defining new exception so that we can catch only the parameter errors here + and let the "real" errors propagate up + """ + + +@enum.unique +class CredentialType(enum.Enum): + """ + Enum representing different types of credentials. + """ + + TOKEN = "token" + SCITOKEN = "scitoken" + IDTOKEN = "idtoken" + X509_CERT = "x509_cert" + RSA_KEY = "rsa_key" + GENERATOR = "generator" + TEXT = "text" + + def __str__(self) -> str: + return self.value + + def __repr__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + @classmethod + def from_string(cls, string: str) -> "CredentialType": + """ + Converts a string representation of a credential type to a CredentialType object. + + Args: + string (str): The string representation of the credential type. + + Returns: + CredentialType: The corresponding CredentialType enum value. + + Raises: + CredentialError: If the string does not match any known credential type. + """ + + extended_map = {"grid_proxy": cls.X509_CERT, "auth_file": cls.TEXT} + + string = string.lower() + + try: + return CredentialType(string) + except ValueError: + pass + if string in extended_map: + return extended_map[string] + raise CredentialError(f"Unknown Credential type: {string}") + + +@enum.unique +class CredentialPairType(enum.Enum): + """ + Enum representing different types of credential pairs. + """ + + X509_PAIR = "x509_pair" + KEY_PAIR = "key_pair" + USERNAME_PASSWORD = "username_password" + + @classmethod + def from_string(cls, string: str) -> "CredentialPairType": + """ + Converts a string representation of a credential type to a CredentialPairType object. + + Args: + string (str): The string representation of the credential type. + + Returns: + CredentialPairType: The corresponding CredentialPairType object. + + Raises: + CredentialError: If the string representation is not a valid credential type. + """ + + extended_map = {"cert_pair": cls.X509_PAIR} + + string = string.lower() + + try: + return CredentialPairType(string) + except ValueError: + pass + if string in extended_map: + return extended_map[string] + raise CredentialError(f"Unknown Credential type: {string}") + + def __str__(self) -> str: + return self.value + + def __repr__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + +@enum.unique +class CredentialPurpose(enum.Enum): + """ + Enum representing different purposes for credentials. + """ + + REQUEST = "request" + PAYLOAD = "payload" + + @classmethod + def from_string(cls, string: str) -> "CredentialPurpose": + """ + Converts a string representation of a CredentialPurpose to a CredentialPurpose object. + + Args: + string (str): The string representation of the CredentialPurpose. + + Returns: + CredentialPurpose: The CredentialPurpose object. + """ + + string = string.lower() + return CredentialPurpose(string) + + def __str__(self) -> str: + return self.value + + def __repr__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + +class Credential(ABC, Generic[T]): + """ + Represents a credential used for authentication or authorization purposes. + + Attributes: + cred_type (Optional[CredentialType]): The type of the credential. + classad_attribute (Optional[str]): The classad attribute associated with the credential. + extension (Optional[str]): The file extension of the credential. + + Raises: + CredentialError: If the credential cannot be initialized or loaded. + + """ + + cred_type: Optional[CredentialType] = None + classad_attribute: Optional[str] = None + extension: Optional[str] = None + + def __init__( + self, + string: Optional[Union[str, bytes]] = None, + path: Optional[str] = None, + purpose: Optional[CredentialPurpose] = None, + trust_domain: Optional[str] = None, + security_class: Optional[str] = None, + ) -> None: + """ + Initialize a Credentials object. + + Args: + string (Optional[Union[str, bytes]]): The credential string. + path (Optional[str]): The path to the credential file. + purpose (Optional[CredentialPurpose]): The purpose of the credential. + trust_domain (Optional[str]): The trust domain of the credential. + security_class (Optional[str]): The security class of the credential. + """ + + self._string = None + self.path = path + self.purpose = purpose + self.trust_domain = trust_domain + self.security_class = security_class + if string or path: + self.load(string, path) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(string={self.string!r}, path={self.path!r}, purpose={self.purpose!r}, trust_domain={self.trust_domain!r}, security_class={self.security_class!r})" + + def __str__(self) -> str: + return self.string.decode() if self.string else "" + + def __renew__(self) -> None: + raise NotImplementedError("Renewal not implemented for this credential type") + + @property + def _payload(self) -> Optional[T]: + return self.decode(self.string) if self.string else None + + @property + def string(self) -> Optional[bytes]: + """ + Credential string. + """ + + return self._string + + @property + def id(self) -> str: + """ + Credential unique identifier. + """ + + if not str(self.string): + raise CredentialError("Credential not initialized") + + return hash_nc(f"{str(self._string)}{self.purpose}{self.trust_domain}{self.security_class}", 8) + + @property + def purpose(self) -> Optional[CredentialPurpose]: + """ + Credential purpose. + """ + + return self._purpose[0] + + @purpose.setter + def purpose(self, value: Optional[Union[CredentialPurpose, str]]): + if not value: + self._purpose = (None, None) + elif isinstance(value, CredentialPurpose): + self._purpose = (value, None) + elif isinstance(value, str): + try: + self._purpose = (CredentialPurpose.from_string(value), None) + except ValueError: + self._purpose = (CredentialPurpose.PAYLOAD, value) + else: + raise CredentialError(f"Invalid purpose: {value}") + + @property + def purpose_alias(self) -> Optional[str]: + """ + Credential purpose alias. + """ + + if self._purpose[0] is not None: + return self._purpose[1] or self._purpose[0].value + + @property + def valid(self) -> bool: + """ + Whether the credential is valid. + """ + return self.invalid_reason() is None + + @staticmethod + @abstractmethod + def decode(string: Union[str, bytes]) -> T: + """ + Decode the given string to provide the credential. + + Args: + string (bytes): The string to decode. + + Returns: + T: The decoded value. + """ + + @abstractmethod + def invalid_reason(self) -> Optional[str]: + """ + Returns the reason why the credential is invalid. + + Returns: + str: The reason why the credential is invalid. None if the credential is valid. + """ + + def load_from_string(self, string: Union[str, bytes]) -> None: + """ + Load the credential from a string. + + Args: + string (bytes): The credential string to load. + + Raises: + CredentialError: If the input string is not of type bytes or if the credential cannot be loaded from the string. + """ + + string = force_bytes(string) + if not isinstance(string, bytes): + raise CredentialError("Credential string must be bytes") + try: + self.decode(string) + except Exception as err: + raise CredentialError(f"Could not load credential from string: {string}") from err + self._string = string + + def load_from_file(self, path: str) -> None: + """ + Load credentials from a file. + + Args: + path (str): The path to the credential file. + + Raises: + CredentialError: If the specified file does not exist. + """ + + if not os.path.isfile(path): + raise CredentialError(f"Credential file {self.path} does not exist") + with open(path, "rb") as cred_file: + self.load_from_string(cred_file.read()) + self.path = path + + def load(self, string: Optional[Union[str, bytes]] = None, path: Optional[str] = None) -> None: + """ + Load credentials from either a string or a file. + If both are defined, the string takes precedence. + + Args: + string (Optional[bytes]): The credentials string to load. + path (Optional[str]): The path to the file containing the credentials. + + Raises: + CredentialError: If neither `string` nor `path` is specified. + """ + + if string: + self.load_from_string(string) + if path: + self.path = path + elif path: + self.load_from_file(path) + else: + raise CredentialError("No string or path specified") + + def copy(self) -> "Credential": + """ + Create a copy of the credential. + + Returns: + Credential: The static credential. + + Raises: + CredentialError: If the credential is not initialized. + """ + + if not self.string: + raise CredentialError("Credential not initialized") + return create_credential( + string=self.string, + path=self.path, + purpose=self.purpose, + trust_domain=self.trust_domain, + security_class=self.security_class, + cred_type=self.cred_type, + ) + + def save_to_file( + self, + path: Optional[str] = None, + permissions: int = 0o600, + backup: bool = False, + compress: bool = False, + data_pattern: Optional[bytes] = None, + overwrite: bool = True, + continue_if_no_path=False, + ) -> None: + """ + Save the credential to a file. + + Args: + path (Optional[str]): The path to the file where the credential will be saved. + permissions (int): The permissions to set for the saved file. Default is 0o600. + backup (bool): Whether to create a backup of the existing file. Default is False. + compress (bool): Whether to compress the credential before saving. Default is False. + data_pattern (Optional[bytes]): A pattern to format the credential data before saving. Default is None. + overwrite (bool): Whether to overwrite the existing file if it already exists. Default is True. + continue_if_no_path (bool): If True, silently return without saving a file if no path is specified. Default is False. + + Raises: + CredentialError: If the credential is not initialized or if there is an error saving the credential. + """ + + if not self.string: + raise CredentialError("Credential not initialized") + + path = path or self.path + if not path: + if continue_if_no_path: + return + raise CredentialError("No path specified") + + if os.path.isfile(path) and not overwrite: + return + + text = self.string + if compress: + text = compress_credential(text) + if data_pattern: + text = data_pattern % text + + try: + # NOTE: NamedTemporaryFile is creted in private mode by default (0600) + with tempfile.NamedTemporaryFile(mode="wb", delete=False) as fd: + os.chmod(fd.name, permissions) + fd.write(text) + fd.flush() + if backup: + try: + shutil.copy2(path, f"{path}.old") + except FileNotFoundError as err: + logSupport.log.debug( + f"Tried to backup credential at {path} but file does not exist: {err}. Probably first time saving credential." + ) + shutil.move(fd.name, path) # os.replace() may cause issues if moving across filesystems + except OSError as err: + raise CredentialError(f"Could not save credential to {path}: {err}") from err + + def renew(self) -> None: + """ + Renews the credentials. + + This method attempts to renew the credentials by calling the private __renew__ method. + If the __renew__ method is not implemented, it will silently pass. + """ + try: + self.__renew__() + except NotImplementedError: + pass + + +class CredentialPair: + """ + Represents a pair of credentials, consisting of a public and private credential. + + NOTE: This class requires a Credential subclass as a second base class. + + Attributes: + cred_type (Optional[CredentialPairType]): The type of the credential pair. + private_credential (Credential): The private credential associated with this pair. + NOTE: Includes all attributes from the Credential class. + """ + + cred_type: Optional[CredentialPairType] = None + + def __init__( + self, + string: Optional[bytes] = None, + path: Optional[str] = None, + private_string: Optional[bytes] = None, + private_path: Optional[str] = None, + purpose: Optional[CredentialPurpose] = None, + trust_domain: Optional[str] = None, + security_class: Optional[str] = None, + ) -> None: + """ + Initialize a CredentialPair object. + + Args: + string (Optional[bytes]): The string representation of the public credential. + path (Optional[str]): The path to the public credential file. + private_string (Optional[bytes]): The string representation of the private credential. + private_path (Optional[str]): The path to the private credential file. + purpose (Optional[CredentialPurpose]): The purpose of the credential. + trust_domain (Optional[str]): The trust domain of the credential. + security_class (Optional[str]): The security class of the credential. + """ + + if len(self.__class__.__bases__) < 2 or not issubclass(self.__class__.__bases__[1], Credential): + raise CredentialError("CredentialPair requires a Credential subclass as second base class") + + credential_class = self.__class__.__bases__[1] + super(credential_class, self).__init__( # pylint: disable=bad-super-call # type: ignore[call-arg] + string, path, purpose, trust_domain, security_class + ) # pylint: disable=bad-super-call # type: ignore[call-arg] + self.private_credential = credential_class(private_string, private_path, purpose, trust_domain, security_class) + + def renew(self) -> None: + """ + Renews the credentials by calling the __renew__() method on both the public and private credentials. + """ + + try: + self.__renew__() # pylint: disable=no-member # type: ignore[attr-defined] + self.private_credential.__renew__() + except NotImplementedError: + pass + + def copy(self) -> "CredentialPair": + """ + Create a copy of the credential pair. + + Returns: + CredentialPair: The static credential pair. + + Raises: + CredentialError: If the credential pair is not initialized. + """ + + if not self.string: # pylint: disable=no-member # type: ignore[attr-defined] + raise CredentialError("Credential pair not initialized") + return create_credential_pair( + string=self.string, # pylint: disable=no-member # type: ignore[attr-defined] + path=self.path, # pylint: disable=no-member # type: ignore[attr-defined] + private_string=self.private_credential.string, + private_path=self.private_credential.path, + purpose=self.purpose, # pylint: disable=no-member # type: ignore[attr-defined] + trust_domain=self.trust_domain, # pylint: disable=no-member # type: ignore[attr-defined] + security_class=self.security_class, # pylint: disable=no-member # type: ignore[attr-defined] + cred_type=self.cred_type, + ) + + +# Dictionary of Credentials +class CredentialDict(dict): + """ + A dictionary-like class for storing credentials. + + This class extends the built-in `dict` class and provides additional + functionality for storing and retrieving `Credential` objects. + """ + + def __setitem__(self, __k, __v): + if not isinstance(__v, Credential): + raise TypeError("Value must be a credential") + super().__setitem__(__k, __v) + + def add(self, credential: Credential, credential_id: Optional[str] = None): + """ + Add a credential to the dictionary. + + Args: + credential (Credential): The credential object to add. + id (str, optional): The ID to use as the key in the dictionary. + If not provided, the credential's ID will be used. + """ + if not isinstance(credential, Credential): + raise TypeError("Value must be a credential") + self[credential_id or credential.id] = credential + + +class CredentialGenerator(Credential[Generator]): + """ + Represents a credential generator used for generating credentials. + + Attributes: + cred_type (CredentialType): The type of the credential. + classad_attribute (str): The classad attribute associated with the credential. + path (str): The path of the credential file. + """ + + cred_type = CredentialType.GENERATOR + classad_attribute = "CredentialGenerator" + + def __init__( # pylint: disable=super-init-not-called + self, + string: Optional[Union[str, bytes]] = None, + path: Optional[str] = None, + purpose: Optional[CredentialPurpose] = None, + trust_domain: Optional[str] = None, + security_class: Optional[str] = None, + context: Optional[Mapping] = None, + ) -> None: + """ + Initialize a Credentials object. + + Args: + string (Optional[Union[str, bytes]]): The credential string. + path (Optional[str]): The path to the credential file. + purpose (Optional[CredentialPurpose]): The purpose of the credential. + trust_domain (Optional[str]): The trust domain of the credential. + security_class (Optional[str]): The security class of the credential. + context (Optional[Mapping]): The context of the generator. + """ + + self._string = None + self._context = context + self._generated_credential: Optional[Credential] = None + self.path = path + self.purpose = purpose + self.trust_domain = trust_domain + self.security_class = security_class + if string or path: + self.load(string, path, context) + + def __getattr__(self, attr: str) -> Any: + try: + return getattr(self._generated_credential, attr) + except Exception: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attr}'") from None + + @property + def _payload(self) -> Optional[Generator]: + return self.decode(self.string, self.context) if self.string and self.context else None + + @property + def context(self) -> Optional[Mapping]: + """ + The context of the generator. + """ + + return self._context + + @staticmethod + def decode(string: Union[str, bytes], context: Optional[Mapping] = None) -> Generator: + if isinstance(string, bytes): + string = string.decode() + return load_generator(string, context) + + def load( + self, string: Optional[Union[str, bytes]] = None, path: Optional[str] = None, context: Optional[Mapping] = None + ) -> None: + if string: + self.load_from_string(string, context) + if path: + self.path = path + elif path: + self.load_from_file(path, context) + else: + raise CredentialError("No string or path specified") + + def load_from_string(self, string: Union[str, bytes], context: Optional[Mapping] = None) -> None: + string = force_bytes(string) + if not isinstance(string, bytes): + raise CredentialError("Credential string must be bytes") + try: + self.decode(string, context) + except Exception as err: + raise CredentialError(f"Could not load credential from string: {string}") from err + self._string = string + self._context = context + if context and "type" in context: + self.cred_type = CredentialType.from_string(context.get("type")) + + def load_from_file(self, path: str, context: Optional[Mapping] = None) -> None: + self.load_from_string(path, context) + + def copy(self) -> Credential: + if not self._generated_credential: + raise CredentialError("Credential not generated.") + return self._generated_credential.copy() + + def invalid_reason(self) -> Optional[str]: + if self._generated_credential: + return self._generated_credential.invalid_reason() + if isinstance(self._payload, Generator): + return None + return "Credential not initialized." + + def generate(self, **kwargs): + """ + Generate a credential using the generator. + + Args: + **kwargs: Additional keyword arguments to pass to the generator. + + Raises: + CredentialError: If the generator is not initialized. + """ + + if not self._payload: + raise CredentialError("Credential generator not initialized") + self._generated_credential = create_credential( + string=self._payload.generate(**kwargs), + purpose=self.purpose, + trust_domain=self.trust_domain, + security_class=self.security_class, + cred_type=self.cred_type, + ) + + +class Token(Credential[Mapping]): + """ + Represents a token credential. + + Attributes: + cred_type (CredentialType): The type of the credential. + classad_attribute (str): The name of the attribute in the classad. + extension (str): The file extension for the token. + scope (Optional[str]): The scope of the token. + issue_time (Optional[datetime]): The issue time of the token. + not_before_time (Optional[datetime]): The not-before time of the token. + expiration_time (Optional[datetime]): The expiration time of the token. + """ + + cred_type = CredentialType.TOKEN + classad_attribute = "ScitokenId" # TODO: We might want to change this name to "TokenId" in the future + extension = "jwt" + + @property + def scope(self) -> Optional[str]: + """ + Token scope. + """ + return self._payload.get("scope", None) if self._payload else None + + @property + def issue_time(self) -> Optional[datetime]: + """ + Token issue time. + """ + return datetime.fromtimestamp(self._payload.get("iat", None)) if self._payload else None + + @property + def not_before_time(self) -> Optional[datetime]: + """ + Token not-before time. + """ + return datetime.fromtimestamp(self._payload.get("nbf", None)) if self._payload else None + + @property + def expiration_time(self) -> Optional[datetime]: + """ + Token expiration time. + """ + return datetime.fromtimestamp(self._payload.get("exp", None)) if self._payload else None + + @staticmethod + def decode(string: Union[str, bytes]) -> Mapping: + if isinstance(string, bytes): + string = string.decode() + return jwt.decode(string.strip(), options={"verify_signature": False}) + + def invalid_reason(self) -> Optional[str]: + if not self._payload: + return "Token not initialized." + if datetime.now() < self.not_before_time: + return "Token not yet valid." + if datetime.now() > self.expiration_time: + return "Token expired." + + +class SciToken(Token): + """ + Represents a SciToken credential. + + Attributes: + cred_type (CredentialType): The type of the credential. + classad_attribute (str): The name of the attribute in the classad. + extension (str): The file extension for the token. + scope (Optional[str]): The scope of the token. + issue_time (Optional[datetime]): The issue time of the token. + not_before_time (Optional[datetime]): The not-before time of the token. + expiration_time (Optional[datetime]): The expiration time of the token. + + NOTE: This class is a subclass of the `Token` class. + """ + + cred_type = CredentialType.SCITOKEN + classad_attribute = "ScitokenId" # TODO: We might want to change this name to "TokenId" in the future + extension = "scitoken" + + +class IdToken(Token): + """ + Represents an ID token credential. + + Attributes: + cred_type (CredentialType): The type of the credential. + classad_attribute (str): The name of the attribute in the classad. + extension (str): The file extension for the token. + scope (Optional[str]): The scope of the token. + issue_time (Optional[datetime]): The issue time of the token. + not_before_time (Optional[datetime]): The not-before time of the token. + expiration_time (Optional[datetime]): The expiration time of the token. + + NOTE: This class is a subclass of the `Token` class. + """ + + cred_type = CredentialType.IDTOKEN + classad_attribute = "IdToken" + extension = "idtoken" + + +class X509Cert(Credential[M2Crypto.X509.X509]): + """ + Represents an X.509 certificate credential. + + Attributes: + cred_type (CredentialType): The type of the credential. + classad_attribute (str): The attribute name used in ClassAds. + extension (str): The file extension for the credential. + pub_key (Optional[M2Crypto.EVP.PKey]): The public key of the certificate. + not_before_time (Optional[datetime]): The not-before time of the certificate. + not_after_time (Optional[datetime]): The not-after time of the certificate. + """ + + cred_type = CredentialType.X509_CERT + classad_attribute = "SubmitProxy" + extension = "pem" + + @property + def pub_key(self) -> Optional[M2Crypto.EVP.PKey]: + """ + X.509 public key. + """ + return self._payload.get_pubkey() if self._payload else None + + @property + def not_before_time(self) -> Optional[datetime]: + """ + X.509 not-before time. + """ + return self._payload.get_not_before().get_datetime() if self._payload else None + + @property + def not_after_time(self) -> Optional[datetime]: + """ + X.509 not-after time. + """ + return self._payload.get_not_after().get_datetime() if self._payload else None + + @staticmethod + def decode(string: Union[str, bytes]) -> M2Crypto.X509.X509: + string = force_bytes(string) + return M2Crypto.X509.load_cert_string(string) + + def invalid_reason(self) -> Optional[str]: + if not self._payload: + return "Certificate not initialized." + if datetime.now(self.not_before_time.tzinfo) < self.not_before_time: + return "Certificate not yet valid." + if datetime.now(self.not_after_time.tzinfo) > self.not_after_time: + return "Certificate expired." + + +class RSAKey(Credential[pubCrypto.RSAKey]): + """ + Represents an RSA key credential. + + Attributes: + cred_type (CredentialType): The type of the credential. + classad_attribute (str): The attribute name used in ClassAds. + extension (str): The file extension for the key. + pub_key (Optional[pubCrypto.PubRSAKey]): The public key of the RSA key. + pub_key_id (Optional[str]): The ID of the public key. + key_type (Optional[str]): The type of the RSA key. + """ + + cred_type = CredentialType.RSA_KEY + classad_attribute = "RSAKey" + extension = "rsa" + + @property + def pub_key(self) -> Optional[pubCrypto.PubRSAKey]: + """ + RSA public key. + """ + return self._payload.PubRSAKey() if self._payload else None + + @property + def pub_key_id(self) -> Optional[str]: + """ + RSA public key ID. + """ + return ( + md5(b" ".join((self.key_type.encode("utf-8"), self.pub_key.get()))).hexdigest() + if self.key_type and self.pub_key + else None + ) + + @property + def key_type(self) -> Optional[str]: + """ + RSA key type. + + NOTE: This property always returns "RSA" if the key is initialized. + """ + return "RSA" if self._payload else None + + @staticmethod + def decode(string: Union[str, bytes]) -> pubCrypto.RSAKey: + string = force_bytes(string) + return pubCrypto.RSAKey(key_str=string) + + def invalid_reason(self) -> Optional[str]: + if not self._payload: + return "RSA key not initialized." + if not self.pub_key: + return "RSA public key not initialized." + if not self.pub_key_id: + return "RSA public key ID not initialized." + + def recreate(self) -> None: + """ + Recreates the RSA key. + + Raises: + CredentialError: If the RSA key is not initialized. + """ + if self._payload is None: + raise CredentialError("RSAKey not initialized") + + new_key = self._payload + new_key.new() + self.load_from_string(new_key.get()) + if self.path: + self.save_to_file(self.path) + + def extract_sym_key(self, enc_sym_key) -> symCrypto.AutoSymKey: + """ + Extracts the symmetric key using the RSA key. + + Args: + enc_sym_key: The encrypted symmetric key. + + Returns: + symCrypto.AutoSymKey: The extracted symmetric key. + + Raises: + CredentialError: If the RSA key is not initialized. + """ + if self._payload is None: + raise CredentialError("RSAKey not initialized") + + return symCrypto.AutoSymKey(self._payload.decrypt_hex(enc_sym_key)) + + +class TextCredential(Credential[bytes]): + """ + Represents a text-based credential. + + Attributes: + cred_type (CredentialType): The type of the credential. + classad_attribute (str): The attribute name used in ClassAds. + extension (str): The file extension for the credential. + """ + + cred_type = CredentialType.TEXT + classad_attribute = "AuthFile" + extension = "txt" + + @staticmethod + def decode(string: Union[str, bytes]) -> bytes: + return force_bytes(string) + + def invalid_reason(self) -> Optional[str]: + if self._payload is None: + return "Text credential not initialized." + + +class X509Pair(CredentialPair, X509Cert): + """ + Represents a pair of X509 certificates, consisting of a public certificate and a private certificate. + + This class extends both the `CredentialPair` and `X509Cert` classes. + + Attributes: + cred_type (CredentialPairType): The type of the credential pair. + classad_attribute (str): The attribute name used in the ClassAd for the public certificate. + private_credential (X509Cert): The private certificate associated with this pair. + NOTE: Includes all attributes from the X509Cert class. + """ + + cred_type = CredentialPairType.X509_PAIR + + def __init__( + self, + string: Optional[bytes] = None, + path: Optional[str] = None, + private_string: Optional[bytes] = None, + private_path: Optional[str] = None, + purpose: Optional[CredentialPurpose] = None, + trust_domain: Optional[str] = None, + security_class: Optional[str] = None, + ) -> None: + """ + Initialize a X509Pair object. + + Args: + string (Optional[bytes]): The public certificate as a byte string. + path (Optional[str]): The path to the public certificate file. + private_string (Optional[bytes]): The private certificate as a byte string. + private_path (Optional[str]): The path to the private certificate file. + purpose (Optional[CredentialPurpose]): The purpose of the credentials. + trust_domain (Optional[str]): The trust domain of the credentials. + security_class (Optional[str]): The security class of the credentials. + """ + + super().__init__(string, path, private_string, private_path, purpose, trust_domain, security_class) + self.classad_attribute = "PublicCert" + self.private_credential.classad_attribute = "PrivateCert" + + +class KeyPair(CredentialPair, RSAKey): + """ + Represents a pair of RSA keys, consisting of a public key and a private key. + + This class extends both the `CredentialPair` and `RSAKey` classes. + + Attributes: + cred_type (CredentialPairType): The type of the credential pair. + classad_attribute (str): The attribute name used in the ClassAd for the public key. + private_credential (RSAKey): The private key associated with this pair. + NOTE: Includes all attributes from the RSAKey class. + """ + + cred_type = CredentialPairType.KEY_PAIR + + def __init__( + self, + string: Optional[bytes] = None, + path: Optional[str] = None, + private_string: Optional[bytes] = None, + private_path: Optional[str] = None, + purpose: Optional[CredentialPurpose] = None, + trust_domain: Optional[str] = None, + security_class: Optional[str] = None, + ) -> None: + """ + Initialize a KeyPair object. + + Args: + string (Optional[bytes]): The public key as a byte string. + path (Optional[str]): The path to the public key file. + private_string (Optional[bytes]): The private key as a byte string. + private_path (Optional[str]): The path to the private key file. + purpose (Optional[CredentialPurpose]): The purpose of the credentials. + trust_domain (Optional[str]): The trust domain of the credentials. + security_class (Optional[str]): The security class of the credentials. + """ + + super().__init__(string, path, private_string, private_path, purpose, trust_domain, security_class) + self.classad_attribute = "PublicKey" + self.private_credential.classad_attribute = "PrivateKey" + + +class UsernamePassword(CredentialPair, TextCredential): + """ + Represents a username and password credential pair. + + This class extends both the `CredentialPair` and `TextCredential` classes. + + Attributes: + cred_type (CredentialPairType): The type of the credential pair. + classad_attribute (str): The classad attribute for the username. + private_credential (Credential): The private credential object for the password. + NOTE: Includes all attributes from the TextCredential class. + """ + + cred_type = CredentialPairType.USERNAME_PASSWORD + + def __init__( + self, + string: Optional[bytes] = None, + path: Optional[str] = None, + private_string: Optional[bytes] = None, + private_path: Optional[str] = None, + purpose: Optional[CredentialPurpose] = None, + trust_domain: Optional[str] = None, + security_class: Optional[str] = None, + ) -> None: + """ + Initialize a UsernamePassword object. + + Args: + string (Optional[bytes]): The username as a byte string. + path (Optional[str]): The path to the username file. + private_string (Optional[bytes]): The password as a byte string. + private_path (Optional[str]): The path to the password file. + purpose (Optional[CredentialPurpose]): The purpose of the credentials. + trust_domain (Optional[str]): The trust domain of the credentials. + security_class (Optional[str]): The security class of the credentials. + """ + + super().__init__(string, path, private_string, private_path, purpose, trust_domain, security_class) + self.classad_attribute = "Username" + self.private_credential.classad_attribute = "Password" + + +class RequestCredential: + """ + Represents an extended credential used for requesting resources. + + Args: + credential (Credential): The credential object. + + Attributes: + credential (Credential): The credential object. + advertize (bool): Flag indicating whether to advertise the credential. + req_idle (int): Number of idle jobs requested. + req_max_run (int): Maximum number of running jobs requested. + """ + + def __init__( + self, + credential: Credential, + ): + self.credential = credential + self.advertize: bool = True + self.req_idle: int = 0 + self.req_max_run: int = 0 + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(credential={self.credential!s}, advertize={self.advertize}, req_idle={self.req_idle}, req_max_run={self.req_max_run})" + + def __str__(self) -> str: + return f"{self.credential!s}" + + def add_usage_details(self, req_idle=0, req_max_run=0): + """ + Add usage details to the request. + + Args: + req_idle (int): Number of idle jobs requested. + req_max_run (int): Maximum number of running jobs requested. + """ + self.req_idle = req_idle + self.req_max_run = req_max_run + + def get_usage_details(self): + return self.req_idle, self.req_max_run + + +def credential_type_from_string(string: str) -> Union[CredentialType, CredentialPairType]: + """ + Returns the credential type for a given string. + + Args: + string (str): The string to parse. + + Raises: + CredentialError: If the credential type is unknown. + + Returns: + Union[CredentialType, CredentialPairType]: The credential type. + """ + + try: + return CredentialType.from_string(string) + except CredentialError: + try: + return CredentialPairType.from_string(string) + except CredentialError: + raise CredentialError(f"Unknown credential type: {string}") # pylint: disable=raise-missing-from + + +def credential_of_type( + cred_type: Union[CredentialType, CredentialPairType] +) -> Union[Type[Credential], Type[CredentialPair]]: + """Returns the credential subclass for the given type. + + Args: + cred_type (CredentialType): credential type + + Raises: + CredentialError: if the credential type is unknown + + Returns: + Credential: credential subclass + """ + + def subclasses_dict(classes): + sc_dict = {} + for cls in classes: + for sc in cls.__subclasses__(): + sc_dict[sc.cred_type] = sc + sc_dict.update(subclasses_dict([sc])) + return sc_dict + + try: + return subclasses_dict([Credential, CredentialPair])[cred_type] + except KeyError: + pass + raise CredentialError(f"Unknown Credential type: {cred_type}") + + +def create_credential( + string: Optional[Union[str, bytes]] = None, + path: Optional[str] = None, + purpose: Optional[CredentialPurpose] = None, + trust_domain: Optional[str] = None, + security_class: Optional[str] = None, + cred_type: Optional[CredentialPairType] = None, + context: Optional[Mapping] = None, +) -> Credential: + """ + Creates a credential object. + + Args: + string (bytes, optional): The credential as a byte string. + path (str, optional): The path to the credential file. + purpose (CredentialPurpose, optional): The purpose of the credential. + trust_domain (str, optional): The trust domain of the credential. + security_class (str, optional): The security class of the credential. + cred_type (CredentialType, optional): The type of the credential. + context (Mapping, optional): The context to use for decoding the credential. + + Returns: + Credential: The credential object. + + Raises: + CredentialError: If the credential cannot be loaded. + """ + + credential_types = [cred_type] if cred_type else CredentialType + for c_type in credential_types: + try: + credential_class = credential_of_type(c_type) + if issubclass(credential_class, Credential): + cred_args = signature(credential_class.__init__).parameters.values() + cred_args = [param.name for param in cred_args if param.name != "self"] + kwargs = {key: value for key, value in locals().items() if key in cred_args and value is not None} + return credential_class(**kwargs) + except CredentialError: + pass # Credential type incompatible with input + except Exception as err: + raise CredentialError(f'Unexpected error loading credential: string="{string}", path="{path}"') from err + raise CredentialError(f'Could not load credential: string="{string}", path="{path}"') + + +def create_credential_pair( + string: Optional[Union[str, bytes]] = None, + path: Optional[str] = None, + private_string: Optional[bytes] = None, + private_path: Optional[str] = None, + purpose: Optional[CredentialPurpose] = None, + trust_domain: Optional[str] = None, + security_class: Optional[str] = None, + cred_type: Optional[CredentialPairType] = None, + context: Optional[Mapping] = None, +) -> CredentialPair: + """ + Creates a credential pair object. + + Args: + string (bytes, optional): The public credential as a byte string. + path (str, optional): The path to the public credential file. + private_string (bytes, optional): The private credential as a byte string. + private_path (str, optional): The path to the private credential file. + purpose (CredentialPurpose, optional): The purpose of the credentials. + trust_domain (str, optional): The trust domain of the credentials. + security_class (str, optional): The security class of the credentials. + cred_type (CredentialPairType, optional): The type of the credential pair. + context (Mapping, optional): The context to use for decoding the credentials. + + Returns: + CredentialPair: The credential pair object. + + Raises: + CredentialError: If the credential pair cannot be loaded. + """ + + credential_types = [cred_type] if cred_type else CredentialPairType + for c_type in credential_types: + try: + credential_class = credential_of_type(c_type) + if issubclass(credential_class, CredentialPair): + cred_args = signature(credential_class.__init__).parameters.values() + cred_args = [param.name for param in cred_args if param.name != "self"] + kwargs = {key: value for key, value in locals().items() if key in cred_args and value is not None} + return credential_class(**kwargs) + except CredentialError: + pass + except Exception as err: + raise CredentialError( + f'Unexpected error loading credential pair: string="{string}", path="{path}", private_string="{private_string}", private_path="{private_path}"' + ) from err + raise CredentialError( + f'Could not load credential pair: string="{string}", path="{path}", private_string="{private_string}", private_path="{private_path}"' + ) + + +def standard_path(cred: Credential) -> str: + """ + Returns the standard path for a credential. + + Args: + cred (Credential): The credential object. + + Returns: + str: The standard path for the credential. + """ + + if not cred.string: + raise CredentialError("Credential not initialized") + if not cred.path: + raise CredentialError("Credential path not set") + + filename = os.path.basename(cred.path) + if not filename: + raise CredentialError("Credential path is not a file") + + filename = f"credential_{cred.purpose_alias}_{filename}.{cred.extension}" + path = os.path.join(os.path.dirname(cred.path), filename) + + return path + + +def compress_credential(credential_data: bytes) -> bytes: + """ + Compresses a credential. + + Args: + credential_data (bytes): The credential data. + + Returns: + bytes: The compressed credential. + """ + + with BytesIO() as cfile: + with gzip.GzipFile(fileobj=cfile, mode="wb") as f: + # Calling a GzipFile object's close() method does not close fileobj, so cfile is available outside + f.write(credential_data) + return base64.b64encode(cfile.getvalue()) + + +########################## +### Parameters ########### +########################## + + +@enum.unique +class ParameterName(enum.Enum): + """ + Enum representing different parameter names. + """ + + VM_ID = "VMId" + VM_TYPE = "VMType" + GLIDEIN_PROXY = "GlideinProxy" + REMOTE_USERNAME = "RemoteUsername" + PROJECT_ID = "ProjectId" + + @classmethod + def from_string(cls, string: str) -> "ParameterName": + """ + Converts a string representation of a parameter name to a ParameterName object. + + Args: + string (str): The string representation of the parameter name. + + Returns: + ParameterName: The corresponding ParameterName object. + + Raises: + ParameterError: If the string does not match any known parameter name. + """ + + extended_map = {"vm_id": cls.VM_ID, "vm_type": cls.VM_TYPE, "project_id": cls.PROJECT_ID} + extended_map.update({param.value.lower(): param for param in cls}) + + string = string.lower() + + try: + return ParameterName(string) + except ValueError: + pass + if string in extended_map: + return extended_map[string] + raise ParameterError(f"Unknown Parameter name: {string}") + + def __str__(self) -> str: + return self.value + + def __repr__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + +@enum.unique +class ParameterType(enum.Enum): + """ + Enum representing different types of parameters. + """ + + GENERATOR = "generator" + INTEGER = "integer" + EXPRESSION = "expression" + STRING = "string" + + @classmethod + def from_string(cls, string: str) -> "ParameterType": + """ + Create a ParameterType object from a string representation. + + Args: + string (str): The string representation of the ParameterType. + + Returns: + ParameterType: The created ParameterType object. + + Raises: + ParameterError: If the string does not match any known ParameterType. + """ + + extended_map = {"int": cls.INTEGER, "expr": cls.EXPRESSION, "str": cls.STRING} + extended_map.update({param.value.lower(): param for param in cls}) + + string = string.lower() + + try: + return ParameterType(string) + except ValueError: + pass + if string in extended_map: + return extended_map[string] + raise ParameterError(f"Unknown Parameter type: {string}") + + def __str__(self) -> str: + return self.value + + def __repr__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + +class Parameter(ABC, Generic[T]): + """ + Represents a parameter with a name and value. + + Attributes: + param_type (ParameterType): The type of the parameter. + name (ParameterName): The name of the parameter. + value (str): The value of the parameter. + """ + + param_type = None + + def __init__(self, name: ParameterName, value: Union[T, str]): + """ + Initialize a Parameter object. + + Args: + name (ParameterName): The name of the parameter. + value (str): The value of the parameter. + """ + + if not isinstance(name, ParameterName): + raise TypeError("Parameter name must be a ParameterName") + + self._name = name + self._value = self.parse_value(value) + + @property + def name(self) -> ParameterName: + """ + Parameter name. + """ + + return self._name + + @property + def value(self) -> T: + """ + Parameter value. + """ + + return self._value + + @property + @abstractmethod + def quoted_value(self) -> str: + """ + Quoted parameter value. + """ + + @staticmethod + @abstractmethod + def parse_value(value: Union[T, str]) -> T: + """ + Parse a value to the parameter type. + + Args: + value: The value to parse. + + Returns: + T: The parsed value. + + Raises: + ValueError: If the value is invalid. + """ + + def copy(self) -> "Parameter": + """ + Create a copy of the parameter. + + Returns: + Parameter: The copied parameter. + """ + + return create_parameter(self._name, self._value, self.param_type) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(name={self._name.value!r}, value={self._value!r}, param_type={self.param_type.value!r})" + + def __str__(self) -> str: + return f"{self.name.value}={self.value}" + + +class ParameterGenerator(Parameter, Generator): + """ + A class representing a generator parameter. + + This class inherits from the base `Parameter` class and is used to define parameters + that generate their values dynamically using a generator function. + + Args: + name (ParameterName): The name of the parameter. + generator (str): The name of the generator to use. + + Attributes: + param_type (ParameterType): The type of the parameter (GENERATOR). + name (ParameterName): The name of the parameter. + value (str): The value of the parameter. + """ + + param_type = ParameterType.GENERATOR + + def __init__(self, name: ParameterName, value: str, context: Optional[Mapping] = None): + """ + Initialize a ParameterGenerator object. + + Args: + name (ParameterName): The name of the parameter. + value (str): The value of the parameter. + context (Mapping, optional): The context to use for the parameter. + """ + + if not isinstance(name, ParameterName): + raise TypeError("Parameter name must be a ParameterName") + + self._generated_parameter: Optional[Parameter] = None + self._name = name + self._value = self.parse_value(value, context) + if context and "type" in context: + self.param_type = ParameterType.from_string(context.get("type")) + + @property + def value(self) -> any: + """ + Parameter value. + + NOTE: None until the parameter is generated. + """ + + return self._generated_parameter.value if self._generated_parameter else None + + @property + def quoted_value(self) -> Optional[str]: + """ + Quoted parameter value. + """ + + return self._generated_parameter.quoted_value if self._generated_parameter else None + + @staticmethod + def parse_value(value: Union[Generator, str], context: Optional[Mapping]) -> Generator: + """ + Parse the parameter value to a generator. + + Args: + value (str): The value to parse. + + Returns: + Generator: The parsed value. + + Raises: + ImportError: If the generator could not be loaded. + """ + + try: + if isinstance(value, Generator): + return value + return load_generator(value, context) + except ImportError as err: + raise ImportError(f"Could not load generator: {value}") from err + + def copy(self) -> Parameter: + """ + Create a copy of the current generated parameter. + + NOTE: The resulting parameter is not a generator. + + Returns: + Parameter: The copied parameter. + """ + + return create_parameter(self.name, self.value, self.param_type) + + def generate(self, **kwargs): + """ + Generate the parameter value using the generator function. + + Args: + **kwargs: Additional keyword arguments to pass to the generator function. + """ + + self._generated_parameter = create_parameter(self._name, self._value.generate(**kwargs), self.param_type) + + +class IntegerParameter(Parameter[int]): + """ + Represents an integer parameter. + + This class extends the base `Parameter` class and is used to define parameters + with integer values. + + Attributes: + param_type (ParameterType): The type of the parameter (INTEGER). + name (ParameterName): The name of the parameter. + value (int): The value of the parameter. + """ + + param_type = ParameterType.INTEGER + + @property + def quoted_value(self) -> str: + """ + Quoted parameter value. + """ + + return str(self.value) + + @staticmethod + def parse_value(value: Union[int, str]) -> int: + """ + Parse a value to an integer. + """ + + try: + return int(value) + except ValueError as err: + raise ValueError(f"Invalid integer value: {value}") from err + + +class StringParameter(Parameter[str]): + """ + Represents a string parameter. + + This class extends the base `Parameter` class and is used to define parameters + with string values. + + Attributes: + param_type (ParameterType): The type of the parameter (STRING). + name (ParameterName): The name of the parameter. + value (str): The value of the parameter. + """ + + param_type = ParameterType.STRING + + @property + def quoted_value(self) -> str: + """ + Quoted parameter value. + """ + + return f'"{self.value}"' + + @staticmethod + def parse_value(value: str) -> str: + """ + Parse a value to a string. + + Args: + value (str): The value to parse. + """ + + return str(value) + + +class ExpressionParameter(Parameter[str]): + """ + Represents an expression parameter. + + This class extends the base `Parameter` class and is used to define parameters + with expression values. + + Attributes: + param_type (ParameterType): The type of the parameter (EXPRESSION). + name (ParameterName): The name of the parameter. + value (str): The value of the parameter. + """ + + param_type = ParameterType.EXPRESSION + + @property + def quoted_value(self) -> str: + """ + Quoted parameter value. + """ + + return self.value + + @staticmethod + def parse_value(value: str) -> str: + """ + Parse the parameter value. + + Args: + value (str): The value to parse. + + Returns: + str: The parsed value. + """ + + # TODO: Implement expression parsing + + return value + + +class ParameterDict(dict): + """ + A dictionary subclass for storing parameters. + + This class extends the built-in `dict` class and provides additional functionality + for storing and retrieving parameters. It enforces that keys must be of type `ParameterName` + and values must be of type `Parameter`. + """ + + def __setitem__(self, __k, __v): + if isinstance(__k, str): + __k = ParameterName.from_string(__k) + if not isinstance(__k, ParameterName): + raise TypeError("Key must be a ParameterType") + if not isinstance(__v, Parameter): + raise TypeError("Value must be a Parameter") + super().__setitem__(__k, __v) + + def __getitem__(self, __k): + if isinstance(__k, str): + __k = ParameterName.from_string(__k) + if not isinstance(__k, ParameterName): + raise TypeError("Key must be a ParameterType") + return super().__getitem__(__k) + + def add(self, parameter: Parameter): + """ + Adds a parameter to the dictionary. + + Args: + parameter (Parameter): The parameter to add. + """ + + if not isinstance(parameter, Parameter): + raise TypeError("Parameter must be a Parameter") + self[parameter.name] = parameter + + +def parameter_of_type(param_type: ParameterType) -> Type[Parameter]: + """Returns the parameter subclass for the given type. + + Args: + param_type (ParameterType): parameter type + + Raises: + CredentialError: if the parameter type is unknown + + Returns: + Parameter: parameter subclass + """ + + class_dict = {} + for i in Parameter.__subclasses__(): + class_dict[i.param_type] = i + try: + return class_dict[param_type] + except KeyError as err: + raise CredentialError(f"Unknown Parameter type: {param_type}") from err + + +def create_parameter( + name: ParameterName, value: str, param_type: Optional[ParameterType] = None, context: Optional[Mapping] = None +) -> Parameter: + """ + Creates a parameter. + + Args: + name (ParameterName): The name of the parameter. + value (str): The value of the parameter. + param_type (ParameterType, optional): The type of the parameter. + context (Mapping, optional): The context to use for the parameter. + + Returns: + Parameter: The parameter object. + """ + + parameter_types = [param_type] if param_type else ParameterType + for p_type in parameter_types: + try: + parameter_class = parameter_of_type(p_type) + if issubclass(parameter_class, Parameter): + param_args = signature(parameter_class.__init__).parameters.values() + param_args = [param.name for param in param_args if param.name != "self"] + kwargs = {key: value for key, value in locals().items() if key in param_args and value is not None} + return parameter_class(**kwargs) + except TypeError: + pass # Parameter type incompatible with input + except Exception as err: + raise CredentialError(f'Unexpected error loading parameter: name="{name}", value="{value}"') from err + raise CredentialError(f'Could not load parameter: name="{name}", value="{value}"') + + +########################## +### Tools ################ +########################## + + +class SecurityBundle: + """ + Represents a security bundle used for submitting jobs. + + Args: + username (str): The username for the security bundle. + security_class (str): The security class for the security bundle. + """ + + def __init__(self): + self.credentials = CredentialDict() + self.parameters = ParameterDict() + + def add_credential(self, credential, credential_id=None): + """ + Adds a credential to the security bundle. + + Args: + credential (Credential): The credential to add. + credential_id (str, optional): The ID to use as the key in the dictionary. + If not provided, the credential's ID will be used. + """ + + self.credentials.add(credential, credential_id) + + def add_parameter(self, parameter: Parameter): + """ + Adds a parameter to the security bundle. + + Args: + parameter (Parameter): The parameter to add. + """ + + self.parameters.add(parameter) + + def load_from_element(self, element_descript): + """ + Load the security bundle from an element descriptor. + + Args: + element_descript (ElementDescriptor): The element descriptor to load from. + """ + + for path in element_descript.merged_data["Proxies"]: + cred_type = credential_type_from_string(element_descript.merged_data["ProxyTypes"].get(path)) + purpose = element_descript.merged_data["CredentialPurposes"].get(path) + trust_domain = element_descript.merged_data["ProxyTrustDomains"].get(path, "None") + security_class = element_descript.merged_data["ProxySecurityClasses"].get( + path, "None" + ) # TODO: Should this be None? + context = load_context(element_descript.merged_data["CredentialContexts"].get(path, None)) + if isinstance(cred_type, CredentialType): + credential = create_credential( + path=path, + purpose=purpose, + trust_domain=trust_domain, + security_class=security_class, + cred_type=cred_type, + context=context, + ) + else: + cred_key = element_descript.merged_data["ProxyKeyFiles"].get(path, None) + credential = create_credential_pair( + path=path, + private_path=cred_key, + purpose=purpose, + trust_domain=trust_domain, + security_class=security_class, + cred_type=cred_type, + context=context, + ) + self.add_credential(credential) + for name, data in element_descript.merged_data["Parameters"].items(): + parameter = create_parameter( + ParameterName.from_string(name), + data["value"], + ParameterType.from_string(data["type"]), + load_context(data["context"]), + ) + self.add_parameter(parameter) + + +class SubmitBundle: + """ + Represents a submit bundle used for submitting jobs. + + This includes Frontend-provided security credentials, identity credentials, and parameters, + and Factory-provided security credentials. + + Attributes: + username (str): The username for the submit bundle. + security_class (str): The security class for the submit bundle. + id (str): The ID used for tracking the submit credentials. + cred_dir (str): The location of the credentials. + auth_set (AuthenticationSet): The authentication requirements for the submit bundle. + security_credentials (CredentialDict): A dictionary of security credentials. + identity_credentials (CredentialDict): A dictionary of identity credentials. + parameters (ParameterDict): A dictionary of parameters. + """ + + def __init__(self, username: str, security_class: str): + """ + Initialize a Credentials object. + + Args: + username (str): The username for the submit bundle. + security_class (str): The security class for the submit bundle. + + """ + self.username = username + self.security_class = security_class + self.id = None + self.cred_dir = "" + self.auth_set: Optional[AuthenticationSet] = None + self.security_credentials = CredentialDict() + self.identity_credentials = CredentialDict() + self.parameters = ParameterDict() + + def add_security_credential( + self, + credential: Credential, + cred_id: str = None, + ) -> bool: + """ + Adds a security credential. + + Args: + credential (Credential): The credential object. + cred_id (str): The ID of the credential. + + Returns: + bool: True if the credential was added, otherwise False. + """ + + try: + self.security_credentials.add(credential, cred_id) + return True + except TypeError: + return False + + def add_factory_credential(self, cred_id: str, credential: Credential) -> bool: + """ + Adds a factory provided security credential. + + Args: + cred_id (str): The ID of the credential. + credential (Credential): The credential object. + + Returns: + bool: True if the credential was added, otherwise False. + """ + + self.security_credentials[cred_id] = credential + return True + + def add_identity_credential(self, credential: Credential, cred_id: Optional[str] = None) -> bool: + """ + Adds an identity credential. + + Args: + cred_id (str): The ID of the credential. + credential (Credential): The credential object. + + Returns: + bool: True if the credential was added, otherwise False. + """ + + try: + self.identity_credentials.add(credential, cred_id) + return True + except TypeError: + return False + + def add_parameter(self, parameter: Parameter) -> bool: + """ + Adds a parameter. + + Args: + param_id (ParameterName): The ID of the parameter. + param_value (str): The value of the parameter. + + Returns: + bool: True if the parameter was added, otherwise False. + """ + + try: + self.parameters.add(parameter) + return True + except TypeError: + return False + + +class AuthenticationSet: + """ + Represents a set of authentication requirements. + """ + + _required_types: Set[Union[CredentialType, CredentialPairType, ParameterName]] = set() + + def __init__(self, auth_set: Iterable[Union[CredentialType, CredentialPairType, ParameterName]]): + """ + Initialize the Credentials object. + + Args: + auth_set: A collection of credential types, credential pair types, or parameter names. + + Raises: + TypeError: If an invalid credential type is provided. + """ + for auth_el in auth_set: + if not isinstance(auth_el, (CredentialType, CredentialPairType, ParameterName)): + raise TypeError(f"Invalid authentication element: {auth_el} ({type(auth_el)})") + self._required_types = set(auth_set) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + def __str__(self) -> str: + return ",".join(str(cred_type) for cred_type in self._required_types) + + def __contains__(self, auth_el: Union[CredentialType, CredentialPairType, ParameterName, str]) -> bool: + return self.supports(auth_el) + + def supports(self, auth_el: Union[CredentialType, CredentialPairType, ParameterName, str]) -> bool: + """ + Checks if the authentication set supports a given credential type. + + Args: + auth_el (Union[CredentialType, CredentialPairType, ParameterName, str]): The authentication element to check. + + Returns: + bool: True if the credential type is supported, otherwise False. + """ + + if isinstance(auth_el, str): + try: + return CredentialType.from_string(auth_el) in self._required_types + except CredentialError: + pass + try: + return CredentialPairType.from_string(auth_el) in self._required_types + except CredentialError: + pass + try: + return ParameterName.from_string(auth_el) in self._required_types + except ParameterError: + pass + return auth_el in self._required_types + + def satisfied_by(self, auth_set: Iterable[Union[CredentialType, CredentialPairType, ParameterName]]) -> bool: + """ + Checks if the authentication set is satisfied by a given set of credential types. + + Args: + auauth_set: A collection of credential types, credential pair types, or parameter names. + + Returns: + bool: True if the authentication set is satisfied, otherwise False. + """ + + return self._required_types.issubset(auth_set) + + +class AuthenticationMethod: + """ + Represents an authentication method used for authenticating users. + """ + + def __init__(self, auth_method: str): + """ + Initialize the Credentials object. + + Args: + auth_method (str): The authentication method. + """ + + self._requirements: List[List[Union[CredentialType, ParameterName]]] = [] + self.load(auth_method) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._requirements!r})" + + def __str__(self) -> str: + return ";".join(str(auth_set) for auth_set in self._requirements) + + def __contains__(self, cred_type: Union[CredentialType, str]) -> bool: + if isinstance(cred_type, str): + cred_type = CredentialType.from_string(cred_type) + return any(cred_type in group for group in self._requirements) + + def load(self, auth_method: str): + """ + Loads the authentication method from a string. + + Args: + auth_method (str): The authentication method. + """ + + for group in auth_method.split(";"): + if group.lower() == "any": + self._requirements.append([]) + else: + options = [] + for option in group.split(","): + try: + options.append(CredentialType.from_string(option)) + continue + except CredentialError: + pass + try: + options.append(CredentialPairType.from_string(option)) + continue + except CredentialError: + pass + try: + options.append(ParameterName.from_string(option)) + continue + except ParameterError: + pass + raise CredentialError(f"Unknown authentication requirement: {option}") + self._requirements.append(options) + + def match(self, security_bundle: SecurityBundle) -> Optional[AuthenticationSet]: + """ + Matches the authentication method to a security bundle and returns the authentication set if the requirements are met. + + Args: + security_bundle (SecurityBundle): The security bundle to match. + + Returns: + Optional[AuthenticationSet]: The authentication set if the security bundle matches the requirements, otherwise None. + """ + + if not self._requirements: + return AuthenticationSet([]) + + auth_set = [] + sec_items = {credential.cred_type for credential in security_bundle.credentials.values() if credential.valid} + sec_items.update(security_bundle.parameters.keys()) + for group in self._requirements: + # At least one group option must be in sec_items (select the first one) + selected = False + for option in group: + if option in sec_items: + auth_set.append(option) + selected = True + break + if not selected: + return None + return AuthenticationSet(auth_set) + + +def load_context(context: str) -> Optional[Mapping]: + """ + Load a context from a string. + + Args: + context (str): The context string. + + Returns: + Mapping: The context as a mapping. + """ + + try: + context = eval(context) # pylint: disable=eval-used + assert isinstance(context, Mapping) + return context + except Exception: # pylint: disable=bare-except + return None + + +########################## +### Compatibility ######## +########################## + + +def cred_path(cred: Optional[Union[Credential, str]]) -> Optional[str]: + """ + Returns the path of a credential. + + Args: + cred (Union[Credential, str]): The credential object or path. + + Returns: + Optional[str]: The path of the credential. + + Raises: + CredentialError: If the credential object is invalid. + """ + + if not cred: + return None + if issubclass(cred.__class__, Credential): + return cred.path + if isinstance(cred, str): + return cred + raise CredentialError("Invalid credential object") diff --git a/lib/credentialsLegacy.py b/lib/credentialsLegacy.py new file mode 100644 index 000000000..c301dc0ee --- /dev/null +++ b/lib/credentialsLegacy.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC +# SPDX-License-Identifier: Apache-2.0 + +""" +This module holds deprecated credential functions. +NOTE: This will likely be removed in the future. +""" + +import os +import sys + +from glideinwms.lib import logSupport +from glideinwms.lib.generators import import_module + +sys.path.append("/etc/gwms-frontend/plugin.d") +plugins = {} + + +def generate_credential(elementDescript, glidein_el, group_name, trust_domain): + """Generates a credential with a credential generator plugin provided for the trust domain. + + Args: + elementDescript (ElementMergedDescript): element descript + glidein_el (dict): glidein element + group_name (string): group name + trust_domain (string): trust domain for the element + + Returns: + string, None: Credential or None if not generated + """ + + ### The credential generator plugin should define the following function: + # def get_credential(log:logger, group:str, entry:dict{name:str, gatekeeper:str}, trust_domain:str): + # Generates a credential given the parameter + + # Args: + # log:logger + # group:str, + # entry:dict{ + # name:str, + # gatekeeper:str}, + # trust_domain:str, + # Return + # tuple + # token:str + # lifetime:int seconds of remaining lifetime + # Exception + # KeyError - miss some information to generate + # ValueError - could not generate the token + + generator = None + generators = elementDescript.element_data.get("CredentialGenerators") + trust_domain_data = elementDescript.element_data.get("ProxyTrustDomains") + if not generators: + generators = elementDescript.frontend_data.get("CredentialGenerators") + if not trust_domain_data: + trust_domain_data = elementDescript.frontend_data.get("ProxyTrustDomains") + if trust_domain_data and generators: + generators_map = eval(generators) + trust_domain_map = eval(trust_domain_data) + for cfname in generators_map: + if trust_domain_map[cfname] == trust_domain: + generator = generators_map[cfname] + logSupport.log.debug(f"found credential generator plugin {generator}") + try: + if generator not in plugins: + plugins[generator] = import_module(generator) + entry = { + "name": glidein_el["attrs"].get("EntryName"), + "gatekeeper": glidein_el["attrs"].get("GLIDEIN_Gatekeeper"), + "factory": glidein_el["attrs"].get("AuthenticatedIdentity"), + } + stkn, _ = plugins[generator].get_credential(logSupport, group_name, entry, trust_domain) + return cfname, stkn + except ModuleNotFoundError: + logSupport.log.warning(f"Failed to load credential generator plugin {generator}") + except Exception as e: # catch any exception from the plugin to prevent the frontend from crashing + logSupport.log.warning(f"Failed to generate credential: {e}.") + + return None, None + + +def get_scitoken(elementDescript, trust_domain): + """Look for a local SciToken specified for the trust domain. + + Args: + elementDescript (ElementMergedDescript): element descript + trust_domain (string): trust domain for the element + + Returns: + string, None: SciToken or None if not found + """ + + scitoken_fullpath = "" + cred_type_data = elementDescript.element_data.get("ProxyTypes") + trust_domain_data = elementDescript.element_data.get("ProxyTrustDomains") + if not cred_type_data: + cred_type_data = elementDescript.frontend_data.get("ProxyTypes") + if not trust_domain_data: + trust_domain_data = elementDescript.frontend_data.get("ProxyTrustDomains") + if trust_domain_data and cred_type_data: + cred_type_map = eval(cred_type_data) + trust_domain_map = eval(trust_domain_data) + for cfname in cred_type_map: + if cred_type_map[cfname] == "scitoken": + if trust_domain_map[cfname] == trust_domain: + scitoken_fullpath = cfname + + if os.path.exists(scitoken_fullpath): + try: + logSupport.log.debug(f"found scitoken {scitoken_fullpath}") + stkn = "" + with open(scitoken_fullpath) as fbuf: + for line in fbuf: + stkn += line + stkn = stkn.strip() + return stkn + except Exception as err: + logSupport.log.exception(f"failed to read scitoken: {err}") + + return None diff --git a/lib/defaults.py b/lib/defaults.py index 9a3d4bcbb..7ddfbb5d5 100644 --- a/lib/defaults.py +++ b/lib/defaults.py @@ -19,6 +19,8 @@ BINARY_ENCODING_ASCII = "ascii" # valid aliases: 646, us-ascii BINARY_ENCODING_DEFAULT = "utf_8" # valid aliases: utf-8, utf8 Default Python 3 encoding +PLUGINS_DIR = "/etc/gwms-frontend/plugin.d" + def force_bytes(instr, encoding=BINARY_ENCODING_CRYPTO): """Forces the output to be bytes, encoding the input if it is a unicode string (str) diff --git a/lib/generators.py b/lib/generators.py new file mode 100644 index 000000000..dab01a838 --- /dev/null +++ b/lib/generators.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC +# SPDX-License-Identifier: Apache-2.0 + +""" +This module contains the Generator base class and built-in generators +""" + + +import inspect +import os +import re +import sys + +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Generic, Mapping, Optional, Type, TypeVar + +from glideinwms.lib.defaults import PLUGINS_DIR +from glideinwms.lib.util import hash_nc, import_module + +sys.path.append(PLUGINS_DIR) +_loaded_generators = {} +_generator_instances = defaultdict(dict) + +T = TypeVar("T") + + +class GeneratorError(Exception): + """Base class for generator exceptions""" + + +class Generator(ABC, Generic[T]): + """Base class for generators""" + + def __init__(self, context: Optional[Mapping] = None): + self.context = context + + def __str__(self): + return f"{self.__class__.__name__}()" + + def __repr__(self): + return str(self) + + @abstractmethod + def generate(self, **kwargs) -> T: + """ + Generate an item using the context and keyword arguments + """ + + +def load_generator(module: str, context: Optional[Mapping] = None) -> Generator: + """Load a generator from a module + + Args: + module (str): module that exports a generator + + Raises: + ImportError: when a `Generator` object cannot be imported from `module` + + Returns: + Generator: generator object + """ + + module_name = re.sub(r"\.py[co]?$", "", os.path.basename(module)) # Extract module name from path + + try: + if module_name not in _loaded_generators: + imported_module = import_module(module) + if module_name not in _loaded_generators: + del imported_module + raise ImportError( + f"Module {module} does not export a generator. Please call export_generator(generator) in the module." + ) + except ImportError as e: + raise ImportError(f"Failed to import module {module}") from e + + try: + instance_id = hash_nc(f"{module_name}{str(context)}", 8) + if instance_id not in _generator_instances: + _generator_instances[module_name][instance_id] = _loaded_generators[module_name](context) + except GeneratorError as e: + raise GeneratorError(f"Failed to create generator from module {module}") from e + + return _generator_instances[module_name][instance_id] + + +def export_generator(generator: Type[Generator]): + """Make a Generator object available to the genearators module""" + + if not issubclass(generator, Generator): + raise TypeError("generator must be a Generator object") + module_fname = inspect.stack()[1].filename + module_name = re.sub(r"\.py[co]?$", "", os.path.basename(module_fname)) + _loaded_generators[module_name] = generator + + +def drop_generator(module: str) -> bool: + """Remove a generator from the generators module""" + + dropped = False + module_name = re.sub(r"\.py[co]?$", "", os.path.basename(module)) + if module_name in _loaded_generators: + del _loaded_generators[module_name] + dropped = True + if module_name in _generator_instances: + del _generator_instances[module_name] + + return dropped + + +def drop_generator_instance(generator: Generator) -> bool: + """Remove a generator instance from the generators module""" + + for module_name, instances in _generator_instances.items(): + for instance in instances: + if instances[instance] == generator: + del instances[instance] + if not instances: + del _generator_instances[module_name] + return True + + return False diff --git a/lib/logSupport.py b/lib/logSupport.py index 32676623d..e3a627e9e 100644 --- a/lib/logSupport.py +++ b/lib/logSupport.py @@ -41,7 +41,7 @@ # Create a placeholder for a global logger (logging.Logger), # individual modules can create their own loggers if necessary -log = None +log: logging.Logger = None log_dir = None disable_rotate = False diff --git a/lib/util.py b/lib/util.py index 7e9e39eaa..27e2e7beb 100644 --- a/lib/util.py +++ b/lib/util.py @@ -11,6 +11,7 @@ import pickle import re import shutil +import string import subprocess import sys import tempfile @@ -482,6 +483,15 @@ def chmod(*args, **kwargs): os.chmod(*args, **kwargs) +############################################################ +# only allow simple strings +def is_str_safe(s): + for c in s: + if c not in ("._-@" + string.ascii_letters + string.digits): + return False + return True + + def import_module(module, search_path=None): """Import a module by name or path diff --git a/plugins/RoundRobinGenerator.py b/plugins/RoundRobinGenerator.py new file mode 100644 index 000000000..2a3fd4bda --- /dev/null +++ b/plugins/RoundRobinGenerator.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC +# SPDX-License-Identifier: Apache-2.0 + +""" +This module contains the RoundRobinGenerator class +""" + +from itertools import cycle +from typing import Any + +from glideinwms.lib.generators import export_generator, Generator, GeneratorError + + +class RoundRobinGenerator(Generator[Any]): + """Round-robin generator""" + + def __init__(self, context: Any = None): + super().__init__(context) + if "items" not in self.context: + raise GeneratorError("items not found in context for RoundRobinGenerator") + self.items_cycle = cycle(self.context["items"]) + + def generate(self, **kwargs) -> Any: + return next(self.items_cycle) + + +export_generator(RoundRobinGenerator) diff --git a/unittests/test_factory_glideFactoryLib.py b/unittests/test_factory_glideFactoryLib.py index 37e0620e7..7c3294045 100755 --- a/unittests/test_factory_glideFactoryLib.py +++ b/unittests/test_factory_glideFactoryLib.py @@ -66,12 +66,12 @@ getQStatusSF, getQStatusStale, hrs2sec, - is_str_safe, isGlideinUnrecoverable, secClass2Name, set_condor_integrity_checks, which, ) +from glideinwms.lib.util import is_str_safe from glideinwms.unittests.unittest_utils import FakeLogger, TestImportError try: diff --git a/unittests/test_frontend_element.py b/unittests/test_frontend_element.py index 4c9704ea1..22339fd54 100755 --- a/unittests/test_frontend_element.py +++ b/unittests/test_frontend_element.py @@ -245,7 +245,7 @@ def test_some_iterate_one_artifacts(self): self.gfe.stats = {"group": glideinFrontendMonitoring.groupStats()} self.gfe.published_frontend_name = f"{self.gfe.frontend_name}.XPVO_{self.gfe.group_name}" mockery = mock.MagicMock() - self.gfe.x509_proxy_plugin = mockery + self.gfe.credentials_plugin = mockery # keep logSupport.log.info in an array to search through later to # evaluate success glideinwms.frontend.glideinFrontendLib.logSupport.log = mockery diff --git a/unittests/test_frontend_glideinFrontendElement.py b/unittests/test_frontend_glideinFrontendElement.py index e845503c9..8a07d8b44 100755 --- a/unittests/test_frontend_glideinFrontendElement.py +++ b/unittests/test_frontend_glideinFrontendElement.py @@ -25,6 +25,7 @@ log_and_sum_factory_line, log_factory_header, ) +from glideinwms.lib.credentials import X509Cert from glideinwms.unittests.unittest_utils import FakeLogger LOG_DATA = [] @@ -129,7 +130,9 @@ def test_configure(self): b_ccm = os.environ.get(v) v = "X509_USER_PROXY" b_xup = os.environ.get(v) - self.gfe.configure() + with mock.patch("glideinwms.lib.credentials.create_credential") as mock_create_credential: + mock_create_credential.return_value = X509Cert() + self.gfe.configure() if self.verbose: print("\nc.glideinFrontendElement=%s" % self.gfe) print("\nc.dir glideinFrontendElement=%s" % dir(self.gfe)) @@ -197,7 +200,9 @@ def test_count_factory_entries_without_classads(self): assert False # TODO: implement your test here def test_deadvertiseAllClassads(self): - self.gfe.configure() + with mock.patch("glideinwms.lib.credentials.create_credential") as mock_create_credential: + mock_create_credential.return_value = X509Cert() + self.gfe.configure() self.gfe.deadvertiseAllClassads() @unittest.skip("for now") diff --git a/unittests/test_glideinFrontendPlugins.py b/unittests/test_glideinFrontendPlugins.py index 774010343..93b584997 100755 --- a/unittests/test_glideinFrontendPlugins.py +++ b/unittests/test_glideinFrontendPlugins.py @@ -34,7 +34,7 @@ try: from glideinwms.frontend import glideinFrontendPlugins - from glideinwms.frontend.glideinFrontendInterface import Credential + from glideinwms.frontend.glideinFrontendInterface import LegacyCredential except ImportError as err: raise TestImportError(str(err)) @@ -86,7 +86,7 @@ def getCredlist(self): (f, proxyfile) = tempfile.mkstemp() os.close(f) self.elementDescript.addproxy(proxyfile) - rtnlist.append(Credential(t, proxyfile, self.elementDescript)) + rtnlist.append(LegacyCredential(t, proxyfile, self.elementDescript)) return rtnlist def killCredlist(self, cred_list):