diff --git a/stardis/base.py b/stardis/base.py index 0754e325..a2a293ed 100644 --- a/stardis/base.py +++ b/stardis/base.py @@ -10,9 +10,7 @@ logger = logging.getLogger(__name__) -def run_stardis( - config_fname, tracing_lambdas_or_nus, add_config_keys=None, add_config_vals=None -): +def run_stardis(config_fname, tracing_lambdas_or_nus, add_config_dict=None): """ Runs a STARDIS simulation. @@ -24,10 +22,8 @@ def run_stardis( Numpy array of the frequencies or wavelengths to calculate the spectrum for. Must have units attached to it, with dimensions of either length or inverse time. - add_config_keys : list, optional - List of additional keys to add or overwrite for the configuration file. - add_config_vals : list, optional - List of corresponding additional values to add to the configuration file. + add_config_dict : dict, optional + Dictionary containing the the keys and values of the configuration to add or overwrite. Returns ------- @@ -37,9 +33,7 @@ def run_stardis( tracing_nus = tracing_lambdas_or_nus.to(u.Hz, u.spectral()) - config, adata, stellar_model = parse_config_to_model( - config_fname, add_config_keys, add_config_vals - ) + config, adata, stellar_model = parse_config_to_model(config_fname, add_config_dict) set_num_threads(config.n_threads) stellar_plasma = create_stellar_plasma(stellar_model, adata, config) stellar_radiation_field = create_stellar_radiation_field( diff --git a/stardis/io/base.py b/stardis/io/base.py index ba081689..eed7260e 100644 --- a/stardis/io/base.py +++ b/stardis/io/base.py @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) -def parse_config_to_model(config_fname, add_config_keys=None, add_config_vals=None): +def parse_config_to_model(config_fname, add_config_dict): """ Parses the config and model files and outputs python objects to be passed into run stardis so they can be individually modified in python. @@ -47,28 +47,18 @@ def parse_config_to_model(config_fname, add_config_keys=None, add_config_vals=No raise ValueError("Config failed to validate. Check the config file.") if ( - not add_config_keys + not add_config_dict ): # If a dictionary was passed, update the config with the dictionary pass else: logger.info("Updating config with additional keys and values") - if isinstance(add_config_keys, str): - # Directly set the config item if add_config_keys is a string - config.set_config_item(add_config_keys, add_config_vals) - else: - # Proceed with iteration if add_config_keys is not a string - if len(add_config_keys) != len(add_config_vals): - raise ValueError( - "Length of additional config keys and values do not match." - ) + for key, val in add_config_dict.items(): try: - for key, val in zip(add_config_keys, add_config_vals): - config.set_config_item(key, val) + config.set_config_item(key, val) except: raise ValueError( - f"{add_config_keys} not a valid type. Should be a single string or a list of strings for keys." + f"{key} not a valid type. Should be a string for keys." ) - try: config_dict = validate_dict(config, schemapath=SCHEMA_PATH) except: