Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactored systematics code #385

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nmma/em/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from . import systematics


def from_list(self, systematics):
"""
Similar to `from_file` but instead of file buffer, takes a list of Prior strings
Expand Down
113 changes: 62 additions & 51 deletions nmma/em/systematics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import yaml
from pathlib import Path
import inspect
import warnings
from pathlib import Path

import yaml
from bilby.core.prior import analytical

warnings.simplefilter("module", DeprecationWarning)

Expand Down Expand Up @@ -39,26 +42,38 @@ def __init__(self, key, message):
"ztfr",
]

ALLOWED_DISTRIBUTIONS = dict(inspect.getmembers(analytical, inspect.isclass))


def get_positional_args(cls):
init_method = cls.__init__

signature = inspect.signature(init_method)
params = [
param.name
for param in signature.parameters.values()
if param.name != "self" and param.default == inspect.Parameter.empty
]

return params


DISTRIBUTION_PARAMETERS = {k: get_positional_args(v) for k, v in ALLOWED_DISTRIBUTIONS.items()}


def load_yaml(file_path):
return yaml.safe_load(Path(file_path).read_text())


def validate_only_one_true(yaml_dict):
for key, values in yaml_dict["config"].items():
if "value" not in values or type(values["value"]) is not bool:
raise ValidationError(
key, "'value' key must be present and be a boolean"
)
if "value" not in values or not isinstance(values["value"], bool):
raise ValidationError(key, "'value' key must be present and be a boolean")
true_count = sum(value["value"] for value in yaml_dict["config"].values())
if true_count > 1:
raise ValidationError(
"config", "Only one configuration key can be set to True at a time"
)
raise ValidationError("config", "Only one configuration key can be set to True at a time")
elif true_count == 0:
raise ValidationError(
"config", "At least one configuration key must be set to True"
)
raise ValidationError("config", "At least one configuration key must be set to True")


def validate_filters(filter_groups):
Expand Down Expand Up @@ -100,67 +115,63 @@ def validate_filters(filter_groups):


def validate_distribution(distribution):
if distribution != "Uniform":
dist_type = distribution.get("type")
if dist_type not in ALLOWED_DISTRIBUTIONS:
raise ValidationError(
"type",
f"Invalid distribution '{distribution}'. Only 'Uniform' distribution is supported",
"distribution type",
f"Invalid distribution '{dist_type}'. Allowed values are {', '.join([str(f) for f in ALLOWED_DISTRIBUTIONS])}",
)

required_params = DISTRIBUTION_PARAMETERS[dist_type]

def validate_fields(key, values, required_fields):
missing_fields = [
field for field in required_fields if values.get(field) is None
]
if missing_fields:
missing_params = set(required_params) - set(distribution.keys())
if missing_params:
raise ValidationError(
key, f"Missing fields: {', '.join(missing_fields)}"
"distribution", f"Missing required parameters for {dist_type} distribution: {', '.join(missing_params)}"
)
for field, expected_type in required_fields.items():
if not isinstance(values[field], expected_type):
raise ValidationError(
key, f"'{field}' must be of type {expected_type}"
)


def handle_withTime(key, values):
required_fields = {
"type": str,
"min": (float, int),
"max": (float, int),
"time_nodes": int,
"filters": list,
}
def create_prior_string(name, distribution):
dist_type = distribution.pop("type")
_ = distribution.pop("value")
_ = distribution.pop("time_nodes", None)
_ = distribution.pop("filters", None)
prior_class = ALLOWED_DISTRIBUTIONS[dist_type]
required_params = DISTRIBUTION_PARAMETERS[dist_type]
params = distribution.copy()

extra_params = set(params.keys()) - set(required_params)
if extra_params:
warnings.warn(f"Distribution parameters {extra_params} are not used by {dist_type} distribution and will be ignored")

params = {k: params[k] for k in required_params if k in params}

return f"{name} = {repr(prior_class(**params, name=name))}"

validate_fields(key, values, required_fields)

def handle_withTime(values):
validate_distribution(values)
filter_groups = values.get("filters", [])
validate_filters(filter_groups)
distribution = values.get("type")
validate_distribution(distribution)
result = []
time_nodes = values["time_nodes"]

for filter_group in filter_groups:
if isinstance(filter_group, list):
filter_name = "___".join(filter_group)
else:
filter_name = filter_group if filter_group is not None else "all"

for n in range(1, values["time_nodes"] + 1):
result.append(
f'sys_err_{filter_name}{n} = {values["type"]}(minimum={values["min"]},maximum={values["max"]},name="sys_err_{filter_name}{n}")'
)
for n in range(1, time_nodes + 1):
prior_name = f"sys_err_{filter_name}{n}"
result.append(create_prior_string(prior_name, values.copy()))

return result


def handle_withoutTime(key, values):
required_fields = {"type": str, "min": (float, int), "max": (float, int)}
validate_fields(key, values, required_fields)
distribution = values.get("type")
validate_distribution(distribution)
return [
f'sys_err = {values["type"]}(minimum={values["min"]},maximum={values["max"]},name="sys_err")'
]
def handle_withoutTime(values):
validate_distribution(values)
return [create_prior_string("sys_err", values)]


config_handlers = {
Expand All @@ -175,5 +186,5 @@ def main(yaml_file_path):
results = []
for key, values in yaml_dict["config"].items():
if values["value"] and key in config_handlers:
results.extend(config_handlers[key](key, values))
return results
results.extend(config_handlers[key](values))
return results
4 changes: 2 additions & 2 deletions nmma/em/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,8 @@ def calc_lc(
y_pred, sigma2_pred = gp.predict(
np.atleast_2d(param_list_postprocess), return_std=True
)
cAproj[i] = y_pred
cAstd[i] = sigma2_pred
cAproj[i] = np.squeeze(y_pred)
cAstd[i] = np.squeeze(sigma2_pred)

# coverrors = np.dot(VA[:, :n_coeff], np.dot(np.power(np.diag(cAstd[:n_coeff]), 2), VA[:, :n_coeff].T))
# errors = np.diag(coverrors)
Expand Down
Loading
Loading