Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactored systematics code #385

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nmma/em/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from . import systematics


def from_list(self, systematics):
"""
Similar to `from_file` but instead of file buffer, takes a list of Prior strings
Expand Down
113 changes: 62 additions & 51 deletions nmma/em/systematics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import yaml
from pathlib import Path
import inspect
import warnings
from pathlib import Path

import yaml
from bilby.core.prior import analytical

warnings.simplefilter("module", DeprecationWarning)

Expand Down Expand Up @@ -39,26 +42,38 @@ def __init__(self, key, message):
"ztfr",
]

ALLOWED_DISTRIBUTIONS = dict(inspect.getmembers(analytical, inspect.isclass))


def get_positional_args(cls):
init_method = cls.__init__

signature = inspect.signature(init_method)
params = [
param.name
for param in signature.parameters.values()
if param.name != "self" and param.default == inspect.Parameter.empty
]

return params


DISTRIBUTION_PARAMETERS = {k: get_positional_args(v) for k, v in ALLOWED_DISTRIBUTIONS.items()}


def load_yaml(file_path):
return yaml.safe_load(Path(file_path).read_text())


def validate_only_one_true(yaml_dict):
for key, values in yaml_dict["config"].items():
if "value" not in values or type(values["value"]) is not bool:
raise ValidationError(
key, "'value' key must be present and be a boolean"
)
if "value" not in values or not isinstance(values["value"], bool):
raise ValidationError(key, "'value' key must be present and be a boolean")
true_count = sum(value["value"] for value in yaml_dict["config"].values())
if true_count > 1:
raise ValidationError(
"config", "Only one configuration key can be set to True at a time"
)
raise ValidationError("config", "Only one configuration key can be set to True at a time")
elif true_count == 0:
raise ValidationError(
"config", "At least one configuration key must be set to True"
)
raise ValidationError("config", "At least one configuration key must be set to True")


def validate_filters(filter_groups):
Expand Down Expand Up @@ -100,67 +115,63 @@ def validate_filters(filter_groups):


def validate_distribution(distribution):
if distribution != "Uniform":
dist_type = distribution.get("type")
if dist_type not in ALLOWED_DISTRIBUTIONS:
raise ValidationError(
"type",
f"Invalid distribution '{distribution}'. Only 'Uniform' distribution is supported",
"distribution type",
f"Invalid distribution '{dist_type}'. Allowed values are {', '.join([str(f) for f in ALLOWED_DISTRIBUTIONS])}",
)

required_params = DISTRIBUTION_PARAMETERS[dist_type]

def validate_fields(key, values, required_fields):
missing_fields = [
field for field in required_fields if values.get(field) is None
]
if missing_fields:
missing_params = set(required_params) - set(distribution.keys())
if missing_params:
raise ValidationError(
key, f"Missing fields: {', '.join(missing_fields)}"
"distribution", f"Missing required parameters for {dist_type} distribution: {', '.join(missing_params)}"
)
for field, expected_type in required_fields.items():
if not isinstance(values[field], expected_type):
raise ValidationError(
key, f"'{field}' must be of type {expected_type}"
)


def handle_withTime(key, values):
required_fields = {
"type": str,
"min": (float, int),
"max": (float, int),
"time_nodes": int,
"filters": list,
}
def create_prior_string(name, distribution):
dist_type = distribution.pop("type")
_ = distribution.pop("value", None)
_ = distribution.pop("time_nodes", None)
_ = distribution.pop("filters", None)
prior_class = ALLOWED_DISTRIBUTIONS[dist_type]
required_params = DISTRIBUTION_PARAMETERS[dist_type]
params = distribution.copy()

extra_params = set(params.keys()) - set(required_params)
if extra_params:
warnings.warn(f"Distribution parameters {extra_params} are not used by {dist_type} distribution and will be ignored")

params = {k: params[k] for k in required_params if k in params}

return f"{name} = {repr(prior_class(**params, name=name))}"

validate_fields(key, values, required_fields)

def handle_withTime(values):
validate_distribution(values)
filter_groups = values.get("filters", [])
validate_filters(filter_groups)
distribution = values.get("type")
validate_distribution(distribution)
result = []
time_nodes = values["time_nodes"]

for filter_group in filter_groups:
if isinstance(filter_group, list):
filter_name = "___".join(filter_group)
else:
filter_name = filter_group if filter_group is not None else "all"

for n in range(1, values["time_nodes"] + 1):
result.append(
f'sys_err_{filter_name}{n} = {values["type"]}(minimum={values["min"]},maximum={values["max"]},name="sys_err_{filter_name}{n}")'
)
for n in range(1, time_nodes + 1):
prior_name = f"sys_err_{filter_name}{n}"
result.append(create_prior_string(prior_name, values.copy()))

return result


def handle_withoutTime(key, values):
required_fields = {"type": str, "min": (float, int), "max": (float, int)}
validate_fields(key, values, required_fields)
distribution = values.get("type")
validate_distribution(distribution)
return [
f'sys_err = {values["type"]}(minimum={values["min"]},maximum={values["max"]},name="sys_err")'
]
def handle_withoutTime(values):
validate_distribution(values)
return [create_prior_string("sys_err", values)]


config_handlers = {
Expand All @@ -175,5 +186,5 @@ def main(yaml_file_path):
results = []
for key, values in yaml_dict["config"].items():
if values["value"] and key in config_handlers:
results.extend(config_handlers[key](key, values))
return results
results.extend(config_handlers[key](values))
return results
4 changes: 2 additions & 2 deletions nmma/em/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,8 @@ def calc_lc(
y_pred, sigma2_pred = gp.predict(
np.atleast_2d(param_list_postprocess), return_std=True
)
cAproj[i] = y_pred
cAstd[i] = sigma2_pred
cAproj[i] = np.squeeze(y_pred)
cAstd[i] = np.squeeze(sigma2_pred)

# coverrors = np.dot(VA[:, :n_coeff], np.dot(np.power(np.diag(cAstd[:n_coeff]), 2), VA[:, :n_coeff].T))
# errors = np.diag(coverrors)
Expand Down
8 changes: 4 additions & 4 deletions nmma/tests/data/systematics_with_time.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ config:
- null # you can keep it or remove it, it will still be parsed as None (filter independent)
time_nodes: 4
type: "Uniform"
min: 0
max: 2
minimum: 0
maximum: 2
withoutTime:
value: false
type: "Uniform"
min: 0
max: 2
minimum: 0
maximum: 2
8 changes: 4 additions & 4 deletions nmma/tests/data/systematics_without_time.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ config:
- null # you can keep it or remove it, it will still be parsed as None (filter independent)
time_nodes: 4
type: "Uniform"
min: 0
max: 2
minimum: 0
maximum: 2
withoutTime:
value: true
type: "Uniform"
min: 0
max: 2
minimum: 0
maximum: 2
90 changes: 28 additions & 62 deletions nmma/tests/systematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
ValidationError,
validate_only_one_true,
validate_filters,
validate_distribution,
validate_fields,
handle_withTime,
handle_withoutTime,
main,
ALLOWED_FILTERS,
ALLOWED_DISTRIBUTIONS
)


Expand All @@ -22,17 +21,17 @@ def sample_yaml_content():
withTime:
value: true
type: Uniform
min: 0.0
max: 1.0
minimum: 0.0
maximum: 1.0
time_nodes: 2
filters:
- [bessellb, bessellv]
- ztfr
withoutTime:
value: false
type: Uniform
min: 0.0
max: 1.0
minimum: 0.0
maximum: 1.0
"""


Expand Down Expand Up @@ -65,40 +64,18 @@ def test_validate_filters_invalid():
validate_filters(invalid_filters)


def test_validate_distribution_valid():
validate_distribution("Uniform") # Should not raise an exception


def test_validate_distribution_invalid():
with pytest.raises(ValidationError, match="Invalid distribution 'Normal'"):
validate_distribution("Normal")


def test_validate_fields_valid():
valid_values = {"type": "Uniform", "min": 0.0, "max": 1.0}
required_fields = {"type": str, "min": (float, int), "max": (float, int)}
validate_fields("test", valid_values, required_fields) # Should not raise an exception


def test_validate_fields_invalid():
invalid_values = {"type": "Uniform", "min": "0.0", "max": 1.0}
required_fields = {"type": str, "min": (float, int), "max": (float, int)}
with pytest.raises(ValidationError, match="'min' must be of type"):
validate_fields("test", invalid_values, required_fields)


def test_handle_withTime():
values = {"type": "Uniform", "min": 0.0, "max": 1.0, "time_nodes": 2, "filters": [["bessellb", "bessellv"], "ztfr"]}
result = handle_withTime("withTime", values)
values = {"type": "Uniform", "minimum": 0.0, "maximum": 1.0, "time_nodes": 2, "filters": [["bessellb", "bessellv"], "ztfr"]}
result = handle_withTime(values)
assert "sys_err_bessellb___bessellv1" in result[0]
assert "sys_err_ztfr2" in result[3]


def test_handle_withoutTime():
values = {"type": "Uniform", "min": 0.0, "max": 1.0}
result = handle_withoutTime("withoutTime", values)
values = {"type": "Uniform", "minimum": 0.0, "maximum": 1.0}
result = handle_withoutTime(values)
assert len(result) == 1
assert 'sys_err = Uniform(minimum=0.0,maximum=1.0,name="sys_err")' in result[0]
assert "sys_err = Uniform(minimum=0.0, maximum=1.0, name='sys_err', latex_label='sys_err', unit=None, boundary=None)" in result[0]


def test_main(sample_yaml_file):
Expand Down Expand Up @@ -146,45 +123,34 @@ def test_validate_filters_empty_list():
validate_filters([]) # Should not raise an exception


def test_validate_distribution_case_sensitive():
with pytest.raises(ValidationError, match="Invalid distribution 'uniform'"):
validate_distribution("uniform") # Should be "Uniform"
def test_validate_distribution_valid():
assert ALLOWED_DISTRIBUTIONS["Uniform"] # Should not raise an exception


def test_validate_distribution_invalid():
with pytest.raises(KeyError):
assert ALLOWED_DISTRIBUTIONS["nonuniform"] # Should be "Uniform"


@pytest.mark.parametrize("invalid_type", [123, True, [], {}])
def test_validate_fields_invalid_types(invalid_type):
invalid_values = {"type": invalid_type, "min": 0.0, "max": 1.0}
required_fields = {"type": str, "min": (float, int), "max": (float, int)}
with pytest.raises(ValidationError, match="'type' must be of type"):
validate_fields("test", invalid_values, required_fields)
def test_validate_distribution_case_sensitive():
with pytest.raises(KeyError):
assert (ALLOWED_DISTRIBUTIONS["uniform"]) # Should be "Uniform"


def test_handle_withTime_single_filter():
values = {"type": "Uniform", "min": 0.0, "max": 1.0, "time_nodes": 2, "filters": ["ztfr"]}
result = handle_withTime("withTime", values)
values = {"type": "Uniform", "minimum": 0.0, "maximum": 1.0, "time_nodes": 2, "filters": ["ztfr"]}
result = handle_withTime(values)
assert len(result) == 2
assert all("sys_err_ztfr" in line for line in result)


def test_handle_withTime_all_filters():
values = {"type": "Uniform", "min": 0.0, "max": 1.0, "time_nodes": 1, "filters": [None]}
result = handle_withTime("withTime", values)
values = {"type": "Uniform", "minimum": 0.0, "maximum": 1.0, "time_nodes": 1, "filters": [None]}
result = handle_withTime(values)
assert len(result) == 1
assert "sys_err_all1" in result[0]


def test_handle_withTime_integer_bounds():
values = {"type": "Uniform", "min": 0, "max": 10, "time_nodes": 1, "filters": ["ztfr"]}
result = handle_withTime("withTime", values)
assert "minimum=0" in result[0] and "maximum=10" in result[0]


def test_handle_withoutTime_integer_bounds():
values = {"type": "Uniform", "min": 0, "max": 10}
result = handle_withoutTime("withoutTime", values)
assert "minimum=0" in result[0] and "maximum=10" in result[0]


def test_main_withoutTime(tmp_path):
yaml_content = """
config:
Expand All @@ -193,8 +159,8 @@ def test_main_withoutTime(tmp_path):
withoutTime:
value: true
type: Uniform
min: 0.0
max: 1.0
minimum: 0.0
maximum: 1.0
"""
yaml_file = tmp_path / "withoutTime_config.yaml"
yaml_file.write_text(yaml_content)
Expand All @@ -219,8 +185,8 @@ def test_main_empty_config(tmp_path):

@pytest.mark.parametrize("filter_name", ALLOWED_FILTERS)
def test_all_allowed_filters(filter_name):
values = {"type": "Uniform", "min": 0.0, "max": 1.0, "time_nodes": 1, "filters": [filter_name]}
result = handle_withTime("withTime", values)
values = {"type": "Uniform", "minimum": 0.0, "maximum": 1.0, "time_nodes": 1, "filters": [filter_name]}
result = handle_withTime(values)
assert len(result) == 1
assert f"sys_err_{filter_name}1" in result[0]

Expand Down
Loading