diff --git a/src/python/nimbusml/model_selection/cv.py b/src/python/nimbusml/model_selection/cv.py index f2a0dd8b..746f8aee 100644 --- a/src/python/nimbusml/model_selection/cv.py +++ b/src/python/nimbusml/model_selection/cv.py @@ -8,7 +8,7 @@ from pandas import DataFrame -from .. import Pipeline +from .. import Pipeline, FileDataStream from ..internal.entrypoints.models_crossvalidator import \ models_crossvalidator from ..internal.entrypoints.transforms_manyheterogeneousmodelcombiner \ @@ -450,13 +450,22 @@ def fit( # Need to infer from group_id, bug 284886 groups = groups or group_id if groups is not None: - if groups not in cv_aux_info[0]['data_import'][0].inputs[ - 'CustomSchema']: - raise Exception( - 'Default stratification column: ' + - str(groups) + - ' cannot be found in the origin data, please specify ' - 'groups in .fit() function.') + if isinstance(X, FileDataStream): + if groups not in cv_aux_info[0]['data_import'][0].inputs[ + 'CustomSchema']: + raise Exception( + 'Default stratification column: ' + + str(groups) + + ' cannot be found in the origin data, please specify ' + 'groups in .fit() function.') + elif isinstance(X, DataFrame): + if groups not in X.columns: + raise Exception( + 'Default stratification column: ' + + str(groups) + + ' cannot be found in the origin data, please specify ' + 'groups in .fit() function.') + split_index = self._process_split_start(split_start) graph_sections = cv_aux_info.graph_sections diff --git a/src/python/nimbusml/tests/model_selection/test_cv.py b/src/python/nimbusml/tests/model_selection/test_cv.py index f26f326a..3d7587e9 100644 --- a/src/python/nimbusml/tests/model_selection/test_cv.py +++ b/src/python/nimbusml/tests/model_selection/test_cv.py @@ -375,6 +375,12 @@ def data(self, label_name, group_id, features): data._set_role(Role.Label, label_name) return data + def data_pandas(self): + simpleinput_file = get_dataset("gen_tickettrain").as_filepath() + data = pd.read_csv(simpleinput_file) + data['group'] = data['group'].astype(str) + return data + def data_wt_rename(self, label_name, group_id, features): simpleinput_file = get_dataset("gen_tickettrain").as_filepath() file_schema = 'sep=, col={label}:R4:0 col={group_id}:TX:1 ' \ @@ -402,6 +408,29 @@ def check_cv_with_defaults2( data = self.data_wt_rename(label_name, group_id, features) check_cv(pipeline=Pipeline(steps), X=data, **params) + @unittest.skipIf(os.name != "nt", "random crashes on linux") + def check_cv_with_defaults_df( + self, + label_name='rank', + group_id='group', + features=['price','Class','dep_day','nbr_stops','duration'], + **params): + steps = [ + OneHotHashVectorizer( + output_kind='Key') << { + group_id: group_id}, + LightGbmRanker( + min_data_per_leaf=1, + feature=features, + label='rank', group_id='group' + )] + data = self.data_pandas() + check_cv(pipeline=Pipeline(steps), X=data, **params) + + @unittest.skipIf(os.name != "nt", "random crashes on linux") + def test_default_df(self): + self.check_cv_with_defaults_df() + @unittest.skipIf(os.name != "nt", "random crashes on linux") def test_default_label2(self): self.check_cv_with_defaults2(split_start='try_all')