Skip to content

Commit

Permalink
Add validations for heterogeneity_features in causal manager (#1659)
Browse files Browse the repository at this point in the history
* Add validations for heterogeneity_features in causal manager

Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>

* Fix tests

Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>

* Fix linting

Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>

Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>
  • Loading branch information
gaugup authored Aug 23, 2022
1 parent 4b15495 commit fb916fb
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
8 changes: 8 additions & 0 deletions responsibleai/responsibleai/managers/causal_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,14 @@ def add(
f"not exist in train data: {list(difference_set)}")
raise UserConfigValidationException(message)

if heterogeneity_features is not None:
difference_set = \
set(heterogeneity_features) - set(self._train.columns)
if len(difference_set) > 0:
message = ("Feature names in heterogeneity_features do "
f"not exist in train data: {list(difference_set)}")
raise UserConfigValidationException(message)

if self._task_type == ModelTask.CLASSIFICATION:
is_multiclass = len(np.unique(
self._train[self._target_column].values).tolist()) > 2
Expand Down
22 changes: 22 additions & 0 deletions responsibleai/tests/test_rai_insights_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,28 @@ def test_treatment_features_list_not_having_train_features(self):
with pytest.raises(UserConfigValidationException, match=message):
rai_insights.causal.add(treatment_features=['not_a_feature'])

def test_heterogeneity_features_list_not_having_train_features(self):
X_train, y_train, X_test, y_test, _ = \
create_binary_classification_dataset()

model = create_lightgbm_classifier(X_train, y_train)
X_train[TARGET] = y_train
X_test[TARGET] = y_test

rai_insights = RAIInsights(
model=model,
train=X_train,
test=X_test,
target_column=TARGET,
task_type='classification')

message = ("Feature names in heterogeneity_features "
"do not exist in train data: \\['not_a_feature'\\]")
with pytest.raises(UserConfigValidationException, match=message):
rai_insights.causal.add(
treatment_features=['col1'],
heterogeneity_features=['not_a_feature'])


class TestCounterfactualUserConfigValidations:

Expand Down

0 comments on commit fb916fb

Please sign in to comment.