Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify updating config logic #216

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions stardis/base.py
Original file line number Diff line number Diff line change
@@ -10,9 +10,7 @@
logger = logging.getLogger(__name__)


def run_stardis(
config_fname, tracing_lambdas_or_nus, add_config_keys=None, add_config_vals=None
):
def run_stardis(config_fname, tracing_lambdas_or_nus, add_config_dict=None):
"""
Runs a STARDIS simulation.

@@ -24,10 +22,8 @@ def run_stardis(
Numpy array of the frequencies or wavelengths to calculate the
spectrum for. Must have units attached to it, with dimensions
of either length or inverse time.
add_config_keys : list, optional
List of additional keys to add or overwrite for the configuration file.
add_config_vals : list, optional
List of corresponding additional values to add to the configuration file.
add_config_dict : dict, optional
Dictionary containing the the keys and values of the configuration to add or overwrite.

Returns
-------
@@ -37,9 +33,7 @@ def run_stardis(

tracing_nus = tracing_lambdas_or_nus.to(u.Hz, u.spectral())

config, adata, stellar_model = parse_config_to_model(
config_fname, add_config_keys, add_config_vals
)
config, adata, stellar_model = parse_config_to_model(config_fname, add_config_dict)
set_num_threads(config.n_threads)
stellar_plasma = create_stellar_plasma(stellar_model, adata, config)
stellar_radiation_field = create_stellar_radiation_field(
20 changes: 5 additions & 15 deletions stardis/io/base.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
logger = logging.getLogger(__name__)


def parse_config_to_model(config_fname, add_config_keys=None, add_config_vals=None):
def parse_config_to_model(config_fname, add_config_dict):
"""
Parses the config and model files and outputs python objects to be passed into run stardis so they can be individually modified in python.

@@ -47,28 +47,18 @@ def parse_config_to_model(config_fname, add_config_keys=None, add_config_vals=No
raise ValueError("Config failed to validate. Check the config file.")

if (
not add_config_keys
not add_config_dict
): # If a dictionary was passed, update the config with the dictionary
pass
else:
logger.info("Updating config with additional keys and values")
if isinstance(add_config_keys, str):
# Directly set the config item if add_config_keys is a string
config.set_config_item(add_config_keys, add_config_vals)
else:
# Proceed with iteration if add_config_keys is not a string
if len(add_config_keys) != len(add_config_vals):
raise ValueError(
"Length of additional config keys and values do not match."
)
for key, val in add_config_dict.items():
try:
for key, val in zip(add_config_keys, add_config_vals):
config.set_config_item(key, val)
config.set_config_item(key, val)
except:
raise ValueError(
f"{add_config_keys} not a valid type. Should be a single string or a list of strings for keys."
f"{key} not a valid type. Should be a string for keys."
)

try:
config_dict = validate_dict(config, schemapath=SCHEMA_PATH)
except: