diff --git a/src/skmiscpy/cbs.py b/src/skmiscpy/cbs.py index 1813cf8..824092c 100644 --- a/src/skmiscpy/cbs.py +++ b/src/skmiscpy/cbs.py @@ -25,9 +25,9 @@ def compute_smd( A pandas DataFrame containing the columns specified in ``vars``, ``group``, and optionally ``wt_var``. vars : List[str] - A list of strings representing the variables names for which to calculate the SMD, where the - variables should be either continuous or binary. The values of the binary variable could be - either string type or numerical, they would be converted into 0 and 1 (if they are not already + A list of strings representing the variables names for which to calculate the SMD, where the + variables should be either continuous or binary. The values of the binary variable could be + either string type or numerical, they would be converted into 0 and 1 (if they are not already 0-1), where lower value converted into 0 and higher value converted into 1. To compute SMD for a discrete variable with more than two categories, pass that variable name in a list to the ``cat_vars`` parameter. @@ -143,7 +143,9 @@ def compute_smd( _check_param_type({"wt_var": wt_var}, str) if cat_vars is not None: - if not (isinstance(cat_vars, list) and all(isinstance(v, str) for v in cat_vars)): + if not ( + isinstance(cat_vars, list) and all(isinstance(v, str) for v in cat_vars) + ): raise TypeError("`cat_vars` must be a list of strings") data = _check_prep_smd_data( @@ -152,11 +154,12 @@ def compute_smd( covariates = list(set(data.columns) - {wt_var, group}) covariates_with_types = _classify_columns(data, covariates) - + if not std_binary: if any(col_type == "binary" for col_type in covariates_with_types.values()): - print("For binary variables, the unstandardized mean differences are shown here. " - "See 'Notes' in function documentation for details." + print( + "For binary variables, the unstandardized mean differences are shown here. " + "See 'Notes' in function documentation for details." ) smd_results = [] @@ -164,10 +167,19 @@ def compute_smd( for var, var_type in covariates_with_types.items(): if wt_var is not None: unadjusted_smd = _calc_smd_covar( - data=data, group=group, covar=var, estimand=estimand, std_binary=std_binary + data=data, + group=group, + covar=var, + estimand=estimand, + std_binary=std_binary, ) adjusted_smd = _calc_smd_covar( - data=data, group=group, covar=var, wt_var=wt_var, estimand=estimand, std_binary=std_binary + data=data, + group=group, + covar=var, + wt_var=wt_var, + estimand=estimand, + std_binary=std_binary, ) smd_results.append( { @@ -179,19 +191,23 @@ def compute_smd( ) else: unadjusted_smd = _calc_smd_covar( - data=data, group=group, covar=var, estimand=estimand, std_binary=std_binary + data=data, + group=group, + covar=var, + estimand=estimand, + std_binary=std_binary, ) smd_results.append( { - "variables": var, + "variables": var, "var_types": var_type, - "unadjusted_smd": unadjusted_smd + "unadjusted_smd": unadjusted_smd, } ) smd_df = pd.DataFrame(smd_results) - return smd_df.sort_values(by = ['var_types', 'variables'], ascending=[False, True]) + return smd_df.sort_values(by=["var_types", "variables"], ascending=[False, True]) def _check_prep_smd_data( diff --git a/tests/test_smd.py b/tests/test_smd.py index 50d2402..aa9bd68 100644 --- a/tests/test_smd.py +++ b/tests/test_smd.py @@ -8,17 +8,20 @@ @pytest.fixture def sample_data(): - return pd.DataFrame({ - 'age': [25, 30, 35, 40, 45], - 'weight': [150.5, 160.0, 155.3, 165.2, 170.8], - 'gender_binary': [0, 1, 0, 1, 0], - 'gender_label': ['male', 'female', 'male', 'female', 'male'], - 'race': ['white', 'black', 'hispanic', 'white', 'black'], - 'educ_level': ['bachelor', 'master', 'doctorate', 'bachelor', 'master'], - 'ps_wts': [0.2, 0.4, 0.6, 0.8, 1.0], - 'group': ['treated', 'control', 'treated', 'control', 'treated'], - 'date': pd.date_range(start='2024-01-01', periods=5, freq='D') - }) + return pd.DataFrame( + { + "age": [25, 30, 35, 40, 45], + "weight": [150.5, 160.0, 155.3, 165.2, 170.8], + "gender_binary": [0, 1, 0, 1, 0], + "gender_label": ["male", "female", "male", "female", "male"], + "race": ["white", "black", "hispanic", "white", "black"], + "educ_level": ["bachelor", "master", "doctorate", "bachelor", "master"], + "ps_wts": [0.2, 0.4, 0.6, 0.8, 1.0], + "group": ["treated", "control", "treated", "control", "treated"], + "date": pd.date_range(start="2024-01-01", periods=5, freq="D"), + } + ) + @pytest.fixture def small_sample_data(): @@ -32,6 +35,7 @@ def small_sample_data(): } ) + @pytest.fixture def df_bin_zero_variance(): """Fixture to provide sample data for tests.""" @@ -47,6 +51,7 @@ def df_bin_zero_variance(): # --- Test std_binary param ---------------------------------------------------------- + def test_std_binary_calc_smd_covar(small_sample_data): smd = _calc_smd_covar( data=small_sample_data, @@ -72,172 +77,208 @@ def test_compute_smd_invalid_std_binary_type(small_sample_data): # Testing _check_prep_smd_data() ----------------------------------------------------- + def test_check_prep_smd_data_transformations_1(sample_data): transformed_data = _check_prep_smd_data( sample_data, - group='group', - vars=['age', 'weight', 'gender_binary', 'gender_label', 'race', 'educ_level'], - wt_var='ps_wts', - cat_vars=['race', 'educ_level'] + group="group", + vars=["age", "weight", "gender_binary", "gender_label", "race", "educ_level"], + wt_var="ps_wts", + cat_vars=["race", "educ_level"], ) assert isinstance(transformed_data, pd.DataFrame) assert transformed_data.shape[1] == 12 - assert transformed_data['race_black'].equals(pd.Series([0, 1, 0, 0, 1], name='race_black')) - assert transformed_data['race_hispanic'].equals(pd.Series([0, 0, 1, 0, 0], name='race_hispanic')) - assert transformed_data['race_white'].equals(pd.Series([1, 0, 0, 1, 0], name='race_white')) - assert transformed_data['educ_level_bachelor'].equals(pd.Series([1, 0, 0, 1, 0], name='educ_level_bachelor')) - assert transformed_data['educ_level_doctorate'].equals(pd.Series([0, 0, 1, 0, 0], name='educ_level_doctorate')) - assert transformed_data['educ_level_master'].equals(pd.Series([0, 1, 0, 0, 1], name='educ_level_master')) + assert transformed_data["race_black"].equals( + pd.Series([0, 1, 0, 0, 1], name="race_black") + ) + assert transformed_data["race_hispanic"].equals( + pd.Series([0, 0, 1, 0, 0], name="race_hispanic") + ) + assert transformed_data["race_white"].equals( + pd.Series([1, 0, 0, 1, 0], name="race_white") + ) + assert transformed_data["educ_level_bachelor"].equals( + pd.Series([1, 0, 0, 1, 0], name="educ_level_bachelor") + ) + assert transformed_data["educ_level_doctorate"].equals( + pd.Series([0, 0, 1, 0, 0], name="educ_level_doctorate") + ) + assert transformed_data["educ_level_master"].equals( + pd.Series([0, 1, 0, 0, 1], name="educ_level_master") + ) pd.testing.assert_series_equal( - transformed_data['age'], - sample_data['age'], + transformed_data["age"], + sample_data["age"], check_index_type=False, - check_dtype=False + check_dtype=False, ) pd.testing.assert_series_equal( - transformed_data['weight'], - sample_data['weight'], + transformed_data["weight"], + sample_data["weight"], check_index_type=False, - check_dtype=False + check_dtype=False, ) pd.testing.assert_series_equal( - transformed_data['ps_wts'], - sample_data['ps_wts'], + transformed_data["ps_wts"], + sample_data["ps_wts"], check_index_type=False, - check_dtype=False + check_dtype=False, ) pd.testing.assert_series_equal( - transformed_data['gender_binary'], - sample_data['gender_binary'], + transformed_data["gender_binary"], + sample_data["gender_binary"], check_index_type=False, - check_dtype=False + check_dtype=False, ) pd.testing.assert_series_equal( - transformed_data['gender_label'], - pd.Series([1, 0, 1, 0, 1], name='gender_label'), + transformed_data["gender_label"], + pd.Series([1, 0, 1, 0, 1], name="gender_label"), check_index_type=False, - check_dtype=False + check_dtype=False, ) def test_check_prep_smd_data_transformations_2(sample_data): transformed_data = _check_prep_smd_data( sample_data, - group='group', - vars=['age', 'weight', 'gender_binary', 'gender_label'], - cat_vars=['race', 'educ_level'] + group="group", + vars=["age", "weight", "gender_binary", "gender_label"], + cat_vars=["race", "educ_level"], ) assert transformed_data.shape[1] == 11 - assert transformed_data['race_black'].equals(pd.Series([0, 1, 0, 0, 1], name='race_black')) - assert transformed_data['race_hispanic'].equals(pd.Series([0, 0, 1, 0, 0], name='race_hispanic')) - assert transformed_data['race_white'].equals(pd.Series([1, 0, 0, 1, 0], name='race_white')) - assert transformed_data['educ_level_bachelor'].equals(pd.Series([1, 0, 0, 1, 0], name='educ_level_bachelor')) - assert transformed_data['educ_level_doctorate'].equals(pd.Series([0, 0, 1, 0, 0], name='educ_level_doctorate')) - assert transformed_data['educ_level_master'].equals(pd.Series([0, 1, 0, 0, 1], name='educ_level_master')) + assert transformed_data["race_black"].equals( + pd.Series([0, 1, 0, 0, 1], name="race_black") + ) + assert transformed_data["race_hispanic"].equals( + pd.Series([0, 0, 1, 0, 0], name="race_hispanic") + ) + assert transformed_data["race_white"].equals( + pd.Series([1, 0, 0, 1, 0], name="race_white") + ) + assert transformed_data["educ_level_bachelor"].equals( + pd.Series([1, 0, 0, 1, 0], name="educ_level_bachelor") + ) + assert transformed_data["educ_level_doctorate"].equals( + pd.Series([0, 0, 1, 0, 0], name="educ_level_doctorate") + ) + assert transformed_data["educ_level_master"].equals( + pd.Series([0, 1, 0, 0, 1], name="educ_level_master") + ) pd.testing.assert_series_equal( - transformed_data['age'], - sample_data['age'], + transformed_data["age"], + sample_data["age"], check_index_type=False, - check_dtype=False + check_dtype=False, ) pd.testing.assert_series_equal( - transformed_data['weight'], - sample_data['weight'], + transformed_data["weight"], + sample_data["weight"], check_index_type=False, - check_dtype=False + check_dtype=False, ) pd.testing.assert_series_equal( - transformed_data['gender_binary'], - sample_data['gender_binary'], + transformed_data["gender_binary"], + sample_data["gender_binary"], check_index_type=False, - check_dtype=False + check_dtype=False, ) pd.testing.assert_series_equal( - transformed_data['gender_label'], - pd.Series([1, 0, 1, 0, 1], name='gender_label'), + transformed_data["gender_label"], + pd.Series([1, 0, 1, 0, 1], name="gender_label"), check_index_type=False, - check_dtype=False + check_dtype=False, ) + def test_check_prep_smd_data_missing_column(sample_data): - with pytest.raises(ValueError, match="The DataFrame is missing the following required columns"): + with pytest.raises( + ValueError, match="The DataFrame is missing the following required columns" + ): _check_prep_smd_data( - sample_data, - group='group', - vars=['age', 'nonexistent'], - wt_var='ps_wts' + sample_data, group="group", vars=["age", "nonexistent"], wt_var="ps_wts" ) + def test_check_prep_smd_data_invalid_weight(sample_data): invalid_data = sample_data.copy() - invalid_data['ps_wts'] = [-0.2, 0.4, -0.6, 0.8, 1.0] - with pytest.raises(ValueError, match="The 'ps_wts' column contains negative weight values."): + invalid_data["ps_wts"] = [-0.2, 0.4, -0.6, 0.8, 1.0] + with pytest.raises( + ValueError, match="The 'ps_wts' column contains negative weight values." + ): _check_prep_smd_data( invalid_data, - group='group', - vars=['age', 'gender_label', 'race'], - wt_var='ps_wts' + group="group", + vars=["age", "gender_label", "race"], + wt_var="ps_wts", ) + def test_check_prep_smd_data_non_numeric_weight(sample_data): invalid_data = sample_data.copy() - invalid_data['ps_wts'] = ['a', 'b', 'c', 'd', 'e'] + invalid_data["ps_wts"] = ["a", "b", "c", "d", "e"] with pytest.raises(ValueError, match="The 'ps_wts' column must be numeric."): _check_prep_smd_data( invalid_data, - group='group', - vars=['age', 'weight', 'gender_binary', 'gender_label', 'race'], - wt_var='ps_wts' + group="group", + vars=["age", "weight", "gender_binary", "gender_label", "race"], + wt_var="ps_wts", ) + def test_check_prep_smd_data_non_binary_group(sample_data): invalid_data = sample_data.copy() - invalid_data['group'] = [1, 2, 3, 4, 5] - with pytest.raises(ValueError, match="The 'group' column must be a binary column for valid SMD calculation."): + invalid_data["group"] = [1, 2, 3, 4, 5] + with pytest.raises( + ValueError, + match="The 'group' column must be a binary column for valid SMD calculation.", + ): _check_prep_smd_data( - invalid_data, - group='group', - vars=['age', 'gender_binary', 'race'] + invalid_data, group="group", vars=["age", "gender_binary", "race"] ) + def test_check_prep_smd_data_non_numeric_or_binary_1(sample_data): - with pytest.raises(ValueError, match="The 'date' column must be continuous or binary"): - _check_prep_smd_data( - sample_data, - group='group', - vars=['age', 'date'] - ) + with pytest.raises( + ValueError, match="The 'date' column must be continuous or binary" + ): + _check_prep_smd_data(sample_data, group="group", vars=["age", "date"]) + def test_check_prep_smd_data_non_numeric_or_binary_2(sample_data): - with pytest.raises(ValueError, match="The 'race' column must be continuous or binary"): - _check_prep_smd_data( - sample_data, - group='group', - vars=['age', 'race'] - ) + with pytest.raises( + ValueError, match="The 'race' column must be continuous or binary" + ): + _check_prep_smd_data(sample_data, group="group", vars=["age", "race"]) + def test_check_prep_smd_data_missing_group(sample_data): - with pytest.raises(ValueError, match="The DataFrame is missing the following required columns"): + with pytest.raises( + ValueError, match="The DataFrame is missing the following required columns" + ): _check_prep_smd_data( sample_data, - group='missing_group', - vars=['age', 'race'], - wt_var='ps_weights' + group="missing_group", + vars=["age", "race"], + wt_var="ps_weights", ) + def test_check_prep_smd_data_invalid_cat_vars(sample_data): - with pytest.raises(ValueError, match="The DataFrame is missing the following required columns"): + with pytest.raises( + ValueError, match="The DataFrame is missing the following required columns" + ): _check_prep_smd_data( sample_data, - group='group', - vars=['weight', 'race'], - cat_vars=['invalid_cat'] + group="group", + vars=["weight", "race"], + cat_vars=["invalid_cat"], ) + def test_check_prep_smd_data_missing_values(small_sample_data): data_with_nan = small_sample_data.copy() data_with_nan.loc[0, "binary_var"] = np.nan @@ -248,13 +289,16 @@ def test_check_prep_smd_data_missing_values(small_sample_data): data_with_nan, group="group", vars=["binary_var"], wt_var="weights" ) + # Test --- _calc_smd_covar() ------------------------------------------------------- + def test_calc_smd_covar_binary_unweighted(small_sample_data): smd = _calc_smd_covar(data=small_sample_data, group="group", covar="binary_var") assert isinstance(smd, float) assert smd > 0 + def test_calc_smd_covar_binary_weighted(small_sample_data): smd = _calc_smd_covar( data=small_sample_data, group="group", covar="binary_var", wt_var="weights" @@ -262,11 +306,13 @@ def test_calc_smd_covar_binary_weighted(small_sample_data): assert isinstance(smd, float) assert smd > 0 + def test_calc_smd_covar_continuous_unweighted(small_sample_data): smd = _calc_smd_covar(data=small_sample_data, group="group", covar="cont_var") assert isinstance(smd, float) assert smd > 0 + def test_calc_smd_covar_continuous_weighted(small_sample_data): smd = _calc_smd_covar( data=small_sample_data, group="group", covar="cont_var", wt_var="weights" @@ -277,6 +323,7 @@ def test_calc_smd_covar_continuous_weighted(small_sample_data): # --- zero variance and zero proportion ---------------------------------------------- + def test_calc_smd_covar_zero_variance(): data_zero_variance = pd.DataFrame( { @@ -317,6 +364,7 @@ def test_calc_smd_covar_zero_variance(): data_zero_variance_cont, group="group", covar="cont_var", wt_var="weights" ) + def test_calc_smd_bin_covar_zero_variance(): data_zero_proportion = pd.DataFrame( { @@ -351,6 +399,7 @@ def test_calc_smd_bin_covar_zero_variance(): wt_var="weights", ) + def test_calc_smd_covar_zero_proportion(): data_zero_proportion = pd.DataFrame( { @@ -378,6 +427,7 @@ def test_calc_smd_covar_zero_proportion(): ): _calc_smd_covar(data_zero_proportion_group0, group="group", covar="binary_var") + def test_compute_smd_att_atc(small_sample_data): att_smd = _calc_smd_covar( data=small_sample_data, @@ -400,42 +450,51 @@ def test_compute_smd_att_atc(small_sample_data): # --- Testing compute_smd() -------------------------------------------------------- + def test_compute_smd_invalid_group_type(sample_data): with pytest.raises(TypeError, match="The `group` parameter must be of type str"): compute_smd(sample_data, group=123, vars=["age"]) + def test_compute_smd_invalid_vars_type_1(sample_data): with pytest.raises(TypeError, match="`vars` must be a list of strings"): - compute_smd(sample_data, group="group", vars='age') + compute_smd(sample_data, group="group", vars="age") + def test_compute_smd_invalid_vars_type_2(sample_data): with pytest.raises(TypeError, match="`vars` must be a list of strings"): compute_smd(sample_data, group="group", vars=1) + def test_compute_smd_invalid_cat_vars_type_1(sample_data): with pytest.raises(TypeError, match="`cat_vars` must be a list of strings"): - compute_smd(sample_data, group="group", vars=['age'], cat_vars='race') + compute_smd(sample_data, group="group", vars=["age"], cat_vars="race") + def test_compute_smd_invalid_cat_vars_type_2(sample_data): with pytest.raises(TypeError, match="`cat_vars` must be a list of strings"): - compute_smd(sample_data, group="group", vars=['age'], cat_vars=2) + compute_smd(sample_data, group="group", vars=["age"], cat_vars=2) + def test_compute_smd_invalid_wt_var_type(sample_data): with pytest.raises(TypeError, match="`wt_var` parameter must be of type str"): - compute_smd(sample_data, group="group", vars=['age'], wt_var=2) + compute_smd(sample_data, group="group", vars=["age"], wt_var=2) + def test_compute_smd_invalid_std_binary_type(sample_data): with pytest.raises(TypeError, match="`std_binary` parameter must be of type bool"): - compute_smd(sample_data, group="group", vars=['age'], std_binary='yes') + compute_smd(sample_data, group="group", vars=["age"], std_binary="yes") + def test_compute_smd_warns_when_estimand_is_none(sample_data): - with pytest.warns(UserWarning, match="Estimand can not be None. Results are shown considering 'ATE' as the estimand."): + with pytest.warns( + UserWarning, + match="Estimand can not be None. Results are shown considering 'ATE' as the estimand.", + ): result = compute_smd( data=sample_data, - group='group', - vars=['age', 'weight'], - wt_var='ps_wts', - estimand=None + group="group", + vars=["age", "weight"], + wt_var="ps_wts", + estimand=None, ) - -