Skip to content

Commit

Permalink
Switch strict meaning in validate_number_positive
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jun 14, 2024
1 parent dfa95d8 commit 7a11445
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
19 changes: 13 additions & 6 deletions metalearners/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# # Copyright (c) QuantCo 2024-2024
# # SPDX-License-Identifier: BSD-3-Clause

import operator
from collections.abc import Callable
from inspect import signature
from operator import le, lt
Expand Down Expand Up @@ -66,14 +65,22 @@ def validate_all_vectors_same_index(*args: Vector) -> None:


def validate_number_positive(
value: int | float, name: str, strict: bool = False
value: int | float, name: str, strict: bool = True
) -> None:
"""Validates that a number is positive.
If ``strict = True`` then it validates that the number is strictly positive.
"""
if strict:
comparison = operator.lt
if value <= 0:
raise ValueError(
f"{name} was expected to be strictly positive but was {value}."
)
else:
comparison = operator.le
if comparison(value, 0):
raise ValueError(f"{name} was expected to be positive but was {value}.")
if value < 0:
raise ValueError(
f"{name} was expected to be positive or zero but was {value}."
)


def check_propensity_score(
Expand Down
2 changes: 1 addition & 1 deletion metalearners/cross_fit_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _validate_data_match_prior_split(
) -> None:
"""Validate whether the previous test_indices and the passed data are based on the
same number of observations."""
validate_number_positive(n_observations, "n_observations", strict=False)
validate_number_positive(n_observations, "n_observations", strict=True)
if test_indices is None:
return
expected_n_observations = sum(len(x) for x in test_indices)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_cross_fit_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def test_crossfitestimator_n_folds_1(rng, sample_size):
)
def test_validate_data_match(n_observations, test_indices, success):
if n_observations < 1:
with pytest.raises(ValueError, match="was expected to be positive"):
with pytest.raises(
ValueError, match=r"was expected to be (strictly )?positive"
):
_validate_data_match_prior_split(n_observations, test_indices)
return
if success:
Expand Down

0 comments on commit 7a11445

Please sign in to comment.