Skip to content

Commit

Permalink
Merge pull request #385 from sahiljhawar/systematics_refactor
Browse files Browse the repository at this point in the history
refactored systematics code
  • Loading branch information
tsunhopang authored Nov 5, 2024
2 parents 34f2079 + 6516602 commit 96cd587
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 123 deletions.
1 change: 1 addition & 0 deletions nmma/em/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from . import systematics


def from_list(self, systematics):
"""
Similar to `from_file` but instead of file buffer, takes a list of Prior strings
Expand Down
113 changes: 62 additions & 51 deletions nmma/em/systematics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import yaml
from pathlib import Path
import inspect
import warnings
from pathlib import Path

import yaml
from bilby.core.prior import analytical

warnings.simplefilter("module", DeprecationWarning)

Expand Down Expand Up @@ -39,26 +42,38 @@ def __init__(self, key, message):
"ztfr",
]

ALLOWED_DISTRIBUTIONS = dict(inspect.getmembers(analytical, inspect.isclass))


def get_positional_args(cls):
init_method = cls.__init__

signature = inspect.signature(init_method)
params = [
param.name
for param in signature.parameters.values()
if param.name != "self" and param.default == inspect.Parameter.empty
]

return params


DISTRIBUTION_PARAMETERS = {k: get_positional_args(v) for k, v in ALLOWED_DISTRIBUTIONS.items()}


def load_yaml(file_path):
return yaml.safe_load(Path(file_path).read_text())


def validate_only_one_true(yaml_dict):
for key, values in yaml_dict["config"].items():
if "value" not in values or type(values["value"]) is not bool:
raise ValidationError(
key, "'value' key must be present and be a boolean"
)
if "value" not in values or not isinstance(values["value"], bool):
raise ValidationError(key, "'value' key must be present and be a boolean")
true_count = sum(value["value"] for value in yaml_dict["config"].values())
if true_count > 1:
raise ValidationError(
"config", "Only one configuration key can be set to True at a time"
)
raise ValidationError("config", "Only one configuration key can be set to True at a time")
elif true_count == 0:
raise ValidationError(
"config", "At least one configuration key must be set to True"
)
raise ValidationError("config", "At least one configuration key must be set to True")


def validate_filters(filter_groups):
Expand Down Expand Up @@ -100,67 +115,63 @@ def validate_filters(filter_groups):


def validate_distribution(distribution):
if distribution != "Uniform":
dist_type = distribution.get("type")
if dist_type not in ALLOWED_DISTRIBUTIONS:
raise ValidationError(
"type",
f"Invalid distribution '{distribution}'. Only 'Uniform' distribution is supported",
"distribution type",
f"Invalid distribution '{dist_type}'. Allowed values are {', '.join([str(f) for f in ALLOWED_DISTRIBUTIONS])}",
)

required_params = DISTRIBUTION_PARAMETERS[dist_type]

def validate_fields(key, values, required_fields):
missing_fields = [
field for field in required_fields if values.get(field) is None
]
if missing_fields:
missing_params = set(required_params) - set(distribution.keys())
if missing_params:
raise ValidationError(
key, f"Missing fields: {', '.join(missing_fields)}"
"distribution", f"Missing required parameters for {dist_type} distribution: {', '.join(missing_params)}"
)
for field, expected_type in required_fields.items():
if not isinstance(values[field], expected_type):
raise ValidationError(
key, f"'{field}' must be of type {expected_type}"
)


def handle_withTime(key, values):
required_fields = {
"type": str,
"min": (float, int),
"max": (float, int),
"time_nodes": int,
"filters": list,
}
def create_prior_string(name, distribution):
dist_type = distribution.pop("type")
_ = distribution.pop("value", None)
_ = distribution.pop("time_nodes", None)
_ = distribution.pop("filters", None)
prior_class = ALLOWED_DISTRIBUTIONS[dist_type]
required_params = DISTRIBUTION_PARAMETERS[dist_type]
params = distribution.copy()

extra_params = set(params.keys()) - set(required_params)
if extra_params:
warnings.warn(f"Distribution parameters {extra_params} are not used by {dist_type} distribution and will be ignored")

params = {k: params[k] for k in required_params if k in params}

return f"{name} = {repr(prior_class(**params, name=name))}"

validate_fields(key, values, required_fields)

def handle_withTime(values):
validate_distribution(values)
filter_groups = values.get("filters", [])
validate_filters(filter_groups)
distribution = values.get("type")
validate_distribution(distribution)
result = []
time_nodes = values["time_nodes"]

for filter_group in filter_groups:
if isinstance(filter_group, list):
filter_name = "___".join(filter_group)
else:
filter_name = filter_group if filter_group is not None else "all"

for n in range(1, values["time_nodes"] + 1):
result.append(
f'sys_err_{filter_name}{n} = {values["type"]}(minimum={values["min"]},maximum={values["max"]},name="sys_err_{filter_name}{n}")'
)
for n in range(1, time_nodes + 1):
prior_name = f"sys_err_{filter_name}{n}"
result.append(create_prior_string(prior_name, values.copy()))

return result


def handle_withoutTime(key, values):
required_fields = {"type": str, "min": (float, int), "max": (float, int)}
validate_fields(key, values, required_fields)
distribution = values.get("type")
validate_distribution(distribution)
return [
f'sys_err = {values["type"]}(minimum={values["min"]},maximum={values["max"]},name="sys_err")'
]
def handle_withoutTime(values):
validate_distribution(values)
return [create_prior_string("sys_err", values)]


config_handlers = {
Expand All @@ -175,5 +186,5 @@ def main(yaml_file_path):
results = []
for key, values in yaml_dict["config"].items():
if values["value"] and key in config_handlers:
results.extend(config_handlers[key](key, values))
return results
results.extend(config_handlers[key](values))
return results
4 changes: 2 additions & 2 deletions nmma/em/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,8 @@ def calc_lc(
y_pred, sigma2_pred = gp.predict(
np.atleast_2d(param_list_postprocess), return_std=True
)
cAproj[i] = y_pred
cAstd[i] = sigma2_pred
cAproj[i] = np.squeeze(y_pred)
cAstd[i] = np.squeeze(sigma2_pred)

# coverrors = np.dot(VA[:, :n_coeff], np.dot(np.power(np.diag(cAstd[:n_coeff]), 2), VA[:, :n_coeff].T))
# errors = np.diag(coverrors)
Expand Down
8 changes: 4 additions & 4 deletions nmma/tests/data/systematics_with_time.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ config:
- null # you can keep it or remove it, it will still be parsed as None (filter independent)
time_nodes: 4
type: "Uniform"
min: 0
max: 2
minimum: 0
maximum: 2
withoutTime:
value: false
type: "Uniform"
min: 0
max: 2
minimum: 0
maximum: 2
8 changes: 4 additions & 4 deletions nmma/tests/data/systematics_without_time.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ config:
- null # you can keep it or remove it, it will still be parsed as None (filter independent)
time_nodes: 4
type: "Uniform"
min: 0
max: 2
minimum: 0
maximum: 2
withoutTime:
value: true
type: "Uniform"
min: 0
max: 2
minimum: 0
maximum: 2
90 changes: 28 additions & 62 deletions nmma/tests/systematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
ValidationError,
validate_only_one_true,
validate_filters,
validate_distribution,
validate_fields,
handle_withTime,
handle_withoutTime,
main,
ALLOWED_FILTERS,
ALLOWED_DISTRIBUTIONS
)


Expand All @@ -22,17 +21,17 @@ def sample_yaml_content():
withTime:
value: true
type: Uniform
min: 0.0
max: 1.0
minimum: 0.0
maximum: 1.0
time_nodes: 2
filters:
- [bessellb, bessellv]
- ztfr
withoutTime:
value: false
type: Uniform
min: 0.0
max: 1.0
minimum: 0.0
maximum: 1.0
"""


Expand Down Expand Up @@ -65,40 +64,18 @@ def test_validate_filters_invalid():
validate_filters(invalid_filters)


def test_validate_distribution_valid():
validate_distribution("Uniform") # Should not raise an exception


def test_validate_distribution_invalid():
with pytest.raises(ValidationError, match="Invalid distribution 'Normal'"):
validate_distribution("Normal")


def test_validate_fields_valid():
valid_values = {"type": "Uniform", "min": 0.0, "max": 1.0}
required_fields = {"type": str, "min": (float, int), "max": (float, int)}
validate_fields("test", valid_values, required_fields) # Should not raise an exception


def test_validate_fields_invalid():
invalid_values = {"type": "Uniform", "min": "0.0", "max": 1.0}
required_fields = {"type": str, "min": (float, int), "max": (float, int)}
with pytest.raises(ValidationError, match="'min' must be of type"):
validate_fields("test", invalid_values, required_fields)


def test_handle_withTime():
values = {"type": "Uniform", "min": 0.0, "max": 1.0, "time_nodes": 2, "filters": [["bessellb", "bessellv"], "ztfr"]}
result = handle_withTime("withTime", values)
values = {"type": "Uniform", "minimum": 0.0, "maximum": 1.0, "time_nodes": 2, "filters": [["bessellb", "bessellv"], "ztfr"]}
result = handle_withTime(values)
assert "sys_err_bessellb___bessellv1" in result[0]
assert "sys_err_ztfr2" in result[3]


def test_handle_withoutTime():
values = {"type": "Uniform", "min": 0.0, "max": 1.0}
result = handle_withoutTime("withoutTime", values)
values = {"type": "Uniform", "minimum": 0.0, "maximum": 1.0}
result = handle_withoutTime(values)
assert len(result) == 1
assert 'sys_err = Uniform(minimum=0.0,maximum=1.0,name="sys_err")' in result[0]
assert "sys_err = Uniform(minimum=0.0, maximum=1.0, name='sys_err', latex_label='sys_err', unit=None, boundary=None)" in result[0]


def test_main(sample_yaml_file):
Expand Down Expand Up @@ -146,45 +123,34 @@ def test_validate_filters_empty_list():
validate_filters([]) # Should not raise an exception


def test_validate_distribution_case_sensitive():
with pytest.raises(ValidationError, match="Invalid distribution 'uniform'"):
validate_distribution("uniform") # Should be "Uniform"
def test_validate_distribution_valid():
assert ALLOWED_DISTRIBUTIONS["Uniform"] # Should not raise an exception


def test_validate_distribution_invalid():
with pytest.raises(KeyError):
assert ALLOWED_DISTRIBUTIONS["nonuniform"] # Should be "Uniform"


@pytest.mark.parametrize("invalid_type", [123, True, [], {}])
def test_validate_fields_invalid_types(invalid_type):
invalid_values = {"type": invalid_type, "min": 0.0, "max": 1.0}
required_fields = {"type": str, "min": (float, int), "max": (float, int)}
with pytest.raises(ValidationError, match="'type' must be of type"):
validate_fields("test", invalid_values, required_fields)
def test_validate_distribution_case_sensitive():
with pytest.raises(KeyError):
assert (ALLOWED_DISTRIBUTIONS["uniform"]) # Should be "Uniform"


def test_handle_withTime_single_filter():
values = {"type": "Uniform", "min": 0.0, "max": 1.0, "time_nodes": 2, "filters": ["ztfr"]}
result = handle_withTime("withTime", values)
values = {"type": "Uniform", "minimum": 0.0, "maximum": 1.0, "time_nodes": 2, "filters": ["ztfr"]}
result = handle_withTime(values)
assert len(result) == 2
assert all("sys_err_ztfr" in line for line in result)


def test_handle_withTime_all_filters():
values = {"type": "Uniform", "min": 0.0, "max": 1.0, "time_nodes": 1, "filters": [None]}
result = handle_withTime("withTime", values)
values = {"type": "Uniform", "minimum": 0.0, "maximum": 1.0, "time_nodes": 1, "filters": [None]}
result = handle_withTime(values)
assert len(result) == 1
assert "sys_err_all1" in result[0]


def test_handle_withTime_integer_bounds():
values = {"type": "Uniform", "min": 0, "max": 10, "time_nodes": 1, "filters": ["ztfr"]}
result = handle_withTime("withTime", values)
assert "minimum=0" in result[0] and "maximum=10" in result[0]


def test_handle_withoutTime_integer_bounds():
values = {"type": "Uniform", "min": 0, "max": 10}
result = handle_withoutTime("withoutTime", values)
assert "minimum=0" in result[0] and "maximum=10" in result[0]


def test_main_withoutTime(tmp_path):
yaml_content = """
config:
Expand All @@ -193,8 +159,8 @@ def test_main_withoutTime(tmp_path):
withoutTime:
value: true
type: Uniform
min: 0.0
max: 1.0
minimum: 0.0
maximum: 1.0
"""
yaml_file = tmp_path / "withoutTime_config.yaml"
yaml_file.write_text(yaml_content)
Expand All @@ -219,8 +185,8 @@ def test_main_empty_config(tmp_path):

@pytest.mark.parametrize("filter_name", ALLOWED_FILTERS)
def test_all_allowed_filters(filter_name):
values = {"type": "Uniform", "min": 0.0, "max": 1.0, "time_nodes": 1, "filters": [filter_name]}
result = handle_withTime("withTime", values)
values = {"type": "Uniform", "minimum": 0.0, "maximum": 1.0, "time_nodes": 1, "filters": [filter_name]}
result = handle_withTime(values)
assert len(result) == 1
assert f"sys_err_{filter_name}1" in result[0]

Expand Down

0 comments on commit 96cd587

Please sign in to comment.