Skip to content

Commit

Permalink
style: black on src n test
Browse files Browse the repository at this point in the history
  • Loading branch information
shafayetShafee committed Aug 14, 2024
1 parent b097e4c commit 6d36970
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 35 deletions.
68 changes: 41 additions & 27 deletions src/skmiscpy/cbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def compute_smd(
The name of the column containing weights. Defaults to None.
std_binary: bool
Should the mean differences for binary variables (i.e., difference in proportion)
Should the mean differences for binary variables (i.e., difference in proportion)
be standardized or not. Default is False. See notes.
estimand : str, optional
Expand All @@ -53,17 +53,17 @@ def compute_smd(
-----
The mean differences for continuous variables are standardized so that they are on the same scale
and so that they can be compared across variables, and they allow for a simple interpretation even
when the details of the variable's original scale are unclear to the analyst.
None of these advantages are passed to binary variables because binary variables are already on the
same scale (i.e., a proportion), and the scale is easily interpretable. In addition, the details of
standardizing the proportion difference of a binary variable involve dividing the proportion difference
by a variance, but the variance of a binary variable is a function of its proportion. Standardizing
the proportion difference of a binary variable can yield the following counterintuitive result:
if P\ :sub:`T`\ = 0.2 and P\ :sub:`C`\ = 0.3, the standardized difference in proportion would be
different from that if P\ :sub:`T`\ = 0.5 and P\ :sub:`C`\ = 0.6, even though the expectation is that
the balance statistic should be the same for both scenarios because both would yield the same degree
and so that they can be compared across variables, and they allow for a simple interpretation even
when the details of the variable's original scale are unclear to the analyst.
None of these advantages are passed to binary variables because binary variables are already on the
same scale (i.e., a proportion), and the scale is easily interpretable. In addition, the details of
standardizing the proportion difference of a binary variable involve dividing the proportion difference
by a variance, but the variance of a binary variable is a function of its proportion. Standardizing
the proportion difference of a binary variable can yield the following counterintuitive result:
if P\ :sub:`T`\ = 0.2 and P\ :sub:`C`\ = 0.3, the standardized difference in proportion would be
different from that if P\ :sub:`T`\ = 0.5 and P\ :sub:`C`\ = 0.6, even though the expectation is that
the balance statistic should be the same for both scenarios because both would yield the same degree
of bias in the effect estimate.
Expand Down Expand Up @@ -145,7 +145,7 @@ def _calc_smd_covar(
covar: str,
wt_var: str = None,
estimand: str = "ATE",
std_binary: bool = False
std_binary: bool = False,
) -> float:
"""
Calculate the Standardized Mean Difference (SMD) for a covariate between two groups.
Expand All @@ -162,10 +162,10 @@ def _calc_smd_covar(
The column name of the weights. If None, only the unadjusted SMD is calculated. Defaults to None.
estimand : str, optional
The causal estimand to use. Defaults to "ATE" (Average Treatment Effect). Currently supported
options are "ATT" (Average Treatment Effect among the Treated) and
options are "ATT" (Average Treatment Effect among the Treated) and
"ATC" (Average Treatment Effect among the Control group).
std_binary: bool
Should the mean differences for binary variables (i.e., difference in proportion)
Should the mean differences for binary variables (i.e., difference in proportion)
be standardized or not. Default is False. See notes.
Returns
Expand Down Expand Up @@ -214,7 +214,9 @@ def _calc_smd_covar(
if data[covar].dropna().nunique() == 2:
_check_proportion_within_range(wt_m1, wt_bin_custom_msg_1)
_check_proportion_within_range(wt_m0, wt_bin_custom_msg_0)
return _calc_smd_bin_covar(estimand, m1=m1, m0=m0, wt_m1=wt_m1, wt_m0=wt_m0, std_binary=std_binary)
return _calc_smd_bin_covar(
estimand, m1=m1, m0=m0, wt_m1=wt_m1, wt_m0=wt_m0, std_binary=std_binary
)
else:
return _calc_smd_cont_covar(
estimand, m1=wt_m1, m0=wt_m0, s2_1=s2_1, s2_0=s2_0
Expand Down Expand Up @@ -348,7 +350,11 @@ def _calc_smd_cont_covar(estimand, *args, **kwargs):


def _calc_smd_bin_covar_ate(
m1: float, m0: float, wt_m1: float = None, wt_m0: float = None, std_binary: bool = False
m1: float,
m0: float,
wt_m1: float = None,
wt_m0: float = None,
std_binary: bool = False,
) -> float:
"""
Calculate the Standardized Mean Difference (SMD) for binary covariates using the Average Treatment Effect (ATE).
Expand Down Expand Up @@ -406,7 +412,11 @@ def _calc_smd_cont_covar_ate(m1: float, m0: float, s2_1: float, s2_0: float) ->


def _calc_smd_bin_covar_att(
m1: float, m0: float, wt_m1: float = None, wt_m0: float = None, std_binary: bool = False
m1: float,
m0: float,
wt_m1: float = None,
wt_m0: float = None,
std_binary: bool = False,
) -> float:
"""
Calculate the standardized mean difference (SMD) for binary covariates
Expand All @@ -419,7 +429,7 @@ def _calc_smd_bin_covar_att(
m0 : float
The mean of the covariate for the control group. Must be between 0 and 1.
wt_m1 : float, optional
The weighted mean of the covariate for the treatment group.
The weighted mean of the covariate for the treatment group.
If not provided, `m1` is used. Must be between 0 and 1.
wt_m0 : float, optional
The weighted mean of the covariate for the control group. I
Expand All @@ -432,15 +442,19 @@ def _calc_smd_bin_covar_att(
"""
wt_m1 = m1 if wt_m1 is None else wt_m1
wt_m0 = m0 if wt_m0 is None else wt_m0

std_factor = np.sqrt(m1 * (1 - m1)) if std_binary else 1

smd = _calc_raw_smd(a=wt_m1, b=wt_m0, std_factor=std_factor)
return smd


def _calc_smd_bin_covar_atc(
m1: float, m0: float, wt_m1: float = None, wt_m0: float = None, std_binary: bool = False
m1: float,
m0: float,
wt_m1: float = None,
wt_m0: float = None,
std_binary: bool = False,
) -> float:
"""
Calculate the standardized mean difference (SMD) for binary covariates
Expand All @@ -453,7 +467,7 @@ def _calc_smd_bin_covar_atc(
m0 : float
The mean of the covariate for the control group. Must be between 0 and 1.
wt_m1 : float, optional
The weighted mean of the covariate for the treatment group.
The weighted mean of the covariate for the treatment group.
If not provided, `m1` is used. Must be between 0 and 1.
wt_m0 : float, optional
The weighted mean of the covariate for the control group. I
Expand Down Expand Up @@ -485,10 +499,10 @@ def _calc_smd_cont_covar_att(m1: float, m0: float, s2_1: float, s2_0: float) ->
m0 : float
The mean of the covariate for control group (group 0).
s2_1 : float
The variance of the covariate for treated group (group 1).
The variance of the covariate for treated group (group 1).
Must be strictly positive.
s2_0 : float
The variance of the covariate for control group (group 0).
The variance of the covariate for control group (group 0).
Must be strictly positive.
Returns
Expand All @@ -513,10 +527,10 @@ def _calc_smd_cont_covar_atc(m1: float, m0: float, s2_1: float, s2_0: float) ->
m0 : float
The mean of the covariate for control group (group 0).
s2_1 : float
The variance of the covariate for treated group (group 1).
The variance of the covariate for treated group (group 1).
Must be strictly positive.
s2_0 : float
The variance of the covariate for control group (group 0).
The variance of the covariate for control group (group 0).
Must be strictly positive.
Returns
Expand Down Expand Up @@ -553,4 +567,4 @@ def _calc_raw_smd(a: float, b: float, std_factor: float) -> float:
The raw SMD is calculated as the absolute difference between `a` and `b` divided by `std_factor`.
"""
raw_smd = abs(a - b) / std_factor
return raw_smd
return raw_smd
19 changes: 11 additions & 8 deletions tests/test_smd.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,22 +338,25 @@ def test_compute_smd_att_atc(sample_data):

# --- Test std_binary param ----


def test_std_binary_calc_smd_covar(sample_data):
smd = _calc_smd_covar(
data=sample_data,
group='group',
covar='binary_var',
group="group",
covar="binary_var",
)
expected = np.float64(0.3333333)
np.testing.assert_allclose(smd, expected, rtol=1e-4, atol=0)


def test_compute_smd_invalid_std_binary_type(sample_data):
with pytest.raises(TypeError, match="The `std_binary` parameter must be of type bool"):
with pytest.raises(
TypeError, match="The `std_binary` parameter must be of type bool"
):
compute_smd(
sample_data,
group='group',
vars=["binary_var"],
sample_data,
group="group",
vars=["binary_var"],
wt_var="weights",
std_binary="True"
)
std_binary="True",
)

0 comments on commit 6d36970

Please sign in to comment.