Skip to content

Commit

Permalink
rev ColumnInference API: can take label index; rev output object type…
Browse files Browse the repository at this point in the history
…s; add tests (dotnet#89)
  • Loading branch information
daholste authored Feb 10, 2019
1 parent 6f2bcab commit 31379c4
Show file tree
Hide file tree
Showing 14 changed files with 221 additions and 220 deletions.
11 changes: 2 additions & 9 deletions src/AutoML/API/InferenceException.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,9 @@ namespace Microsoft.ML.Auto
{
public enum InferenceType
{
Seperator,
Header,
Label,
Task,
ColumnDataKind,
ColumnPurpose,
Tranform,
Trainer,
Hyperparams,
ColumnSplit
ColumnSplit,
Label,
}

public class InferenceException : Exception
Expand Down
74 changes: 8 additions & 66 deletions src/AutoML/API/MLContextDataExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,87 +2,29 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Data.DataView;
using Microsoft.ML.Data;

namespace Microsoft.ML.Auto
{
public static class DataExtensions
{
// Delimiter, header, column datatype inference
public static ColumnInferenceResult InferColumns(this DataOperationsCatalog catalog, string path, string label,
bool hasHeader = false, char? separatorChar = null, bool? allowQuotedStrings = null, bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true)
public static (TextLoader.Arguments TextLoaderArgs, IEnumerable<(string Name, ColumnPurpose Purpose)> ColumnPurpopses) InferColumns(this DataOperationsCatalog catalog, string path, string label,
char? separatorChar = null, bool? allowQuotedStrings = null, bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true)
{
UserInputValidationUtil.ValidateInferColumnsArgs(path, label);
var mlContext = new MLContext();
return ColumnInferenceApi.InferColumns(mlContext, path, label, hasHeader, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
return ColumnInferenceApi.InferColumns(mlContext, path, label, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
}

public static IDataView AutoRead(this DataOperationsCatalog catalog, string path, string label,
bool hasHeader = false, char? separatorChar = null, bool? allowQuotedStrings = null, bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true)
public static (TextLoader.Arguments TextLoaderArgs, IEnumerable<(string Name, ColumnPurpose Purpose)> ColumnPurpopses) InferColumns(this DataOperationsCatalog catalog, string path, int labelColumnIndex,
bool hasHeader = false, char? separatorChar = null, bool? allowQuotedStrings = null, bool? supportSparse = null,
bool trimWhitespace = false, bool groupColumns = true)
{
UserInputValidationUtil.ValidateAutoReadArgs(path, label);
UserInputValidationUtil.ValidateInferColumnsArgs(path, labelColumnIndex);
var mlContext = new MLContext();
var columnInferenceResult = ColumnInferenceApi.InferColumns(mlContext, path, label, hasHeader, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
var textLoader = columnInferenceResult.BuildTextLoader();
return textLoader.Read(path);
}

public static TextLoader CreateTextLoader(this DataOperationsCatalog catalog, ColumnInferenceResult columnInferenceResult)
{
UserInputValidationUtil.ValidateCreateTextReaderArgs(columnInferenceResult);
return columnInferenceResult.BuildTextLoader();
}

// Task inference
public static MachineLearningTaskType InferTask(this DataOperationsCatalog catalog, IDataView dataView)
{
throw new NotImplementedException();
}

public enum MachineLearningTaskType
{
Regression,
BinaryClassification,
MultiClassClassification
}
}

public class ColumnInferenceResult
{
public readonly IEnumerable<(TextLoader.Column, ColumnPurpose)> Columns;
public readonly bool AllowQuotedStrings;
public readonly bool SupportSparse;
public readonly char[] Separators;
public readonly bool HasHeader;
public readonly bool TrimWhitespace;

public ColumnInferenceResult(IEnumerable<(TextLoader.Column, ColumnPurpose)> columns,
bool allowQuotedStrings, bool supportSparse, char[] separators, bool hasHeader, bool trimWhitespace)
{
Columns = columns;
AllowQuotedStrings = allowQuotedStrings;
SupportSparse = supportSparse;
Separators = separators;
HasHeader = hasHeader;
TrimWhitespace = trimWhitespace;
}

internal TextLoader BuildTextLoader()
{
var context = new MLContext();
return new TextLoader(context, new TextLoader.Arguments()
{
AllowQuoting = AllowQuotedStrings,
AllowSparse = SupportSparse,
Column = Columns.Select(c => c.Item1).ToArray(),
Separators = Separators,
HasHeader = HasHeader,
TrimWhitespace = TrimWhitespace
});
return ColumnInferenceApi.InferColumns(mlContext, path, labelColumnIndex, hasHeader, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
}
}
}
62 changes: 52 additions & 10 deletions src/AutoML/ColumnInference/ColumnInferenceApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,52 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;

namespace Microsoft.ML.Auto
{
internal static class ColumnInferenceApi
{
public static ColumnInferenceResult InferColumns(MLContext context, string path, string label,
public static (TextLoader.Arguments, IEnumerable<(string, ColumnPurpose)>) InferColumns(MLContext context, string path, int labelColumnIndex,
bool hasHeader, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
{
var sample = TextFileSample.CreateFromFullFile(path);
var splitInference = InferSplit(sample, separatorChar, allowQuotedStrings, supportSparse);
var typeInference = InferColumnTypes(context, sample, splitInference, hasHeader);

// If label column index > inferred # of columns, throw error
if (labelColumnIndex >= typeInference.Columns.Count())
{
throw new ArgumentOutOfRangeException(nameof(labelColumnIndex), $"Label column index ({labelColumnIndex}) is >= than # of inferred columns ({typeInference.Columns.Count()}).");
}

// if no column is named label,
// rename label column to default ML.NET label column name
if (!typeInference.Columns.Any(c => c.SuggestedName == DefaultColumnNames.Label))
{
typeInference.Columns[labelColumnIndex].SuggestedName = DefaultColumnNames.Label;
}

return InferColumns(context, path, typeInference.Columns[labelColumnIndex].SuggestedName,
hasHeader, splitInference, typeInference, trimWhitespace, groupColumns);
}

public static (TextLoader.Arguments, IEnumerable<(string, ColumnPurpose)>) InferColumns(MLContext context, string path, string label,
char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
{
var sample = TextFileSample.CreateFromFullFile(path);
var splitInference = InferSplit(sample, separatorChar, allowQuotedStrings, supportSparse);
var typeInference = InferColumnTypes(context, sample, splitInference, true);
return InferColumns(context, path, label, true, splitInference, typeInference, trimWhitespace, groupColumns);
}

public static (TextLoader.Arguments, IEnumerable<(string, ColumnPurpose)>) InferColumns(MLContext context, string path, string label, bool hasHeader,
TextFileContents.ColumnSplitResult splitInference, ColumnTypeInference.InferenceResult typeInference,
bool trimWhitespace, bool groupColumns)
{
var loaderColumns = ColumnTypeInference.GenerateLoaderColumns(typeInference.Columns);
if (!loaderColumns.Any(t => label.Equals(t.Name)))
{
Expand All @@ -34,25 +67,34 @@ public static ColumnInferenceResult InferColumns(MLContext context, string path,

var purposeInferenceResult = PurposeInference.InferPurposes(context, dataView, label);

(TextLoader.Column, ColumnPurpose Purpose)[] inferredColumns = null;
// start building result objects
IEnumerable<TextLoader.Column> columnResults = null;
IEnumerable<(string, ColumnPurpose)> purposeResults = null;

// infer column grouping and generate column names
if (groupColumns)
{
var groupingResult = ColumnGroupingInference.InferGroupingAndNames(context, hasHeader,
typeInference.Columns, purposeInferenceResult);

// build result objects & return
inferredColumns = groupingResult.Select(c => (c.GenerateTextLoaderColumn(), c.Purpose)).ToArray();
columnResults = groupingResult.Select(c => c.GenerateTextLoaderColumn());
purposeResults = groupingResult.Select(c => (c.SuggestedName, c.Purpose));
}
else
{
inferredColumns = new (TextLoader.Column, ColumnPurpose Purpose)[loaderColumns.Length];
for (int i = 0; i < loaderColumns.Length; i++)
{
inferredColumns[i] = (loaderColumns[i], purposeInferenceResult[i].Purpose);
}
columnResults = loaderColumns;
purposeResults = purposeInferenceResult.Select(p => (dataView.Schema[p.ColumnIndex].Name, p.Purpose));
}
return new ColumnInferenceResult(inferredColumns, splitInference.AllowQuote, splitInference.AllowSparse, new char[] { splitInference.Separator.Value }, hasHeader, trimWhitespace);

return (new TextLoader.Arguments()
{
Column = columnResults.ToArray(),
AllowQuoting = splitInference.AllowQuote,
AllowSparse = splitInference.AllowSparse,
Separators = new char[] { splitInference.Separator.Value },
HasHeader = hasHeader,
TrimWhitespace = trimWhitespace
}, purposeResults);
}

private static TextFileContents.ColumnSplitResult InferSplit(TextFileSample sample, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse)
Expand Down
5 changes: 3 additions & 2 deletions src/AutoML/ColumnInference/ColumnTypeInference.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,13 @@ public IntermediateColumn(ReadOnlyMemory<char>[] data, int columnId)
public ReadOnlyMemory<char>[] RawData { get { return _data; } }
}

public readonly struct Column
public struct Column
{
public readonly int ColumnIndex;
public readonly string SuggestedName;
public readonly PrimitiveType ItemType;

public string SuggestedName;

public Column(int columnIndex, string suggestedName, PrimitiveType itemType)
{
ColumnIndex = columnIndex;
Expand Down
52 changes: 19 additions & 33 deletions src/AutoML/Utils/UserInputValidationUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public static void ValidateAutoFitArgs(IDataView trainData, string label, IDataV
{
ValidateTrainData(trainData);
ValidateValidationData(trainData, validationData);
ValidateLabel(trainData, validationData, label);
ValidateLabel(trainData, label);
ValidateSettings(settings);
ValidatePurposeOverrides(trainData, validationData, label, purposeOverrides);
}
Expand All @@ -28,49 +28,27 @@ public static void ValidateInferColumnsArgs(string path, string label)
ValidatePath(path);
}

public static void ValidateAutoReadArgs(string path, string label)
public static void ValidateInferColumnsArgs(string path, int labelColumnIndex)
{
ValidateLabel(label);
ValidateLabelColumnIndex(labelColumnIndex);
ValidatePath(path);
}

public static void ValidateCreateTextReaderArgs(ColumnInferenceResult columnInferenceResult)
public static void ValidateAutoReadArgs(string path, string label)
{
if(columnInferenceResult == null)
{
throw new ArgumentNullException($"Column inference result cannot be null", nameof(columnInferenceResult));
}

if (columnInferenceResult.Separators == null || !columnInferenceResult.Separators.Any())
{
throw new ArgumentException($"Column inference result cannot have null or empty separators", nameof(columnInferenceResult));
}

if (columnInferenceResult.Columns == null || !columnInferenceResult.Columns.Any())
{
throw new ArgumentException($"Column inference result must contain at least one column", nameof(columnInferenceResult));
}

if(columnInferenceResult.Columns.Any(c => c.Item1 == null))
{
throw new ArgumentException($"Column inference result cannot contain null columns", nameof(columnInferenceResult));
}

if (columnInferenceResult.Columns.Any(c => c.Item1.Name == null || c.Item1.Type == null || c.Item1.Source == null))
{
throw new ArgumentException($"Column inference result cannot contain a column that has a null name, type, or source", nameof(columnInferenceResult));
}
ValidateLabel(label);
ValidatePath(path);
}

private static void ValidateTrainData(IDataView trainData)
{
if(trainData == null)
{
throw new ArgumentNullException("Training data cannot be null", nameof(trainData));
throw new ArgumentNullException(nameof(trainData), "Training data cannot be null");
}
}

private static void ValidateLabel(IDataView trainData, IDataView validationData, string label)
private static void ValidateLabel(IDataView trainData, string label)
{
ValidateLabel(label);

Expand All @@ -84,15 +62,23 @@ private static void ValidateLabel(string label)
{
if (label == null)
{
throw new ArgumentNullException("Provided label cannot be null", nameof(label));
throw new ArgumentNullException(nameof(label), "Provided label cannot be null");
}
}

private static void ValidateLabelColumnIndex(int labelColumnIndex)
{
if (labelColumnIndex < 0)
{
throw new ArgumentOutOfRangeException(nameof(labelColumnIndex), $"Provided label column index ({labelColumnIndex}) must be non-negative.");
}
}

private static void ValidatePath(string path)
{
if (path == null)
{
throw new ArgumentNullException("Provided path cannot be null", nameof(path));
throw new ArgumentNullException(nameof(path), "Provided path cannot be null");
}

var fileInfo = new FileInfo(path);
Expand Down Expand Up @@ -148,7 +134,7 @@ private static void ValidateSettings(AutoFitSettings settings)

if(settings.StoppingCriteria.MaxIterations <= 0)
{
throw new ArgumentOutOfRangeException("Max iterations must be > 0", nameof(settings));
throw new ArgumentOutOfRangeException(nameof(settings), "Max iterations must be > 0");
}
}

Expand Down
12 changes: 6 additions & 6 deletions src/Test/AutoFitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ public void AutoFitBinaryTest()
{
var context = new MLContext();
var dataPath = DatasetUtil.DownloadUciAdultDataset();
var columnInference = context.Data.InferColumns(dataPath, DatasetUtil.UciAdultLabel, true);
var textLoader = context.Data.CreateTextLoader(columnInference);
var columnInference = context.Data.InferColumns(dataPath, DatasetUtil.UciAdultLabel);
var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderArgs);
var trainData = textLoader.Read(dataPath);
var validationData = trainData.Take(100);
trainData = trainData.Skip(100);
Expand All @@ -38,8 +38,8 @@ public void AutoFitMultiTest()
{
var context = new MLContext();
var dataPath = DatasetUtil.DownloadTrivialDataset();
var columnInference = context.Data.InferColumns(dataPath, DatasetUtil.TrivialDatasetLabel, true);
var textLoader = context.Data.CreateTextLoader(columnInference);
var columnInference = context.Data.InferColumns(dataPath, DatasetUtil.TrivialDatasetLabel);
var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderArgs);
var trainData = textLoader.Read(dataPath);
var validationData = trainData.Take(20);
trainData = trainData.Skip(20);
Expand All @@ -61,8 +61,8 @@ public void AutoFitRegressionTest()
{
var context = new MLContext();
var dataPath = DatasetUtil.DownloadMlNetGeneratedRegressionDataset();
var columnInference = context.Data.InferColumns(dataPath, DatasetUtil.MlNetGeneratedRegressionLabel, true);
var textLoader = context.Data.CreateTextLoader(columnInference);
var columnInference = context.Data.InferColumns(dataPath, DatasetUtil.MlNetGeneratedRegressionLabel);
var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderArgs);
var trainData = textLoader.Read(dataPath);
var validationData = trainData.Take(20);
trainData = trainData.Skip(20);
Expand Down
Loading

0 comments on commit 31379c4

Please sign in to comment.