Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix error analysis cohort filter for string label values
Browse files Browse the repository at this point in the history
imatiach-msft committed Oct 6, 2022
1 parent c044957 commit fa7e498
Showing 2 changed files with 49 additions and 3 deletions.
12 changes: 11 additions & 1 deletion erroranalysis/erroranalysis/_internal/cohort_filter.py
Original file line number Diff line number Diff line change
@@ -52,7 +52,8 @@ def filter_from_cohort(analyzer, filters, composite_filters,
categories=analyzer.categories,
true_y=analyzer.true_y,
pred_y=pred_y,
model_task=analyzer.model_task)
model_task=analyzer.model_task,
classes=analyzer.classes)

return filter_data_with_cohort.filter_data_from_cohort(
filters=filters,
@@ -462,13 +463,22 @@ def _build_bounds_query(self, filter, colname, method,
is_categorical = False
if categorical_features:
is_categorical = colname in categorical_features
is_label = False
if colname == TRUE_Y or colname == PRED_Y:
is_label = True
for arg in filter[ARG]:
if is_categorical:
cat_idx = categorical_features.index(colname)
if isinstance(categories[cat_idx][arg], str):
arg_val = "'{}'".format(str(categories[cat_idx][arg]))
else:
arg_val = "{}".format(str(categories[cat_idx][arg]))
elif is_label and self.classes is not None:
class_value = self.classes[arg]
format_str = "{}"
if isinstance(class_value, str):
format_str = "'{}'"
arg_val = format_str.format(str(self.classes[arg]))
else:
arg_val = arg
bounds.append("`{}`{}{}".format(colname, operator, arg_val))
40 changes: 38 additions & 2 deletions erroranalysis/tests/test_cohort_filter.py
Original file line number Diff line number Diff line change
@@ -83,6 +83,27 @@ def test_cohort_filter_true_y(self):
model_task,
filters=filters)

def test_cohort_filter_string_encoded_true_y(self):
X_train, X_test, y_train, y_test, feature_names, classes = \
create_iris_str_y()
filters = [{'arg': [2],
'column': 'True Y',
'method': 'includes'}]
validation_data = create_validation_data(X_test, y_test)
validation_data = validation_data.loc[y_test == classes[2]]
model_task = ModelTask.CLASSIFICATION
model = create_sklearn_svm_classifier(X_train, y_train)
categorical_features = []
run_error_analyzer(validation_data,
model,
X_test,
y_test,
feature_names,
categorical_features,
model_task,
filters=filters,
classes=classes)

def test_cohort_filter_less(self):
X_train, X_test, y_train, y_test, feature_names = create_iris_pandas()
filters = [{'arg': [2.8],
@@ -355,6 +376,19 @@ def create_iris_pandas():
return X_train, X_test, y_train, y_test, feature_names


def create_iris_str_y():
X_train, X_test, y_train, y_test, feature_names, classes = \
create_iris_data()

X_train = pd.DataFrame(X_train, columns=feature_names)
X_test = pd.DataFrame(X_test, columns=feature_names)

y_train = np.array([classes[y] for y in y_train])
y_test = np.array([classes[y] for y in y_test])

return X_train, X_test, y_train, y_test, feature_names, classes


def create_validation_data(X_test, y_test, pred_y=None):
validation_data = X_test.copy()
validation_data[TRUE_Y] = y_test
@@ -373,13 +407,15 @@ def run_error_analyzer(validation_data,
model_task,
filters=None,
composite_filters=None,
is_empty_validation_data=False):
is_empty_validation_data=False,
classes=None):
error_analyzer = ModelAnalyzer(model,
X_test,
y_test,
feature_names,
categorical_features,
model_task=model_task)
model_task=model_task,
classes=classes)
filtered_data = filter_from_cohort(error_analyzer,
filters,
composite_filters)

0 comments on commit fa7e498

Please sign in to comment.