@@ -13,6 +13,14 @@ namespace Microsoft.ML.Auto
1313{
1414 internal static class UserInputValidationUtil
1515 {
16+ // column purpose names
17+ private const string LabelColumnPurposeName = "label" ;
18+ private const string WeightColumnPurposeName = "weight" ;
19+ private const string NumericColumnPurposeName = "numeric" ;
20+ private const string CategoricalColumnPurposeName = "categorical" ;
21+ private const string TextColumnPurposeName = "text" ;
22+ private const string IgnoredColumnPurposeName = "ignored" ;
23+
1624 public static void ValidateExperimentExecuteArgs ( IDataView trainData , ColumnInformation columnInformation ,
1725 IDataView validationData )
1826 {
@@ -55,22 +63,25 @@ private static void ValidateTrainData(IDataView trainData)
5563 private static void ValidateColumnInformation ( IDataView trainData , ColumnInformation columnInformation )
5664 {
5765 ValidateColumnInformation ( columnInformation ) ;
58- ValidateTrainDataColumnExists ( trainData , columnInformation . LabelColumn ) ;
59- ValidateTrainDataColumnExists ( trainData , columnInformation . WeightColumn ) ;
60- ValidateTrainDataColumnsExist ( trainData , columnInformation . CategoricalColumns ) ;
61- ValidateTrainDataColumnsExist ( trainData , columnInformation . NumericColumns ) ;
62- ValidateTrainDataColumnsExist ( trainData , columnInformation . TextColumns ) ;
63- ValidateTrainDataColumnsExist ( trainData , columnInformation . IgnoredColumns ) ;
66+ ValidateTrainDataColumn ( trainData , columnInformation . LabelColumn , LabelColumnPurposeName ) ;
67+ ValidateTrainDataColumn ( trainData , columnInformation . WeightColumn , WeightColumnPurposeName ) ;
68+ ValidateTrainDataColumns ( trainData , columnInformation . CategoricalColumns , CategoricalColumnPurposeName ,
69+ new DataViewType [ ] { NumberDataViewType . Single , TextDataViewType . Instance } ) ;
70+ ValidateTrainDataColumns ( trainData , columnInformation . NumericColumns , NumericColumnPurposeName ,
71+ new DataViewType [ ] { NumberDataViewType . Single , BooleanDataViewType . Instance } ) ;
72+ ValidateTrainDataColumns ( trainData , columnInformation . TextColumns , TextColumnPurposeName ,
73+ new DataViewType [ ] { TextDataViewType . Instance } ) ;
74+ ValidateTrainDataColumns ( trainData , columnInformation . IgnoredColumns , IgnoredColumnPurposeName ) ;
6475 }
6576
6677 private static void ValidateColumnInformation ( ColumnInformation columnInformation )
6778 {
6879 ValidateLabelColumn ( columnInformation . LabelColumn ) ;
6980
70- ValidateColumnInfoEnumerationProperty ( columnInformation . CategoricalColumns , "categorical" ) ;
71- ValidateColumnInfoEnumerationProperty ( columnInformation . NumericColumns , "numeric" ) ;
72- ValidateColumnInfoEnumerationProperty ( columnInformation . TextColumns , "text" ) ;
73- ValidateColumnInfoEnumerationProperty ( columnInformation . IgnoredColumns , "ignored" ) ;
81+ ValidateColumnInfoEnumerationProperty ( columnInformation . CategoricalColumns , CategoricalColumnPurposeName ) ;
82+ ValidateColumnInfoEnumerationProperty ( columnInformation . NumericColumns , NumericColumnPurposeName ) ;
83+ ValidateColumnInfoEnumerationProperty ( columnInformation . TextColumns , TextColumnPurposeName ) ;
84+ ValidateColumnInfoEnumerationProperty ( columnInformation . IgnoredColumns , IgnoredColumnPurposeName ) ;
7485
7586 // keep a list of all columns, to detect duplicates
7687 var allColumns = new List < string > ( ) ;
@@ -88,11 +99,11 @@ private static void ValidateColumnInformation(ColumnInformation columnInformatio
8899 }
89100 }
90101
91- private static void ValidateColumnInfoEnumerationProperty ( IEnumerable < string > columns , string propertyName )
102+ private static void ValidateColumnInfoEnumerationProperty ( IEnumerable < string > columns , string columnPurpose )
92103 {
93104 if ( columns ? . Contains ( null ) == true )
94105 {
95- throw new ArgumentException ( $ "Null column string was specified as { propertyName } in column information") ;
106+ throw new ArgumentException ( $ "Null column string was specified as { columnPurpose } in column information") ;
96107 }
97108 }
98109
@@ -155,7 +166,8 @@ private static void ValidateValidationData(IDataView trainData, IDataView valida
155166 }
156167 }
157168
158- private static void ValidateTrainDataColumnsExist ( IDataView trainData , IEnumerable < string > columnNames )
169+ private static void ValidateTrainDataColumns ( IDataView trainData , IEnumerable < string > columnNames , string columnPurpose ,
170+ IEnumerable < DataViewType > allowedTypes = null )
159171 {
160172 if ( columnNames == null )
161173 {
@@ -164,15 +176,41 @@ private static void ValidateTrainDataColumnsExist(IDataView trainData, IEnumerab
164176
165177 foreach ( var columnName in columnNames )
166178 {
167- ValidateTrainDataColumnExists ( trainData , columnName ) ;
179+ ValidateTrainDataColumn ( trainData , columnName , columnPurpose , allowedTypes ) ;
168180 }
169181 }
170182
171- private static void ValidateTrainDataColumnExists ( IDataView trainData , string columnName )
183+ private static void ValidateTrainDataColumn ( IDataView trainData , string columnName , string columnPurpose , IEnumerable < DataViewType > allowedTypes = null )
172184 {
173- if ( columnName != null && trainData . Schema . GetColumnOrNull ( columnName ) == null )
185+ if ( columnName == null )
186+ {
187+ return ;
188+ }
189+
190+ var nullableColumn = trainData . Schema . GetColumnOrNull ( columnName ) ;
191+ if ( nullableColumn == null )
192+ {
193+ throw new ArgumentException ( $ "Provided { columnPurpose } column { columnName } '{ columnName } ' not found in training data.") ;
194+ }
195+
196+ if ( allowedTypes == null )
174197 {
175- throw new ArgumentException ( $ "Provided column '{ columnName } ' not found in training data.") ;
198+ return ;
199+ }
200+ var column = nullableColumn . Value ;
201+ var itemType = column . Type . GetItemType ( ) ;
202+ if ( ! allowedTypes . Contains ( itemType ) )
203+ {
204+ if ( allowedTypes . Count ( ) == 1 )
205+ {
206+ throw new ArgumentException ( $ "Provided { columnPurpose } column '{ columnName } ' was of type { itemType } , " +
207+ $ "but only type { allowedTypes . First ( ) } is allowed.") ;
208+ }
209+ else
210+ {
211+ throw new ArgumentException ( $ "Provided { columnPurpose } column '{ columnName } ' was of type { itemType } , " +
212+ $ "but only types { string . Join ( ", " , allowedTypes ) } are allowed.") ;
213+ }
176214 }
177215 }
178216
0 commit comments