Skip to content

Commit 4e15684

Browse files
daholsteDmitry-A
authored andcommitted
User input column type validation (dotnet#218)
1 parent bf4ece8 commit 4e15684

File tree

2 files changed

+71
-17
lines changed

2 files changed

+71
-17
lines changed

src/Microsoft.ML.Auto/Utils/UserInputValidationUtil.cs

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@ namespace Microsoft.ML.Auto
1313
{
1414
internal static class UserInputValidationUtil
1515
{
16+
// column purpose names
17+
private const string LabelColumnPurposeName = "label";
18+
private const string WeightColumnPurposeName = "weight";
19+
private const string NumericColumnPurposeName = "numeric";
20+
private const string CategoricalColumnPurposeName = "categorical";
21+
private const string TextColumnPurposeName = "text";
22+
private const string IgnoredColumnPurposeName = "ignored";
23+
1624
public static void ValidateExperimentExecuteArgs(IDataView trainData, ColumnInformation columnInformation,
1725
IDataView validationData)
1826
{
@@ -55,22 +63,25 @@ private static void ValidateTrainData(IDataView trainData)
5563
private static void ValidateColumnInformation(IDataView trainData, ColumnInformation columnInformation)
5664
{
5765
ValidateColumnInformation(columnInformation);
58-
ValidateTrainDataColumnExists(trainData, columnInformation.LabelColumn);
59-
ValidateTrainDataColumnExists(trainData, columnInformation.WeightColumn);
60-
ValidateTrainDataColumnsExist(trainData, columnInformation.CategoricalColumns);
61-
ValidateTrainDataColumnsExist(trainData, columnInformation.NumericColumns);
62-
ValidateTrainDataColumnsExist(trainData, columnInformation.TextColumns);
63-
ValidateTrainDataColumnsExist(trainData, columnInformation.IgnoredColumns);
66+
ValidateTrainDataColumn(trainData, columnInformation.LabelColumn, LabelColumnPurposeName);
67+
ValidateTrainDataColumn(trainData, columnInformation.WeightColumn, WeightColumnPurposeName);
68+
ValidateTrainDataColumns(trainData, columnInformation.CategoricalColumns, CategoricalColumnPurposeName,
69+
new DataViewType[] { NumberDataViewType.Single, TextDataViewType.Instance });
70+
ValidateTrainDataColumns(trainData, columnInformation.NumericColumns, NumericColumnPurposeName,
71+
new DataViewType[] { NumberDataViewType.Single, BooleanDataViewType.Instance });
72+
ValidateTrainDataColumns(trainData, columnInformation.TextColumns, TextColumnPurposeName,
73+
new DataViewType[] { TextDataViewType.Instance });
74+
ValidateTrainDataColumns(trainData, columnInformation.IgnoredColumns, IgnoredColumnPurposeName);
6475
}
6576

6677
private static void ValidateColumnInformation(ColumnInformation columnInformation)
6778
{
6879
ValidateLabelColumn(columnInformation.LabelColumn);
6980

70-
ValidateColumnInfoEnumerationProperty(columnInformation.CategoricalColumns, "categorical");
71-
ValidateColumnInfoEnumerationProperty(columnInformation.NumericColumns, "numeric");
72-
ValidateColumnInfoEnumerationProperty(columnInformation.TextColumns, "text");
73-
ValidateColumnInfoEnumerationProperty(columnInformation.IgnoredColumns, "ignored");
81+
ValidateColumnInfoEnumerationProperty(columnInformation.CategoricalColumns, CategoricalColumnPurposeName);
82+
ValidateColumnInfoEnumerationProperty(columnInformation.NumericColumns, NumericColumnPurposeName);
83+
ValidateColumnInfoEnumerationProperty(columnInformation.TextColumns, TextColumnPurposeName);
84+
ValidateColumnInfoEnumerationProperty(columnInformation.IgnoredColumns, IgnoredColumnPurposeName);
7485

7586
// keep a list of all columns, to detect duplicates
7687
var allColumns = new List<string>();
@@ -88,11 +99,11 @@ private static void ValidateColumnInformation(ColumnInformation columnInformatio
8899
}
89100
}
90101

91-
private static void ValidateColumnInfoEnumerationProperty(IEnumerable<string> columns, string propertyName)
102+
private static void ValidateColumnInfoEnumerationProperty(IEnumerable<string> columns, string columnPurpose)
92103
{
93104
if (columns?.Contains(null) == true)
94105
{
95-
throw new ArgumentException($"Null column string was specified as {propertyName} in column information");
106+
throw new ArgumentException($"Null column string was specified as {columnPurpose} in column information");
96107
}
97108
}
98109

@@ -155,7 +166,8 @@ private static void ValidateValidationData(IDataView trainData, IDataView valida
155166
}
156167
}
157168

158-
private static void ValidateTrainDataColumnsExist(IDataView trainData, IEnumerable<string> columnNames)
169+
private static void ValidateTrainDataColumns(IDataView trainData, IEnumerable<string> columnNames, string columnPurpose,
170+
IEnumerable<DataViewType> allowedTypes = null)
159171
{
160172
if (columnNames == null)
161173
{
@@ -164,15 +176,41 @@ private static void ValidateTrainDataColumnsExist(IDataView trainData, IEnumerab
164176

165177
foreach (var columnName in columnNames)
166178
{
167-
ValidateTrainDataColumnExists(trainData, columnName);
179+
ValidateTrainDataColumn(trainData, columnName, columnPurpose, allowedTypes);
168180
}
169181
}
170182

171-
private static void ValidateTrainDataColumnExists(IDataView trainData, string columnName)
183+
private static void ValidateTrainDataColumn(IDataView trainData, string columnName, string columnPurpose, IEnumerable<DataViewType> allowedTypes = null)
172184
{
173-
if (columnName != null && trainData.Schema.GetColumnOrNull(columnName) == null)
185+
if (columnName == null)
186+
{
187+
return;
188+
}
189+
190+
var nullableColumn = trainData.Schema.GetColumnOrNull(columnName);
191+
if (nullableColumn == null)
192+
{
193+
throw new ArgumentException($"Provided {columnPurpose} column {columnName} '{columnName}' not found in training data.");
194+
}
195+
196+
if(allowedTypes == null)
174197
{
175-
throw new ArgumentException($"Provided column '{columnName}' not found in training data.");
198+
return;
199+
}
200+
var column = nullableColumn.Value;
201+
var itemType = column.Type.GetItemType();
202+
if (!allowedTypes.Contains(itemType))
203+
{
204+
if (allowedTypes.Count() == 1)
205+
{
206+
throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' was of type {itemType}, " +
207+
$"but only type {allowedTypes.First()} is allowed.");
208+
}
209+
else
210+
{
211+
throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' was of type {itemType}, " +
212+
$"but only types {string.Join(", ", allowedTypes)} are allowed.");
213+
}
176214
}
177215
}
178216

src/Test/UserInputValidationTests.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,5 +161,21 @@ public void ValidateFeaturesColInvalidType()
161161
var dataView = new EmptyDataView(new MLContext(), schema);
162162
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), null);
163163
}
164+
165+
[TestMethod]
166+
[ExpectedException(typeof(ArgumentException))]
167+
public void ValidateTextColumnNotText()
168+
{
169+
const string TextPurposeColName = "TextColumn";
170+
var schemaBuilder = new SchemaBuilder();
171+
schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single);
172+
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
173+
schemaBuilder.AddColumn(TextPurposeColName, NumberDataViewType.Double);
174+
var schema = schemaBuilder.GetSchema();
175+
var dataView = new EmptyDataView(new MLContext(), schema);
176+
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView,
177+
new ColumnInformation() { TextColumns = new[] { TextPurposeColName } },
178+
null);
179+
}
164180
}
165181
}

0 commit comments

Comments
 (0)