Skip to content

Commit

Permalink
style: black on src n test
Browse files Browse the repository at this point in the history
  • Loading branch information
shafayetShafee committed Aug 16, 2024
1 parent 23f0462 commit 0f19aab
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 121 deletions.
42 changes: 29 additions & 13 deletions src/skmiscpy/cbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def compute_smd(
A pandas DataFrame containing the columns specified in ``vars``, ``group``, and optionally ``wt_var``.
vars : List[str]
A list of strings representing the variables names for which to calculate the SMD, where the
variables should be either continuous or binary. The values of the binary variable could be
either string type or numerical, they would be converted into 0 and 1 (if they are not already
A list of strings representing the variables names for which to calculate the SMD, where the
variables should be either continuous or binary. The values of the binary variable could be
either string type or numerical, they would be converted into 0 and 1 (if they are not already
0-1), where lower value converted into 0 and higher value converted into 1. To compute SMD for
a discrete variable with more than two categories, pass that variable name in a list to the
``cat_vars`` parameter.
Expand Down Expand Up @@ -143,7 +143,9 @@ def compute_smd(
_check_param_type({"wt_var": wt_var}, str)

if cat_vars is not None:
if not (isinstance(cat_vars, list) and all(isinstance(v, str) for v in cat_vars)):
if not (
isinstance(cat_vars, list) and all(isinstance(v, str) for v in cat_vars)
):
raise TypeError("`cat_vars` must be a list of strings")

data = _check_prep_smd_data(
Expand All @@ -152,22 +154,32 @@ def compute_smd(

covariates = list(set(data.columns) - {wt_var, group})
covariates_with_types = _classify_columns(data, covariates)

if not std_binary:
if any(col_type == "binary" for col_type in covariates_with_types.values()):
print("For binary variables, the unstandardized mean differences are shown here. "
"See 'Notes' in function documentation for details."
print(
"For binary variables, the unstandardized mean differences are shown here. "
"See 'Notes' in function documentation for details."
)

smd_results = []

for var, var_type in covariates_with_types.items():
if wt_var is not None:
unadjusted_smd = _calc_smd_covar(
data=data, group=group, covar=var, estimand=estimand, std_binary=std_binary
data=data,
group=group,
covar=var,
estimand=estimand,
std_binary=std_binary,
)
adjusted_smd = _calc_smd_covar(
data=data, group=group, covar=var, wt_var=wt_var, estimand=estimand, std_binary=std_binary
data=data,
group=group,
covar=var,
wt_var=wt_var,
estimand=estimand,
std_binary=std_binary,
)
smd_results.append(
{
Expand All @@ -179,19 +191,23 @@ def compute_smd(
)
else:
unadjusted_smd = _calc_smd_covar(
data=data, group=group, covar=var, estimand=estimand, std_binary=std_binary
data=data,
group=group,
covar=var,
estimand=estimand,
std_binary=std_binary,
)
smd_results.append(
{
"variables": var,
"variables": var,
"var_types": var_type,
"unadjusted_smd": unadjusted_smd
"unadjusted_smd": unadjusted_smd,
}
)

smd_df = pd.DataFrame(smd_results)

return smd_df.sort_values(by = ['var_types', 'variables'], ascending=[False, True])
return smd_df.sort_values(by=["var_types", "variables"], ascending=[False, True])


def _check_prep_smd_data(
Expand Down
Loading

0 comments on commit 0f19aab

Please sign in to comment.