diff --git a/ZBaselines/Common/EntryPoints/core_ep-list.tsv b/ZBaselines/Common/EntryPoints/core_ep-list.tsv index 47007edaa6..22c2767d7a 100644 --- a/ZBaselines/Common/EntryPoints/core_ep-list.tsv +++ b/ZBaselines/Common/EntryPoints/core_ep-list.tsv @@ -1,7 +1,8 @@ +Data.CustomTextLoader Import a dataset from a text file Microsoft.ML.Runtime.EntryPoints.ImportTextData ImportText Microsoft.ML.Runtime.EntryPoints.ImportTextData+Input Microsoft.ML.Runtime.EntryPoints.ImportTextData+Output Data.DataViewReference Pass dataview from memory to experiment Microsoft.ML.Runtime.EntryPoints.DataViewReference ImportData Microsoft.ML.Runtime.EntryPoints.DataViewReference+Input Microsoft.ML.Runtime.EntryPoints.DataViewReference+Output Data.IDataViewArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewOutput Data.PredictorModelArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelOutput -Data.TextLoader Import a dataset from a text file Microsoft.ML.Runtime.EntryPoints.ImportTextData ImportText Microsoft.ML.Runtime.EntryPoints.ImportTextData+Input Microsoft.ML.Runtime.EntryPoints.ImportTextData+Output +Data.TextLoader Import a dataset from a text file Microsoft.ML.Runtime.EntryPoints.ImportTextData TextLoader Microsoft.ML.Runtime.EntryPoints.ImportTextData+LoaderInput Microsoft.ML.Runtime.EntryPoints.ImportTextData+Output Models.AnomalyDetectionEvaluator Evaluates an anomaly detection scored dataset. Microsoft.ML.Runtime.Data.Evaluate AnomalyDetection Microsoft.ML.Runtime.Data.AnomalyDetectionMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+CommonEvaluateOutput Models.BinaryClassificationEvaluator Evaluates a binary classification scored dataset. Microsoft.ML.Runtime.Data.Evaluate Binary Microsoft.ML.Runtime.Data.BinaryClassifierMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+ClassificationEvaluateOutput Models.BinaryCrossValidator Cross validation for binary classification Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro CrossValidateBinary Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+Output] diff --git a/ZBaselines/Common/EntryPoints/core_manifest.json b/ZBaselines/Common/EntryPoints/core_manifest.json index a6309fe36a..6eeb1bf709 100644 --- a/ZBaselines/Common/EntryPoints/core_manifest.json +++ b/ZBaselines/Common/EntryPoints/core_manifest.json @@ -1,5 +1,43 @@ { "EntryPoints": [ + { + "Name": "Data.CustomTextLoader", + "Desc": "Import a dataset from a text file", + "FriendlyName": null, + "ShortName": null, + "Inputs": [ + { + "Name": "InputFile", + "Type": "FileHandle", + "Desc": "Location of the input file", + "Aliases": [ + "data" + ], + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "CustomSchema", + "Type": "String", + "Desc": "Custom schema to use for parsing", + "Aliases": [ + "schema" + ], + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": null + } + ], + "Outputs": [ + { + "Name": "Data", + "Type": "DataView", + "Desc": "The resulting data view" + } + ] + }, { "Name": "Data.DataViewReference", "Desc": "Pass dataview from memory to experiment", @@ -99,16 +137,325 @@ "IsNullable": false }, { - "Name": "CustomSchema", - "Type": "String", - "Desc": "Custom schema to use for parsing", + "Name": "Arguments", + "Type": { + "Kind": "Struct", + "Fields": [ + { + "Name": "Column", + "Type": { + "Kind": "Array", + "ItemType": { + "Kind": "Struct", + "Fields": [ + { + "Name": "Name", + "Type": "String", + "Desc": "Name of the column", + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Type", + "Type": { + "Kind": "Enum", + "Values": [ + "I1", + "U1", + "I2", + "U2", + "I4", + "U4", + "I8", + "U8", + "R4", + "Num", + "R8", + "TX", + "Text", + "TXT", + "BL", + "Bool", + "TimeSpan", + "TS", + "DT", + "DateTime", + "DZ", + "DateTimeZone", + "UG", + "U16" + ] + }, + "Desc": "Type of the items in the column", + "Required": false, + "SortOrder": 150.0, + "IsNullable": true, + "Default": null + }, + { + "Name": "Source", + "Type": { + "Kind": "Array", + "ItemType": { + "Kind": "Struct", + "Fields": [ + { + "Name": "Min", + "Type": "Int", + "Desc": "First index in the range", + "Required": true, + "SortOrder": 150.0, + "IsNullable": false, + "Default": 0 + }, + { + "Name": "Max", + "Type": "Int", + "Desc": "Last index in the range", + "Required": false, + "SortOrder": 150.0, + "IsNullable": true, + "Default": null + }, + { + "Name": "AutoEnd", + "Type": "Bool", + "Desc": "This range extends to the end of the line, but should be a fixed number of items", + "Aliases": [ + "auto" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": false + }, + { + "Name": "VariableEnd", + "Type": "Bool", + "Desc": "This range extends to the end of the line, which can vary from line to line", + "Aliases": [ + "var" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": false + }, + { + "Name": "AllOther", + "Type": "Bool", + "Desc": "This range includes only other indices not specified", + "Aliases": [ + "other" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": false + }, + { + "Name": "ForceVector", + "Type": "Bool", + "Desc": "Force scalar columns to be treated as vectors of length one", + "Aliases": [ + "vector" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": false + } + ] + } + }, + "Desc": "Source index range(s) of the column", + "Aliases": [ + "src" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "KeyRange", + "Type": { + "Kind": "Struct", + "Fields": [ + { + "Name": "Min", + "Type": "UInt", + "Desc": "First index in the range", + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": 0 + }, + { + "Name": "Max", + "Type": "UInt", + "Desc": "Last index in the range", + "Required": false, + "SortOrder": 150.0, + "IsNullable": true, + "Default": null + }, + { + "Name": "Contiguous", + "Type": "Bool", + "Desc": "Whether the key is contiguous", + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": true + } + ] + }, + "Desc": "For a key column, this defines the range of values", + "Aliases": [ + "key" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": null + } + ] + } + }, + "Desc": "Column groups. Each group is specified as name:type:numeric-ranges, eg, col=Features:R4:1-17,26,35-40", + "Aliases": [ + "col" + ], + "Required": false, + "SortOrder": 1.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "UseThreads", + "Type": "Bool", + "Desc": "Use separate parsing threads?", + "Aliases": [ + "threads" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": true + }, + { + "Name": "HeaderFile", + "Type": "String", + "Desc": "File containing a header with feature names. If specified, header defined in the data file (header+) is ignored.", + "Aliases": [ + "hf" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "MaxRows", + "Type": "Int", + "Desc": "Maximum number of rows to produce", + "Aliases": [ + "rows" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": true, + "Default": null + }, + { + "Name": "AllowQuoting", + "Type": "Bool", + "Desc": "Whether the input may include quoted values, which can contain separator characters, colons, and distinguish empty values from missing values. When true, consecutive separators denote a missing value and an empty value is denoted by \"\". When false, consecutive separators denote an empty value.", + "Aliases": [ + "quote" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": true + }, + { + "Name": "AllowSparse", + "Type": "Bool", + "Desc": "Whether the input may include sparse representations", + "Aliases": [ + "sparse" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": true + }, + { + "Name": "InputSize", + "Type": "Int", + "Desc": "Number of source columns in the text data. Default is that sparse rows contain their size information.", + "Aliases": [ + "size" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": true, + "Default": null + }, + { + "Name": "Separator", + "Type": { + "Kind": "Array", + "ItemType": "Char" + }, + "Desc": "Source column separator.", + "Aliases": [ + "sep" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": [ + "\t" + ] + }, + { + "Name": "TrimWhitespace", + "Type": "Bool", + "Desc": "Remove trailing whitespace from lines", + "Aliases": [ + "trim" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": false + }, + { + "Name": "HasHeader", + "Type": "Bool", + "Desc": "Data file has header with feature names. Header is read only if options 'hs' and 'hf' are not specified.", + "Aliases": [ + "header" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": false + } + ] + }, + "Desc": "Arguments", "Aliases": [ - "schema" + "args" ], - "Required": false, + "Required": true, "SortOrder": 2.0, - "IsNullable": false, - "Default": null + "IsNullable": false } ], "Outputs": [ @@ -117,6 +464,9 @@ "Type": "DataView", "Desc": "The resulting data view" } + ], + "InputKind": [ + "ILearningPipelineLoader" ] }, { @@ -21959,6 +22309,10 @@ } ] }, + { + "Kind": "ILearningPipelineLoader", + "Settings": [] + }, { "Kind": "IMulticlassClassificationOutput", "Settings": [] diff --git a/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs b/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs index 3a468ad451..586f6a4b02 100644 --- a/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs +++ b/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs @@ -49,8 +49,10 @@ public sealed class EntryPointInfo public readonly Type OutputType; public readonly Type[] InputKinds; public readonly Type[] OutputKinds; + public readonly ObsoleteAttribute ObsoleteAttribute; - internal EntryPointInfo(IExceptionContext ectx, MethodInfo method, TlcModule.EntryPointAttribute attribute) + internal EntryPointInfo(IExceptionContext ectx, MethodInfo method, + TlcModule.EntryPointAttribute attribute, ObsoleteAttribute obsoleteAttribute) { Contracts.AssertValueOrNull(ectx); ectx.AssertValue(method); @@ -61,6 +63,7 @@ internal EntryPointInfo(IExceptionContext ectx, MethodInfo method, TlcModule.Ent Method = method; ShortName = attribute.ShortName; FriendlyName = attribute.UserName; + ObsoleteAttribute = obsoleteAttribute; // There are supposed to be 2 parameters, env and input for non-macro nodes. // Macro nodes have a 3rd parameter, the entry point node. @@ -183,7 +186,10 @@ private ModuleCatalog(IExceptionContext ectx) var attr = methodInfo.GetCustomAttributes(typeof(TlcModule.EntryPointAttribute), false).FirstOrDefault() as TlcModule.EntryPointAttribute; if (attr == null) continue; - var info = new EntryPointInfo(ectx, methodInfo, attr); + + var info = new EntryPointInfo(ectx, methodInfo, attr, + methodInfo.GetCustomAttributes(typeof(ObsoleteAttribute), false).FirstOrDefault() as ObsoleteAttribute); + entryPoints.Add(info); if (_entryPointMap.ContainsKey(info.Name)) { diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index 3867b18f26..3678c749ba 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -85,15 +85,19 @@ private bool TryParse(string str) return TryParseSource(rgstr[istr++]); } - private bool TryParseSource(string str) + private bool TryParseSource(string str) => TryParseSourceEx(str, out Source); + + public static bool TryParseSourceEx(string str, out Range[] ranges) { + ranges = null; var strs = str.Split(','); if (str.Length == 0) return false; - Source = new Range[strs.Length]; + + ranges = new Range[strs.Length]; for (int i = 0; i < strs.Length; i++) { - if ((Source[i] = Range.Parse(strs[i])) == null) + if ((ranges[i] = Range.Parse(strs[i])) == null) return false; } return true; @@ -294,9 +298,12 @@ public class ArgumentsCore ShortName = "size")] public int? InputSize; - [Argument(ArgumentType.AtMostOnce, HelpText = "Source column separator. Options: tab, space, comma, single character", ShortName = "sep")] + [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Source column separator. Options: tab, space, comma, single character", ShortName = "sep")] public string Separator = "tab"; + [Argument(ArgumentType.AtMostOnce, Name = nameof(Separator), Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Source column separator.", ShortName = "sep")] + public char[] SeparatorChars = new[] { '\t' }; + [Argument(ArgumentType.Multiple, HelpText = "Column groups. Each group is specified as name:type:numeric-ranges, eg, col=Features:R4:1-17,26,35-40", ShortName = "col", SortOrder = 1)] public Column[] Column; @@ -1005,26 +1012,40 @@ public TextLoader(IHostEnvironment env, Arguments args, IMultiStreamSource files _inputSize = SrcLim - 1; _host.CheckNonEmpty(args.Separator, nameof(args.Separator), "Must specify a separator"); - string sep = args.Separator.ToLowerInvariant(); - if (sep == ",") - _separators = new char[] { ',' }; - else + //Default arg.Separator is tab and default args.SeparatorChars is also a '\t'. + //At a time only one default can be different and whichever is different that will + //be used. + if (args.SeparatorChars.Length > 1 || args.SeparatorChars[0] != '\t') { var separators = new HashSet(); - foreach (string s in sep.Split(',')) - { - if (string.IsNullOrEmpty(s)) - continue; + foreach (char c in args.SeparatorChars) + separators.Add(NormalizeSeparator(c.ToString())); - char c = NormalizeSeparator(s); - separators.Add(c); - } _separators = separators.ToArray(); - - // Handling ",,,," case, that .Split() returns empty strings. - if (_separators.Length == 0) + } + else + { + string sep = args.Separator.ToLowerInvariant(); + if (sep == ",") _separators = new char[] { ',' }; + else + { + var separators = new HashSet(); + foreach (string s in sep.Split(',')) + { + if (string.IsNullOrEmpty(s)) + continue; + + char c = NormalizeSeparator(s); + separators.Add(c); + } + _separators = separators.ToArray(); + + // Handling ",,,," case, that .Split() returns empty strings. + if (_separators.Length == 0) + _separators = new char[] { ',' }; + } } _bindings = new Bindings(this, cols, headerFile); diff --git a/src/Microsoft.ML.PipelineInference/AutoInference.cs b/src/Microsoft.ML.PipelineInference/AutoInference.cs index 894029460a..7a340e5957 100644 --- a/src/Microsoft.ML.PipelineInference/AutoInference.cs +++ b/src/Microsoft.ML.PipelineInference/AutoInference.cs @@ -579,11 +579,13 @@ public static AutoMlMlState InferPipelines(IHostEnvironment env, PipelineOptimiz RecipeInference.InferRecipesFromData(env, trainDataPath, schemaDefinitionFile, out var _, out schemaDefinition, out var _, true); +#pragma warning disable 0618 var data = ImportTextData.ImportText(env, new ImportTextData.Input { InputFile = new SimpleFileHandle(env, trainDataPath, false, false), CustomSchema = schemaDefinition }).Data; +#pragma warning restore 0618 var splitOutput = TrainTestSplit.Split(env, new TrainTestSplit.Input { Data = data, Fraction = 0.8f }); AutoMlMlState amls = new AutoMlMlState(env, metric, autoMlEngine, terminator, trainerKind, splitOutput.TrainData.Take(numOfSampleRows), splitOutput.TestData.Take(numOfSampleRows)); diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 2ca1af7159..317ee98db0 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -22,6 +22,30 @@ namespace Runtime { public sealed partial class Experiment { + public Microsoft.ML.Data.CustomTextLoader.Output Add(Microsoft.ML.Data.CustomTextLoader input) + { + var output = new Microsoft.ML.Data.CustomTextLoader.Output(); + Add(input, output); + return output; + } + + public void Add(Microsoft.ML.Data.CustomTextLoader input, Microsoft.ML.Data.CustomTextLoader.Output output) + { + _jsonNodes.Add(Serialize("Data.CustomTextLoader", input, output)); + } + + public Microsoft.ML.Data.DataViewReference.Output Add(Microsoft.ML.Data.DataViewReference input) + { + var output = new Microsoft.ML.Data.DataViewReference.Output(); + Add(input, output); + return output; + } + + public void Add(Microsoft.ML.Data.DataViewReference input, Microsoft.ML.Data.DataViewReference.Output output) + { + _jsonNodes.Add(Serialize("Data.DataViewReference", input, output)); + } + public Microsoft.ML.Data.IDataViewArrayConverter.Output Add(Microsoft.ML.Data.IDataViewArrayConverter input) { var output = new Microsoft.ML.Data.IDataViewArrayConverter.Output(); @@ -53,22 +77,11 @@ public Microsoft.ML.Data.TextLoader.Output Add(Microsoft.ML.Data.TextLoader inpu return output; } - public Microsoft.ML.Data.DataViewReference.Output Add(Microsoft.ML.Data.DataViewReference input) - { - var output = new Microsoft.ML.Data.DataViewReference.Output(); - Add(input, output); - return output; - } - public void Add(Microsoft.ML.Data.TextLoader input, Microsoft.ML.Data.TextLoader.Output output) { _jsonNodes.Add(Serialize("Data.TextLoader", input, output)); } - public void Add(Microsoft.ML.Data.DataViewReference input, Microsoft.ML.Data.DataViewReference.Output output) - { - _jsonNodes.Add(Serialize("Data.DataViewReference", input, output)); - } public Microsoft.ML.Models.AnomalyDetectionEvaluator.Output Add(Microsoft.ML.Models.AnomalyDetectionEvaluator input) { var output = new Microsoft.ML.Models.AnomalyDetectionEvaluator.Output(); @@ -453,6 +466,18 @@ public void Add(Microsoft.ML.Trainers.GeneralizedAdditiveModelRegressor input, M _jsonNodes.Add(Serialize("Trainers.GeneralizedAdditiveModelRegressor", input, output)); } + public Microsoft.ML.Trainers.KMeansPlusPlusClusterer.Output Add(Microsoft.ML.Trainers.KMeansPlusPlusClusterer input) + { + var output = new Microsoft.ML.Trainers.KMeansPlusPlusClusterer.Output(); + Add(input, output); + return output; + } + + public void Add(Microsoft.ML.Trainers.KMeansPlusPlusClusterer input, Microsoft.ML.Trainers.KMeansPlusPlusClusterer.Output output) + { + _jsonNodes.Add(Serialize("Trainers.KMeansPlusPlusClusterer", input, output)); + } + public Microsoft.ML.Trainers.LinearSvmBinaryClassifier.Output Add(Microsoft.ML.Trainers.LinearSvmBinaryClassifier input) { var output = new Microsoft.ML.Trainers.LinearSvmBinaryClassifier.Output(); @@ -1271,6 +1296,66 @@ public void Add(Microsoft.ML.Transforms.WordTokenizer input, Microsoft.ML.Transf } } + namespace Data + { + + /// + /// Import a dataset from a text file + /// + [Obsolete("Use TextLoader instead.")] + public sealed partial class CustomTextLoader + { + + + /// + /// Location of the input file + /// + public Var InputFile { get; set; } = new Var(); + + /// + /// Custom schema to use for parsing + /// + public string CustomSchema { get; set; } + + + public sealed class Output + { + /// + /// The resulting data view + /// + public Var Data { get; set; } = new Var(); + + } + } + } + + namespace Data + { + + /// + /// Pass dataview from memory to experiment + /// + public sealed partial class DataViewReference + { + + + /// + /// Pointer to IDataView in memory + /// + public Var Data { get; set; } = new Var(); + + + public sealed class Output + { + /// + /// The resulting data view + /// + public Var Data { get; set; } = new Var(); + + } + } + } + namespace Data { @@ -1328,40 +1413,185 @@ public sealed class Output namespace Data { - /// - /// Import a dataset from a text file - /// - public sealed partial class TextLoader + public sealed partial class TextLoaderArguments { + /// + /// Use separate parsing threads? + /// + public bool UseThreads { get; set; } = true; + /// + /// File containing a header with feature names. If specified, header defined in the data file (header+) is ignored. + /// + public string HeaderFile { get; set; } /// - /// Location of the input file + /// Maximum number of rows to produce /// - public Var InputFile { get; set; } = new Var(); + public long? MaxRows { get; set; } /// - /// Custom schema to use for parsing + /// Whether the input may include quoted values, which can contain separator characters, colons, and distinguish empty values from missing values. When true, consecutive separators denote a missing value and an empty value is denoted by "". When false, consecutive separators denote an empty value. /// - public string CustomSchema { get; set; } + public bool AllowQuoting { get; set; } = true; + + /// + /// Whether the input may include sparse representations + /// + public bool AllowSparse { get; set; } = true; + /// + /// Number of source columns in the text data. Default is that sparse rows contain their size information. + /// + public int? InputSize { get; set; } - public sealed class Output - { - /// - /// The resulting data view - /// - public Var Data { get; set; } = new Var(); + /// + /// Source column separator. + /// + public char[] Separator { get; set; } = { '\t' }; + + /// + /// Column groups. Each group is specified as name:type:numeric-ranges, eg, col=Features:R4:1-17,26,35-40 + /// + public TextLoaderColumn[] Column { get; set; } + + /// + /// Remove trailing whitespace from lines + /// + public bool TrimWhitespace { get; set; } = false; + + /// + /// Data file has header with feature names. Header is read only if options 'hs' and 'hf' are not specified. + /// + public bool HasHeader { get; set; } = false; - } } - public sealed partial class DataViewReference + public sealed partial class TextLoaderColumn { + /// + /// Name of the column + /// + public string Name { get; set; } + + /// + /// Type of the items in the column + /// + public DataKind? Type { get; set; } + + /// + /// Source index range(s) of the column + /// + public TextLoaderRange[] Source { get; set; } + + /// + /// For a key column, this defines the range of values + /// + public KeyRange KeyRange { get; set; } + + } + + public sealed partial class TextLoaderRange + { + /// + /// First index in the range + /// + public int Min { get; set; } + + /// + /// Last index in the range + /// + public int? Max { get; set; } + + /// + /// This range extends to the end of the line, but should be a fixed number of items + /// + public bool AutoEnd { get; set; } = false; + + /// + /// This range extends to the end of the line, which can vary from line to line + /// + public bool VariableEnd { get; set; } = false; + + /// + /// This range includes only other indices not specified + /// + public bool AllOther { get; set; } = false; + + /// + /// Force scalar columns to be treated as vectors of length one + /// + public bool ForceVector { get; set; } = false; + + } + + public sealed partial class KeyRange + { + /// + /// First index in the range + /// + public ulong Min { get; set; } = 0; + + /// + /// Last index in the range + /// + public ulong? Max { get; set; } + + /// + /// Whether the key is contiguous + /// + public bool Contiguous { get; set; } = true; + + } + + /// + /// Import a dataset from a text file + /// + public sealed partial class TextLoader : Microsoft.ML.ILearningPipelineLoader + { + + [JsonIgnore] + private string _inputFilePath = null; + public TextLoader(string filePath) + { + _inputFilePath = filePath; + } + + public void SetInput(IHostEnvironment env, Experiment experiment) + { + IFileHandle inputFile = new SimpleFileHandle(env, _inputFilePath, false, false); + experiment.SetInput(InputFile, inputFile); + } + + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) + { + Contracts.Assert(previousStep == null); + + return new TextLoaderPipelineStep(experiment.Add(this)); + } + + private class TextLoaderPipelineStep : ILearningPipelineDataStep + { + public TextLoaderPipelineStep (Output output) + { + Data = output.Data; + Model = null; + } + + public Var Data { get; } + public Var Model { get; } + } + /// /// Location of the input file /// - public Var Data { get; set; } = new Var(); + public Var InputFile { get; set; } = new Var(); + + /// + /// Arguments + /// + public Data.TextLoaderArguments Arguments { get; set; } = new Data.TextLoaderArguments(); + public sealed class Output { @@ -1561,7 +1791,7 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ICla namespace Models { - public sealed class CrossValidationBinaryMacroSubGraphInput + public sealed partial class CrossValidationBinaryMacroSubGraphInput { /// /// The data to be used for training @@ -1570,7 +1800,7 @@ public sealed class CrossValidationBinaryMacroSubGraphInput } - public sealed class CrossValidationBinaryMacroSubGraphOutput + public sealed partial class CrossValidationBinaryMacroSubGraphOutput { /// /// The model @@ -1826,7 +2056,7 @@ public enum MacroUtilsTrainerKinds } - public sealed class CrossValidationMacroSubGraphInput + public sealed partial class CrossValidationMacroSubGraphInput { /// /// The data to be used for training @@ -1835,7 +2065,7 @@ public sealed class CrossValidationMacroSubGraphInput } - public sealed class CrossValidationMacroSubGraphOutput + public sealed partial class CrossValidationMacroSubGraphOutput { /// /// The model @@ -2239,7 +2469,7 @@ public enum CachingOptions } - public sealed class OneVersusAllMacroSubGraphOutput + public sealed partial class OneVersusAllMacroSubGraphOutput { /// /// The predictor model for the subgraph exemplar. @@ -2877,7 +3107,7 @@ public sealed class Output namespace Models { - public sealed class TrainTestBinaryMacroSubGraphInput + public sealed partial class TrainTestBinaryMacroSubGraphInput { /// /// The data to be used for training @@ -2886,7 +3116,7 @@ public sealed class TrainTestBinaryMacroSubGraphInput } - public sealed class TrainTestBinaryMacroSubGraphOutput + public sealed partial class TrainTestBinaryMacroSubGraphOutput { /// /// The model @@ -2962,7 +3192,7 @@ public sealed class Output namespace Models { - public sealed class TrainTestMacroSubGraphInput + public sealed partial class TrainTestMacroSubGraphInput { /// /// The data to be used for training @@ -2971,7 +3201,7 @@ public sealed class TrainTestMacroSubGraphInput } - public sealed class TrainTestMacroSubGraphOutput + public sealed partial class TrainTestMacroSubGraphOutput { /// /// The model @@ -5686,6 +5916,107 @@ public GeneralizedAdditiveModelRegressorPipelineStep(Output output) } } + namespace Trainers + { + public enum KMeansPlusPlusTrainerInitAlgorithm + { + KMeansPlusPlus = 0, + Random = 1, + KMeansParallel = 2 + } + + + /// + /// K-means is a popular clustering algorithm. With K-means, the data is clustered into a specified number of clusters in order to minimize the within-cluster sum of squares. K-means++ improves upon K-means by using a better method for choosing the initial cluster centers. + /// + public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem + { + + + /// + /// The number of clusters + /// + [TlcModule.SweepableDiscreteParamAttribute("K", new object[]{5, 10, 20, 40})] + public int K { get; set; } = 5; + + /// + /// Cluster initialization algorithm + /// + public Trainers.KMeansPlusPlusTrainerInitAlgorithm InitAlgorithm { get; set; } = Trainers.KMeansPlusPlusTrainerInitAlgorithm.KMeansParallel; + + /// + /// Tolerance parameter for trainer convergence. Lower = slower, more accurate + /// + public float OptTol { get; set; } = 1E-07f; + + /// + /// Maximum number of iterations. + /// + public int MaxIterations { get; set; } = 1000; + + /// + /// Memory budget (in MBs) to use for KMeans acceleration + /// + public int AccelMemBudgetMb { get; set; } = 4096; + + /// + /// Degree of lock-free parallelism. Defaults to automatic. Determinism not guaranteed. + /// + public int? NumThreads { get; set; } + + /// + /// The data to be used for training + /// + public Var TrainingData { get; set; } = new Var(); + + /// + /// Column to use for features + /// + public string FeatureColumn { get; set; } = "Features"; + + /// + /// Normalize option for the feature column + /// + public Models.NormalizeOption NormalizeFeatures { get; set; } = Models.NormalizeOption.Auto; + + /// + /// Whether learner should cache input training data + /// + public Models.CachingOptions Caching { get; set; } = Models.CachingOptions.Auto; + + + public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IClusteringOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + { + /// + /// The trained model + /// + public Var PredictorModel { get; set; } = new Var(); + + } + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) + { + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(KMeansPlusPlusClusterer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } + + TrainingData = dataStep.Data; + Output output = experiment.Add(this); + return new KMeansPlusPlusClustererPipelineStep(output); + } + + private class KMeansPlusPlusClustererPipelineStep : ILearningPipelinePredictorStep + { + public KMeansPlusPlusClustererPipelineStep(Output output) + { + Model = output.PredictorModel; + } + + public Var Model { get; } + } + } + } + namespace Trainers { @@ -7196,7 +7527,7 @@ public BinaryPredictionScoreColumnsRenamerPipelineStep(Output output) namespace Transforms { - public sealed class NormalizeTransformBinColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class NormalizeTransformBinColumn : OneToOneColumn, IOneToOneColumn { /// /// Max number of bins, power of 2 recommended @@ -7348,7 +7679,7 @@ public enum CategoricalTransformOutputKind : byte } - public sealed class CategoricalHashTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class CategoricalHashTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// The number of bits to hash into. Must be between 1 and 30, inclusive. @@ -7518,7 +7849,7 @@ public enum TermTransformSortOrder : byte } - public sealed class CategoricalTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class CategoricalTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Output kind: Bag (multi-set vector), Ind (indicator vector), Key (index), or Binary encoded indicator vector @@ -7682,7 +8013,7 @@ public CategoricalOneHotVectorizerPipelineStep(Output output) namespace Transforms { - public sealed class CharTokenizeTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class CharTokenizeTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Name of the new column @@ -7801,7 +8132,7 @@ public CharacterTokenizerPipelineStep(Output output) namespace Transforms { - public sealed class ConcatTransformColumn : ManyToOneColumn, IManyToOneColumn + public sealed partial class ConcatTransformColumn : ManyToOneColumn, IManyToOneColumn { /// /// Name of the new column @@ -7891,7 +8222,7 @@ public ColumnConcatenatorPipelineStep(Output output) namespace Transforms { - public sealed class CopyColumnsTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class CopyColumnsTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Name of the new column @@ -8153,7 +8484,7 @@ public enum DataKind : byte } - public sealed class ConvertTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class ConvertTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// The result type @@ -8352,7 +8683,7 @@ public CombinerByContiguousGroupIdPipelineStep(Output output) namespace Transforms { - public sealed class NormalizeTransformAffineColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class NormalizeTransformAffineColumn : OneToOneColumn, IOneToOneColumn { /// /// Whether to map zero to zero, preserving sparsity @@ -8625,7 +8956,7 @@ public sealed class Output namespace Transforms { - public sealed class TermTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class TermTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Maximum number of terms to keep when auto-training @@ -8979,7 +9310,7 @@ public FeatureSelectorByMutualInformationPipelineStep(Output output) namespace Transforms { - public sealed class LpNormNormalizerTransformGcnColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class LpNormNormalizerTransformGcnColumn : OneToOneColumn, IOneToOneColumn { /// /// Normalize by standard deviation rather than L2 norm @@ -9123,7 +9454,7 @@ public GlobalContrastNormalizerPipelineStep(Output output) namespace Transforms { - public sealed class HashJoinTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class HashJoinTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Whether the values need to be combined for a single hash @@ -9282,7 +9613,7 @@ public HashConverterPipelineStep(Output output) namespace Transforms { - public sealed class KeyToValueTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class KeyToValueTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Name of the new column @@ -9461,7 +9792,7 @@ public LabelColumnKeyBooleanConverterPipelineStep(Output output) namespace Transforms { - public sealed class LabelIndicatorTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class LabelIndicatorTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// The positive example class for binary classification. @@ -9645,7 +9976,7 @@ public LabelToFloatConverterPipelineStep(Output output) namespace Transforms { - public sealed class NormalizeTransformLogNormalColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class NormalizeTransformLogNormalColumn : OneToOneColumn, IOneToOneColumn { /// /// Max number of examples used to train the normalizer @@ -9782,7 +10113,7 @@ public enum LpNormNormalizerTransformNormalizerKind : byte } - public sealed class LpNormNormalizerTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class LpNormNormalizerTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// The norm to use to normalize each sample @@ -10185,7 +10516,7 @@ public enum NAHandleTransformReplacementKind } - public sealed class NAHandleTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class NAHandleTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// The replacement method to utilize @@ -10329,7 +10660,7 @@ public MissingValueHandlerPipelineStep(Output output) namespace Transforms { - public sealed class NAIndicatorTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class NAIndicatorTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Name of the new column @@ -10443,7 +10774,7 @@ public MissingValueIndicatorPipelineStep(Output output) namespace Transforms { - public sealed class NADropTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class NADropTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Name of the new column @@ -10637,7 +10968,7 @@ public enum NAReplaceTransformReplacementKind } - public sealed class NAReplaceTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class NAReplaceTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Replacement value for NAs (uses default value if not given) @@ -10810,7 +11141,7 @@ public enum NgramTransformWeightingCriteria } - public sealed class NgramTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class NgramTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Maximum ngram length @@ -11149,7 +11480,7 @@ public PredictedLabelColumnOriginalValueConverterPipelineStep(Output output) namespace Transforms { - public sealed class GenerateNumberTransformColumn + public sealed partial class GenerateNumberTransformColumn { /// /// Name of the new column @@ -11888,7 +12219,7 @@ public enum TextTransformTextNormKind } - public sealed class TextTransformColumn : ManyToOneColumn, IManyToOneColumn + public sealed partial class TextTransformColumn : ManyToOneColumn, IManyToOneColumn { /// /// Name of the new column @@ -11902,7 +12233,7 @@ public sealed class TextTransformColumn : ManyToOneColumn, } - public sealed class TermLoaderArguments + public sealed partial class TermLoaderArguments { /// /// List of terms @@ -12317,7 +12648,7 @@ public sealed class Output namespace Transforms { - public sealed class DelimitedTokenizeTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class DelimitedTokenizeTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Comma separated set of term separator(s). Commonly: 'space', 'comma', 'semicolon' or other single character. diff --git a/src/Microsoft.ML/Data/TextLoader.cs b/src/Microsoft.ML/Data/TextLoader.cs new file mode 100644 index 0000000000..3c8550ef09 --- /dev/null +++ b/src/Microsoft.ML/Data/TextLoader.cs @@ -0,0 +1,179 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using System; +using System.Linq; +using System.Reflection; +using System.Text.RegularExpressions; + +namespace Microsoft.ML.Data +{ + public sealed partial class TextLoaderRange + { + public TextLoaderRange() + { + } + + /// + /// Convenience constructor for the scalar case, when a given column + /// in the schema spans only a single column in the dataset. + /// and are set to the single value . + /// + /// Column index in the dataset. + public TextLoaderRange(int ordinal) + { + + Contracts.CheckParam(ordinal >= 0, nameof(ordinal), "Cannot be a negative number"); + + Min = ordinal; + Max = ordinal; + } + + /// + /// Convenience constructor for the vector case, when a given column + /// in the schema spans contiguous columns in the dataset. + /// + /// Starting column index in the dataset. + /// Ending column index in the dataset. + public TextLoaderRange(int min, int max) + { + + Contracts.CheckParam(min >= 0, nameof(min), "Cannot be a negative number."); + Contracts.CheckParam(max >= min, nameof(max), "Cannot be less than " + nameof(min) +"."); + + Min = min; + Max = max; + } + } + + public sealed partial class TextLoader + { + /// + /// Construct a TextLoader object by inferencing the dataset schema from a type. + /// + /// Does the file contains header? + /// Column separator character. Default is '\t' + /// Whether the input may include quoted values, + /// which can contain separator characters, colons, + /// and distinguish empty values from missing values. When true, consecutive separators + /// denote a missing value and an empty value is denoted by \"\". + /// When false, consecutive separators denote an empty value. + /// Whether the input may include sparse representations e.g. + /// if one of the row contains "5 2:6 4:3" that's mean there are 5 columns all zero + /// except for 3rd and 5th columns which have values 6 and 3 + /// Remove trailing whitespace from lines + public TextLoader CreateFrom(bool useHeader = false, + char separator = '\t', bool allowQuotedStrings = true, + bool supportSparse = true, bool trimWhitespace = false) + { + var fields = typeof(TInput).GetFields(); + Arguments.Column = new TextLoaderColumn[fields.Length]; + for (int index = 0; index < fields.Length; index++) + { + var field = fields[index]; + var mappingAttr = field.GetCustomAttribute(); + if (mappingAttr == null) + throw Contracts.Except($"{field.Name} is missing ColumnAttribute"); + + if (Regex.Match(mappingAttr.Ordinal, @"[^(0-9,\*\-~)]+").Success) + throw Contracts.Except($"{mappingAttr.Ordinal} contains invalid characters. " + + $"Valid characters are 0-9, *, - and ~"); + + var name = mappingAttr.Name ?? field.Name; + if (name.Any(c => !Char.IsLetterOrDigit(c))) + throw Contracts.Except($"{name} is not alphanumeric."); + + Runtime.Data.TextLoader.Range[] sources; + if (!Runtime.Data.TextLoader.Column.TryParseSourceEx(mappingAttr.Ordinal, out sources)) + throw Contracts.Except($"{mappingAttr.Ordinal} could not be parsed."); + + Contracts.Assert(sources != null); + + TextLoaderColumn tlc = new TextLoaderColumn(); + tlc.Name = name; + tlc.Source = new TextLoaderRange[sources.Length]; + DataKind dk; + if (!TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk)) + throw Contracts.Except($"{name} is of unsupported type."); + + tlc.Type = dk; + + for (int indexLocal = 0; indexLocal < tlc.Source.Length; indexLocal++) + { + tlc.Source[indexLocal] = new TextLoaderRange + { + AllOther = sources[indexLocal].AllOther, + AutoEnd = sources[indexLocal].AutoEnd, + ForceVector = sources[indexLocal].ForceVector, + VariableEnd = sources[indexLocal].VariableEnd, + Max = sources[indexLocal].Max, + Min = sources[indexLocal].Min + }; + } + + Arguments.Column[index] = tlc; + } + + Arguments.HasHeader = useHeader; + Arguments.Separator = new[] { separator }; + Arguments.AllowQuoting = allowQuotedStrings; + Arguments.AllowSparse = supportSparse; + Arguments.TrimWhitespace = trimWhitespace; + + return this; + } + + /// + /// Try to map a System.Type to a corresponding DataKind value. + /// + private static bool TryGetDataKind(Type type, out DataKind kind) + { + Contracts.AssertValue(type); + + // REVIEW: Make this more efficient. Should we have a global dictionary? + if (type == typeof(DvInt1) || type == typeof(sbyte)) + kind = DataKind.I1; + else if (type == typeof(byte) || type == typeof(char)) + kind = DataKind.U1; + else if (type == typeof(DvInt2) || type == typeof(short)) + kind = DataKind.I2; + else if (type == typeof(ushort)) + kind = DataKind.U2; + else if (type == typeof(DvInt4) || type == typeof(int)) + kind = DataKind.I4; + else if (type == typeof(uint)) + kind = DataKind.U4; + else if (type == typeof(DvInt8) || type == typeof(long)) + kind = DataKind.I8; + else if (type == typeof(ulong)) + kind = DataKind.U8; + else if (type == typeof(Single)) + kind = DataKind.R4; + else if (type == typeof(Double)) + kind = DataKind.R8; + else if (type == typeof(DvText) || type == typeof(string)) + kind = DataKind.TX; + else if (type == typeof(DvBool) || type == typeof(bool)) + kind = DataKind.BL; + else if (type == typeof(DvTimeSpan) || type == typeof(TimeSpan)) + kind = DataKind.TS; + else if (type == typeof(DvDateTime) || type == typeof(DateTime)) + kind = DataKind.DT; + else if (type == typeof(DvDateTimeZone) || type == typeof(TimeZoneInfo)) + kind = DataKind.DZ; + else if (type == typeof(UInt128)) + kind = DataKind.UG; + else + { + kind = default(DataKind); + return false; + } + + return true; + } + } +} diff --git a/src/Microsoft.ML/LearningPipeline.cs b/src/Microsoft.ML/LearningPipeline.cs index 51677afbf4..0e554734ea 100644 --- a/src/Microsoft.ML/LearningPipeline.cs +++ b/src/Microsoft.ML/LearningPipeline.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. @@ -68,7 +68,7 @@ public LearningPipeline() /// Possible data loader(s), transforms and trainers options are /// /// Data Loader: - /// + /// /// etc. /// /// @@ -154,7 +154,6 @@ public PredictionModel Train() step = currentItem.ApplyStep(step, experiment); if (step is ILearningPipelineDataStep dataStep && dataStep.Model != null) transformModels.Add(dataStep.Model); - else if (step is ILearningPipelinePredictorStep predictorDataStep) { if (lastTransformModel != null) diff --git a/src/Microsoft.ML/Runtime/EntryPoints/ImportTextData.cs b/src/Microsoft.ML/Runtime/EntryPoints/ImportTextData.cs index 8038294398..41048000d8 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/ImportTextData.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/ImportTextData.cs @@ -27,13 +27,25 @@ public sealed class Input public string CustomSchema = null; } + [TlcModule.EntryPointKind(typeof(ILearningPipelineLoader))] + public sealed class LoaderInput + { + [Argument(ArgumentType.Required, ShortName = "data", HelpText = "Location of the input file", SortOrder = 1)] + public IFileHandle InputFile; + + [Argument(ArgumentType.Required, ShortName = "args", HelpText = "Arguments", SortOrder = 2)] + public TextLoader.Arguments Arguments = new TextLoader.Arguments(); + } + public sealed class Output { [TlcModule.Output(Desc = "The resulting data view", SortOrder = 1)] public IDataView Data; } - [TlcModule.EntryPoint(Name = "Data.TextLoader", Desc = "Import a dataset from a text file")] +#pragma warning disable 0618 + [Obsolete("Use TextLoader instead.")] + [TlcModule.EntryPoint(Name = "Data.CustomTextLoader", Desc = "Import a dataset from a text file")] public static Output ImportText(IHostEnvironment env, Input input) { Contracts.CheckValue(env, nameof(env)); @@ -43,5 +55,17 @@ public static Output ImportText(IHostEnvironment env, Input input) var loader = host.CreateLoader(string.Format("Text{{{0}}}", input.CustomSchema), new FileHandleSource(input.InputFile)); return new Output { Data = loader }; } +#pragma warning restore 0618 + + [TlcModule.EntryPoint(Name = "Data.TextLoader", Desc = "Import a dataset from a text file")] + public static Output TextLoader(IHostEnvironment env, LoaderInput input) + { + Contracts.CheckValue(env, nameof(env)); + var host = env.Register("ImportTextData"); + env.CheckValue(input, nameof(input)); + EntryPointUtils.CheckInputArgs(host, input); + var loader = host.CreateLoader(input.Arguments, new FileHandleSource(input.InputFile)); + return new Output { Data = loader }; + } } } diff --git a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs index f1e45fa446..5fabb15840 100644 --- a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs +++ b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs @@ -177,6 +177,37 @@ public static string Capitalize(string s) return char.ToUpperInvariant(s[0]) + s.Substring(1); } + private static string GetCharAsString(char value) + { + switch (value) + { + case '\t': + return "\\t"; + case '\n': + return "\\n"; + case '\r': + return "\\r"; + case '\\': + return "\\"; + case '\"': + return "\""; + case '\'': + return "\\'"; + case '\0': + return "\\0"; + case '\a': + return "\\a"; + case '\b': + return "\\b"; + case '\f': + return "\\f"; + case '\v': + return "\\v"; + default: + return value.ToString(); + } + } + public static string GetValue(ModuleCatalog catalog, Type fieldType, object fieldValue, Dictionary typesSymbolTable, string rootNameSpace = "") { @@ -264,7 +295,7 @@ public static string GetValue(ModuleCatalog catalog, Type fieldType, object fiel case TlcModule.DataKind.Enum: return GetEnumName(fieldType, typesSymbolTable, rootNameSpace) + "." + fieldValue; case TlcModule.DataKind.Char: - return $"'{(char)fieldValue}'"; + return $"'{GetCharAsString((char)fieldValue)}'"; case TlcModule.DataKind.Component: var type = fieldValue.GetType(); ModuleCatalog.ComponentInfo componentInfo; @@ -685,7 +716,7 @@ private void GenerateStructs(IndentingTextWriter writer, classBase = $" : OneToOneColumn<{_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}>, IOneToOneColumn"; else if (type.IsSubclassOf(typeof(ManyToOneColumn))) classBase = $" : ManyToOneColumn<{_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}>, IManyToOneColumn"; - writer.WriteLine($"public sealed class {_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}{classBase}"); + writer.WriteLine($"public sealed partial class {_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}{classBase}"); writer.WriteLine("{"); writer.Indent(); GenerateInputFields(writer, type, catalog, _typesSymbolTable); @@ -696,6 +727,58 @@ private void GenerateStructs(IndentingTextWriter writer, } } + private void GenerateLoaderAddInputMethod(IndentingTextWriter writer, string className) + { + //Constructor. + writer.WriteLine("[JsonIgnore]"); + writer.WriteLine("private string _inputFilePath = null;"); + writer.WriteLine($"public {className}(string filePath)"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine("_inputFilePath = filePath;"); + writer.Outdent(); + writer.WriteLine("}"); + writer.WriteLine(""); + + //SetInput. + writer.WriteLine($"public void SetInput(IHostEnvironment env, Experiment experiment)"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine("IFileHandle inputFile = new SimpleFileHandle(env, _inputFilePath, false, false);"); + writer.WriteLine("experiment.SetInput(InputFile, inputFile);"); + writer.Outdent(); + writer.WriteLine("}"); + writer.WriteLine(""); + + //Apply. + writer.WriteLine($"public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine("Contracts.Assert(previousStep == null);"); + writer.WriteLine(""); + writer.WriteLine($"return new {className}PipelineStep(experiment.Add(this));"); + writer.Outdent(); + writer.WriteLine("}"); + writer.WriteLine(""); + + //Pipelinestep class. + writer.WriteLine($"private class {className}PipelineStep : ILearningPipelineDataStep"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine($"public {className}PipelineStep (Output output)"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine("Data = output.Data;"); + writer.WriteLine("Model = null;"); + writer.Outdent(); + writer.WriteLine("}"); + writer.WriteLine(); + writer.WriteLine("public Var Data { get; }"); + writer.WriteLine("public Var Model { get; }"); + writer.Outdent(); + writer.WriteLine("}"); + } + private void GenerateColumnAddMethods(IndentingTextWriter writer, Type inputType, ModuleCatalog catalog, @@ -842,10 +925,11 @@ private void GenerateInput(IndentingTextWriter writer, var classAndMethod = GeneratorUtils.GetClassAndMethodNames(entryPointInfo); string classBase = ""; if (entryPointInfo.InputKinds != null) + { classBase += $" : {string.Join(", ", entryPointInfo.InputKinds.Select(GeneratorUtils.GetCSharpTypeName))}"; - - if (classBase.Contains("ITransformInput") || classBase.Contains("ITrainerInput")) - classBase += ", Microsoft.ML.ILearningPipelineItem"; + if (entryPointInfo.InputKinds.Any(t => typeof(ITrainerInput).IsAssignableFrom(t) || typeof(ITransformInput).IsAssignableFrom(t))) + classBase += ", Microsoft.ML.ILearningPipelineItem"; + } GenerateEnums(writer, entryPointInfo.InputType, classAndMethod.Item1); writer.WriteLine(); @@ -854,10 +938,17 @@ private void GenerateInput(IndentingTextWriter writer, foreach (var line in entryPointInfo.Description.Split(new[] { Environment.NewLine }, StringSplitOptions.RemoveEmptyEntries)) writer.WriteLine($"/// {line}"); writer.WriteLine("/// "); + + if(entryPointInfo.ObsoleteAttribute != null) + writer.WriteLine($"[Obsolete(\"{entryPointInfo.ObsoleteAttribute.Message}\")]"); + writer.WriteLine($"public sealed partial class {classAndMethod.Item2}{classBase}"); writer.WriteLine("{"); writer.Indent(); writer.WriteLine(); + if (entryPointInfo.InputKinds != null && entryPointInfo.InputKinds.Any(t => typeof(ILearningPipelineLoader).IsAssignableFrom(t))) + GenerateLoaderAddInputMethod(writer, classAndMethod.Item2); + GenerateColumnAddMethods(writer, entryPointInfo.InputType, catalog, classAndMethod.Item2, out Type transformType); writer.WriteLine(); GenerateInputFields(writer, entryPointInfo.InputType, catalog, _typesSymbolTable); diff --git a/src/Microsoft.ML/TextLoader.cs b/src/Microsoft.ML/TextLoader.cs deleted file mode 100644 index 4e3e3fb8e4..0000000000 --- a/src/Microsoft.ML/TextLoader.cs +++ /dev/null @@ -1,124 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using System; -using System.Linq; -using System.Reflection; -using System.Text; - -namespace Microsoft.ML -{ - public class TextLoader : ILearningPipelineLoader - { - private string _inputFilePath; - private string CustomSchema; - private Data.TextLoader ImportTextInput; - - /// - /// Construct a TextLoader object - /// - /// Data file path - /// Does the file contains header? - /// How the columns are seperated? - /// Options: separator="tab", separator="space", separator="comma" or separator=[single character]. - /// By default separator=null means "tab" - /// Whether the input may include quoted values, - /// which can contain separator characters, colons, - /// and distinguish empty values from missing values. When true, consecutive separators - /// denote a missing value and an empty value is denoted by \"\". - /// When false, consecutive separators denote an empty value. - /// Whether the input may include sparse representations e.g. - /// if one of the row contains "5 2:6 4:3" that's mean there are 5 columns all zero - /// except for 3rd and 5th columns which have values 6 and 3 - /// Remove trailing whitespace from lines - public TextLoader(string inputFilePath, bool useHeader = false, - string separator = null, bool allowQuotedStrings = true, - bool supportSparse = true, bool trimWhitespace = false) - { - _inputFilePath = inputFilePath; - SetCustomStringFromType(useHeader, separator, allowQuotedStrings, supportSparse, trimWhitespace); - } - - private IFileHandle GetTextLoaderFileHandle(IHostEnvironment env, string trainFilePath) => - new SimpleFileHandle(env, trainFilePath, false, false); - - private void SetCustomStringFromType(bool useHeader, string separator, - bool allowQuotedStrings, bool supportSparse, bool trimWhitespace) - { - StringBuilder schemaBuilder = new StringBuilder(CustomSchema); - foreach (var field in typeof(TInput).GetFields()) - { - var mappingAttr = field.GetCustomAttribute(); - if(mappingAttr == null) - throw Contracts.ExceptParam(field.Name, $"{field.Name} is missing ColumnAttribute"); - - schemaBuilder.AppendFormat("col={0}:{1}:{2} ", - mappingAttr.Name ?? field.Name, - TypeToName(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType), - mappingAttr.Ordinal); - } - - if (useHeader) - schemaBuilder.Append(nameof(TextLoader.Arguments.HasHeader)).Append("+ "); - - if (separator != null) - schemaBuilder.Append(nameof(TextLoader.Arguments.Separator)).Append("=").Append(separator).Append(" "); - - if (!allowQuotedStrings) - schemaBuilder.Append(nameof(TextLoader.Arguments.AllowQuoting)).Append("- "); - - if (!supportSparse) - schemaBuilder.Append(nameof(TextLoader.Arguments.AllowSparse)).Append("- "); - - if (trimWhitespace) - schemaBuilder.Append(nameof(TextLoader.Arguments.TrimWhitespace)).Append("+ "); - - schemaBuilder.Length--; - CustomSchema = schemaBuilder.ToString(); - } - - private string TypeToName(Type type) - { - if (type == typeof(string)) - return "TX"; - else if (type == typeof(float) || type == typeof(double)) - return "R4"; - else if (type == typeof(bool)) - return "BL"; - else - throw new System.NotSupportedException("Type ${type.FullName} is not implemented or supported."); //Add more types. - } - - public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) - { - Contracts.Assert(previousStep == null); - - ImportTextInput = new Data.TextLoader(); - ImportTextInput.CustomSchema = CustomSchema; - var importOutput = experiment.Add(ImportTextInput); - return new TextLoaderPipelineStep(importOutput.Data); - } - - public void SetInput(IHostEnvironment env, Experiment experiment) - { - IFileHandle inputFile = GetTextLoaderFileHandle(env, _inputFilePath); - experiment.SetInput(ImportTextInput.InputFile, inputFile); - } - - private class TextLoaderPipelineStep : ILearningPipelineDataStep - { - public TextLoaderPipelineStep(Var data) - { - Data = data; - } - - public Var Data { get; } - public Var Model => null; - } - } -} diff --git a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs index e0583f58b7..adfa42e50d 100644 --- a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs +++ b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs @@ -4,6 +4,7 @@ using BenchmarkDotNet.Attributes; using BenchmarkDotNet.Running; +using Microsoft.ML.Data; using Microsoft.ML.Models; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Trainers; @@ -50,7 +51,7 @@ public void Setup() s_trainedModel = TrainCore(); IrisPrediction prediction = s_trainedModel.Predict(s_example); - var testData = new TextLoader(s_dataPath, useHeader: true, separator: "tab"); + var testData = new TextLoader(s_dataPath).CreateFrom(useHeader: true); var evaluator = new ClassificationEvaluator(); s_metrics = evaluator.Evaluate(s_trainedModel, testData); @@ -70,7 +71,7 @@ private static PredictionModel TrainCore() { var pipeline = new LearningPipeline(); - pipeline.Add(new TextLoader(s_dataPath, useHeader: true, separator: "tab")); + pipeline.Add(new TextLoader(s_dataPath).CreateFrom(useHeader: true)); pipeline.Add(new ColumnConcatenator(outputColumn: "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index ee4f56c260..66e241163f 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -36,7 +36,7 @@ public void TestSimpleExperiment() { var experiment = env.CreateExperiment(); - var importInput = new ML.Data.TextLoader(); + var importInput = new ML.Data.TextLoader(dataPath); var importOutput = experiment.Add(importInput); var normalizeInput = new ML.Transforms.MinMaxNormalizer @@ -67,7 +67,7 @@ public void TestSimpleTrainExperiment() { var experiment = env.CreateExperiment(); - var importInput = new ML.Data.TextLoader(); + var importInput = new ML.Data.TextLoader(dataPath); var importOutput = experiment.Add(importInput); var catInput = new ML.Transforms.CategoricalOneHotVectorizer @@ -165,7 +165,7 @@ public void TestTrainTestMacro() var experiment = env.CreateExperiment(); - var importInput = new ML.Data.TextLoader(); + var importInput = new ML.Data.TextLoader(dataPath); var importOutput = experiment.Add(importInput); var trainTestInput = new ML.Models.TrainTestBinaryEvaluator @@ -235,7 +235,7 @@ public void TestCrossValidationBinaryMacro() var experiment = env.CreateExperiment(); - var importInput = new ML.Data.TextLoader(); + var importInput = new ML.Data.TextLoader(dataPath); var importOutput = experiment.Add(importInput); var crossValidateBinary = new ML.Models.BinaryCrossValidator @@ -295,7 +295,7 @@ public void TestCrossValidationMacro() var modelCombineOutput = subGraph.Add(modelCombine); var experiment = env.CreateExperiment(); - var importInput = new ML.Data.TextLoader(); + var importInput = new ML.Data.TextLoader(dataPath); var importOutput = experiment.Add(importInput); var crossValidate = new ML.Models.CrossValidator diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index e8be6c0370..24e8374b4c 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -33,7 +33,35 @@ public void EntryPointTrainTestSplit() { var dataPath = GetDataPath("breast-cancer.txt"); var inputFile = new SimpleFileHandle(Env, dataPath, false, false); - var dataView = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFile, CustomSchema = "col=Label:0 col=Features:TX:1-9" }).Data; + /*var dataView = ImportTextData.ImportText(Env, new ImportTextData.Input + { InputFile = inputFile, CustomSchema = "col=Label:0 col=Features:TX:1-9" }).Data;*/ + + var dataView = ImportTextData.TextLoader(Env, new ImportTextData.LoaderInput() + { + Arguments = + { + SeparatorChars = new []{',' }, + HasHeader = true, + Column = new[] + { + new TextLoader.Column() + { + Name = "Label", + Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} }, + Type = Runtime.Data.DataKind.Text + }, + + new TextLoader.Column() + { + Name = "Features", + Source = new [] { new TextLoader.Range() { Min = 1, Max = 9} }, + Type = Runtime.Data.DataKind.Text + } + } + }, + + InputFile = inputFile + }).Data; var splitOutput = TrainTestSplit.Split(Env, new TrainTestSplit.Input { Data = dataView, Fraction = 0.9f }); @@ -62,7 +90,44 @@ public void EntryPointFeatureCombiner() { var dataPath = GetDataPath("breast-cancer.txt"); var inputFile = new SimpleFileHandle(Env, dataPath, false, false); - var dataView = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFile, CustomSchema = "col=Label:0 col=F1:TX:1 col=F2:I4:2 col=Rest:3-9" }).Data; + var dataView = ImportTextData.TextLoader(Env, new ImportTextData.LoaderInput() + { + Arguments = + { + HasHeader = true, + Column = new[] + { + new TextLoader.Column() + { + Name = "Label", + Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} } + }, + + new TextLoader.Column() + { + Name = "F1", + Source = new [] { new TextLoader.Range() { Min = 1, Max = 1} }, + Type = Runtime.Data.DataKind.Text + }, + + new TextLoader.Column() + { + Name = "F2", + Source = new [] { new TextLoader.Range() { Min = 2, Max = 2} }, + Type = Runtime.Data.DataKind.I4 + }, + + new TextLoader.Column() + { + Name = "Rest", + Source = new [] { new TextLoader.Range() { Min = 3, Max = 9} } + } + } + }, + + InputFile = inputFile + }).Data; + dataView = Env.CreateTransform("Term{col=F1}", dataView); var result = FeatureCombiner.PrepareFeatures(Env, new FeatureCombiner.FeatureCombinerInput() { Data = dataView, Features = new[] { "F1", "F2", "Rest" } }).OutputData; var expected = Env.CreateTransform("Convert{col=F2 type=R4}", dataView); @@ -82,7 +147,44 @@ public void EntryPointScoring() { var dataPath = GetDataPath("breast-cancer.txt"); var inputFile = new SimpleFileHandle(Env, dataPath, false, false); - var dataView = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFile, CustomSchema = "col=Label:0 col=F1:TX:1 col=F2:I4:2 col=Rest:3-9" }).Data; + var dataView = ImportTextData.TextLoader(Env, new ImportTextData.LoaderInput() + { + Arguments = + { + HasHeader = true, + Column = new[] + { + new TextLoader.Column() + { + Name = "Label", + Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} } + }, + + new TextLoader.Column() + { + Name = "F1", + Source = new [] { new TextLoader.Range() { Min = 1, Max = 1} }, + Type = Runtime.Data.DataKind.Text + }, + + new TextLoader.Column() + { + Name = "F2", + Source = new [] { new TextLoader.Range() { Min = 2, Max = 2} }, + Type = Runtime.Data.DataKind.I4 + }, + + new TextLoader.Column() + { + Name = "Rest", + Source = new [] { new TextLoader.Range() { Min = 3, Max = 9} } + } + } + }, + + InputFile = inputFile + }).Data; + dataView = Env.CreateTransform("Term{col=F1}", dataView); var trainData = FeatureCombiner.PrepareFeatures(Env, new FeatureCombiner.FeatureCombinerInput() { Data = dataView, Features = new[] { "F1", "F2", "Rest" } }); @@ -105,7 +207,44 @@ public void EntryPointApplyModel() { var dataPath = GetDataPath("breast-cancer.txt"); var inputFile = new SimpleFileHandle(Env, dataPath, false, false); - var dataView = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFile, CustomSchema = "col=Label:0 col=F1:TX:1 col=F2:I4:2 col=Rest:3-9" }).Data; + var dataView = ImportTextData.TextLoader(Env, new ImportTextData.LoaderInput() + { + Arguments = + { + HasHeader = true, + Column = new[] + { + new TextLoader.Column() + { + Name = "Label", + Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} }, + }, + + new TextLoader.Column() + { + Name = "F1", + Source = new [] { new TextLoader.Range() { Min = 1, Max = 1} }, + Type = Runtime.Data.DataKind.Text + }, + + new TextLoader.Column() + { + Name = "F2", + Source = new [] { new TextLoader.Range() { Min = 2, Max = 2} }, + Type = Runtime.Data.DataKind.I4 + }, + + new TextLoader.Column() + { + Name = "Rest", + Source = new [] { new TextLoader.Range() { Min = 3, Max = 9} } + } + } + }, + + InputFile = inputFile + }).Data; + dataView = Env.CreateTransform("Term{col=F1}", dataView); var data1 = FeatureCombiner.PrepareFeatures(Env, new FeatureCombiner.FeatureCombinerInput() { Data = dataView, Features = new[] { "F1", "F2", "Rest" } }); @@ -120,7 +259,49 @@ public void EntryPointCaching() { var dataPath = GetDataPath("breast-cancer.txt"); var inputFile = new SimpleFileHandle(Env, dataPath, false, false); - var dataView = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFile, CustomSchema = "col=Label:0 col=F1:TX:1 col=F2:I4:2 col=Rest:3-9" }).Data; + /*var dataView = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFile, + CustomSchema = "col=Label:0 col=F1:TX:1 col=F2:I4:2 col=Rest:3-9" }).Data; + */ + + var dataView = ImportTextData.TextLoader(Env, new ImportTextData.LoaderInput() + { + Arguments = + { + SeparatorChars = new []{',' }, + HasHeader = true, + Column = new[] + { + new TextLoader.Column() + { + Name = "Label", + Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} } + }, + + new TextLoader.Column() + { + Name = "F1", + Source = new [] { new TextLoader.Range() { Min = 1, Max = 1} }, + Type = Runtime.Data.DataKind.Text + }, + + new TextLoader.Column() + { + Name = "F2", + Source = new [] { new TextLoader.Range() { Min = 2, Max = 2} }, + Type = Runtime.Data.DataKind.I4 + }, + + new TextLoader.Column() + { + Name = "Rest", + Source = new [] { new TextLoader.Range() { Min = 3, Max = 9} } + } + } + }, + + InputFile = inputFile + }).Data; + dataView = Env.CreateTransform("Term{col=F1}", dataView); var cached1 = Cache.CacheData(Env, new Cache.CacheInput() { Data = dataView, Caching = Cache.CachingType.Memory }); @@ -305,7 +486,7 @@ public void EntryPointOptionalParams() { 'Nodes': [ { - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': { 'InputFile': '$file1' }, @@ -355,7 +536,7 @@ public void EntryPointExecGraphCommand() {{ 'Nodes': [ {{ - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': {{ 'InputFile': '$file1' }}, @@ -512,7 +693,7 @@ public void EntryPointParseColumns() {{ 'Nodes': [ {{ - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': {{ 'InputFile': '$file1' }}, @@ -562,7 +743,7 @@ public void EntryPointCountFeatures() {{ 'Nodes': [ {{ - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': {{ 'InputFile': '$file1' }}, @@ -607,7 +788,7 @@ public void EntryPointMutualSelectFeatures() {{ 'Nodes': [ {{ - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': {{ 'InputFile': '$file1' }}, @@ -653,7 +834,7 @@ public void EntryPointTextToKeyToText() {{ 'Nodes': [ {{ - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': {{ 'InputFile': '$file1', 'CustomSchema': 'sep=comma col=Cat:TX:4' @@ -735,7 +916,7 @@ private void RunTrainScoreEvaluate(string learner, string evaluator, string data {{ 'Nodes': [ {{ - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': {{ 'InputFile': '$file' }}, @@ -1214,7 +1395,7 @@ internal void TestEntryPointPipelineRoutine(string dataFile, string schema, stri {{ 'Nodes': [ {{ - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': {{ 'InputFile': '$file1', 'CustomSchema': '{schema}' @@ -1287,7 +1468,7 @@ internal void TestEntryPointRoutine(string dataFile, string trainerName, string {{ 'Nodes': [ {{ - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': {{ 'InputFile': '$file1' {3} @@ -1459,7 +1640,7 @@ public void EntryPointNormalizeIfNeeded() { 'Nodes': [ { - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': { 'InputFile': '$file' }, @@ -1522,7 +1703,7 @@ public void EntryPointTrainTestBinaryMacro() { 'Nodes': [ { - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': { 'InputFile': '$file' }, @@ -1630,7 +1811,7 @@ public void EntryPointTrainTestMacroNoTransformInput() { 'Nodes': [ { - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': { 'InputFile': '$file' }, @@ -1744,7 +1925,7 @@ public void EntryPointTrainTestMacro() { 'Nodes': [ { - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': { 'InputFile': '$file' }, @@ -1843,7 +2024,7 @@ public void EntryPointChainedTrainTestMacros() { 'Nodes': [ { - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': { 'InputFile': '$file' }, @@ -2019,7 +2200,7 @@ public void EntryPointChainedCrossValMacros() { 'Nodes': [ { - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': { 'InputFile': '$file' }, @@ -2214,7 +2395,7 @@ public void EntryPointMacroEarlyExpansion() { 'Nodes': [ { - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': { 'InputFile': '$file' }, @@ -2302,7 +2483,7 @@ public void EntryPointSerialization() { 'Nodes': [ { - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': { 'InputFile': '$file' }, @@ -2368,7 +2549,7 @@ public void EntryPointNodeSchedulingFields() { 'Nodes': [ { - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'StageId': '5063dee8f19c4dd89a1fc3a9da5351a7', 'Inputs': { 'InputFile': '$file' @@ -2437,7 +2618,7 @@ public void EntryPointPrepareLabelConvertPredictedLabel() {{ 'Nodes': [ {{ - 'Name': 'Data.TextLoader', + 'Name': 'Data.CustomTextLoader', 'Inputs': {{ 'InputFile': '$file1', 'CustomSchema': 'sep=comma col=Label:TX:4 col=Features:Num:0-3' @@ -2527,7 +2708,9 @@ public void EntryPointTreeLeafFeaturizer() { var dataPath = GetDataPath(@"adult.tiny.with-schema.txt"); var inputFile = new SimpleFileHandle(Env, dataPath, false, false); +#pragma warning disable 0618 var dataView = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFile }).Data; +#pragma warning restore 0618 var cat = Categorical.CatTransformDict(Env, new CategoricalTransform.Arguments() { Data = dataView, diff --git a/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs b/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs index 5166540ccd..f0e7d8ec73 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs @@ -49,12 +49,13 @@ public void TestLearn() // Use best pipeline for another task var inputFileTrain = new SimpleFileHandle(env, pathData, false, false); +#pragma warning disable 0618 var datasetTrain = ImportTextData.ImportText(env, new ImportTextData.Input { InputFile = inputFileTrain, CustomSchema = schema }).Data; var inputFileTest = new SimpleFileHandle(env, pathDataTest, false, false); var datasetTest = ImportTextData.ImportText(env, new ImportTextData.Input { InputFile = inputFileTest, CustomSchema = schema }).Data; - +#pragma warning restore 0618 // REVIEW: Theoretically, it could be the case that a new, very bad learner is introduced and // we get unlucky and only select it every time, such that this test fails. Not // likely at all, but a non-zero probability. Should be ok, since all current learners are returning d > .80. @@ -77,11 +78,13 @@ public void EntryPointPipelineSweepSerialization() "sep=, col=Features:R4:0,2,4,10-12 col=workclass:TX:1 col=education:TX:3 col=marital_status:TX:5 col=occupation:TX:6 " + "col=relationship:TX:7 col=ethnicity:TX:8 col=sex:TX:9 col=native_country:TX:13 col=label_IsOver50K_:R4:14 header=+"; var inputFileTrain = new SimpleFileHandle(Env, pathData, false, false); +#pragma warning disable 0618 var datasetTrain = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFileTrain, CustomSchema = schema }).Data.Take(numOfSampleRows); var inputFileTest = new SimpleFileHandle(Env, pathDataTest, false, false); var datasetTest = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFileTest, CustomSchema = schema }).Data.Take(numOfSampleRows); +#pragma warning restore 0618 // Define entrypoint graph string inputGraph = @" @@ -143,12 +146,13 @@ public void EntryPointPipelineSweep() const int numOfSampleRows = 1000; int numIterations = 4; var inputFileTrain = new SimpleFileHandle(Env, pathData, false, false); +#pragma warning disable 0618 var datasetTrain = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFileTrain }).Data.Take(numOfSampleRows); var inputFileTest = new SimpleFileHandle(Env, pathDataTest, false, false); var datasetTest = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFileTest }).Data.Take(numOfSampleRows); - +#pragma warning restore 0618 // Define entrypoint graph string inputGraph = @" { diff --git a/test/Microsoft.ML.TestFramework/ModelHelper.cs b/test/Microsoft.ML.TestFramework/ModelHelper.cs index dca360c4e3..1b0ab4eb8e 100644 --- a/test/Microsoft.ML.TestFramework/ModelHelper.cs +++ b/test/Microsoft.ML.TestFramework/ModelHelper.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -40,24 +41,187 @@ public static void WriteKcHousePriceModel(string dataPath, Stream stream) public static IDataView GetKcHouseDataView(string dataPath) { - var dataSchema = "col=Id:TX:0 col=Date:TX:1 col=Label:R4:2 col=Bedrooms:R4:3 col=Bathrooms:R4:4 col=SqftLiving:R4:5 col=SqftLot:R4:6 col=Floors:R4:7 col=Waterfront:R4:8 col=View:R4:9 col=Condition:R4:10 col=Grade:R4:11 col=SqftAbove:R4:12 col=SqftBasement:R4:13 col=YearBuilt:R4:14 col=YearRenovated:R4:15 col=Zipcode:R4:16 col=Lat:R4:17 col=Long:R4:18 col=SqftLiving15:R4:19 col=SqftLot15:R4:20 header+ sep=,"; - var txtArgs = new TextLoader.Arguments(); + var dataSchema = "col=Id:TX:0 col=Date:TX:1 col=Label:R4:2 col=Bedrooms:R4:3 " + + "col=Bathrooms:R4:4 col=SqftLiving:R4:5 col=SqftLot:R4:6 col=Floors:R4:7 " + + "col=Waterfront:R4:8 col=View:R4:9 col=Condition:R4:10 col=Grade:R4:11 " + + "col=SqftAbove:R4:12 col=SqftBasement:R4:13 col=YearBuilt:R4:14 " + + "col=YearRenovated:R4:15 col=Zipcode:R4:16 col=Lat:R4:17 col=Long:R4:18 " + + "col=SqftLiving15:R4:19 col=SqftLot15:R4:20 header+ sep=,"; + + var txtArgs = new Runtime.Data.TextLoader.Arguments(); bool parsed = CmdParser.ParseArguments(s_environment, dataSchema, txtArgs); s_environment.Assert(parsed); - var txtLoader = new TextLoader(s_environment, txtArgs, new MultiFileSource(dataPath)); + var txtLoader = new Runtime.Data.TextLoader(s_environment, txtArgs, new MultiFileSource(dataPath)); return txtLoader; } private static ITransformModel CreateKcHousePricePredictorModel(string dataPath) { - var dataSchema = "col=Id:TX:0 col=Date:TX:1 col=Label:R4:2 col=Bedrooms:R4:3 col=Bathrooms:R4:4 col=SqftLiving:R4:5 col=SqftLot:R4:6 col=Floors:R4:7 col=Waterfront:R4:8 col=View:R4:9 col=Condition:R4:10 col=Grade:R4:11 col=SqftAbove:R4:12 col=SqftBasement:R4:13 col=YearBuilt:R4:14 col=YearRenovated:R4:15 col=Zipcode:R4:16 col=Lat:R4:17 col=Long:R4:18 col=SqftLiving15:R4:19 col=SqftLot15:R4:20 header+ sep=,"; - Experiment experiment = s_environment.CreateExperiment(); - var importData = new Data.TextLoader(); - importData.CustomSchema = dataSchema; - Data.TextLoader.Output imported = experiment.Add(importData); + var importData = new Data.TextLoader(dataPath) + { + Arguments = new TextLoaderArguments + { + Separator = new[] { ',' }, + HasHeader = true, + Column = new[] + { + new TextLoaderColumn() + { + Name = "Id", + Source = new [] { new TextLoaderRange(0) }, + Type = Runtime.Data.DataKind.Text + }, + + new TextLoaderColumn() + { + Name = "Date", + Source = new [] { new TextLoaderRange(1) }, + Type = Runtime.Data.DataKind.Text + }, + + new TextLoaderColumn() + { + Name = "Label", + Source = new [] { new TextLoaderRange(2) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "Bedrooms", + Source = new [] { new TextLoaderRange(3) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "Bathrooms", + Source = new [] { new TextLoaderRange(4) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "SqftLiving", + Source = new [] { new TextLoaderRange(5) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "SqftLot", + Source = new [] { new TextLoaderRange(6) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "Floors", + Source = new [] { new TextLoaderRange(7) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "Waterfront", + Source = new [] { new TextLoaderRange(8) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "View", + Source = new [] { new TextLoaderRange(9) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "Condition", + Source = new [] { new TextLoaderRange(10) }, + Type = Runtime.Data.DataKind.Num + }, + new TextLoaderColumn() + { + Name = "Grade", + Source = new [] { new TextLoaderRange(11) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "SqftAbove", + Source = new [] { new TextLoaderRange(12) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "SqftBasement", + Source = new [] { new TextLoaderRange(13) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "YearBuilt", + Source = new [] { new TextLoaderRange(14) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "YearRenovated", + Source = new [] { new TextLoaderRange(15) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "Zipcode", + Source = new [] { new TextLoaderRange(16) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "Lat", + Source = new [] { new TextLoaderRange(17) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "Long", + Source = new [] { new TextLoaderRange(18) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "SqftLiving15", + Source = new [] { new TextLoaderRange(19) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "SqftLot15", + Source = new [] { new TextLoaderRange(20) }, + Type = Runtime.Data.DataKind.Num + }, + } + } + + //new Data.CustomTextLoader(); + // importData.CustomSchema = dataSchema; + // + }; + + Data.TextLoader.Output imported = experiment.Add(importData); var numericalConcatenate = new Transforms.ColumnConcatenator(); numericalConcatenate.Data = imported.Data; numericalConcatenate.AddColumn("NumericalFeatures", "SqftLiving", "SqftLot", "SqftAbove", "SqftBasement", "Lat", "Long", "SqftLiving15", "SqftLot15"); diff --git a/test/Microsoft.ML.Tests/LearningPipelineTests.cs b/test/Microsoft.ML.Tests/LearningPipelineTests.cs index 4519fc5285..3ccc36255f 100644 --- a/test/Microsoft.ML.Tests/LearningPipelineTests.cs +++ b/test/Microsoft.ML.Tests/LearningPipelineTests.cs @@ -5,6 +5,7 @@ using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; using Microsoft.ML.TestFramework; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; @@ -66,7 +67,7 @@ public void TransformOnlyPipeline() { const string _dataPath = @"..\..\Data\breast-cancer.txt"; var pipeline = new LearningPipeline(); - pipeline.Add(new TextLoader(_dataPath, useHeader: false)); + pipeline.Add(new ML.Data.TextLoader(_dataPath).CreateFrom(useHeader: false)); pipeline.Add(new CategoricalHashOneHotVectorizer("F1") { HashBits = 10, Seed = 314489979, OutputKind = CategoricalTransformOutputKind.Bag }); var model = pipeline.Train(); var predictionModel = model.Predict(new InputData() { F1 = "5" }); @@ -95,7 +96,7 @@ public class Data public class Prediction { [ColumnName("PredictedLabel")] - public bool PredictedLabel; + public DvBool PredictedLabel; } [Fact] diff --git a/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs index 38ec6ce073..31fc4fdd6d 100644 --- a/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Data; using Microsoft.ML.Models; using Microsoft.ML.Runtime.Api; using Microsoft.ML.TestFramework; @@ -21,7 +22,7 @@ public void TrainAndPredictHousePriceModelTest() var pipeline = new LearningPipeline(); - pipeline.Add(new TextLoader(dataPath, useHeader: true, separator: ",")); + pipeline.Add(new TextLoader(dataPath).CreateFrom(useHeader: true, separator: ',')); pipeline.Add(new ColumnConcatenator(outputColumn: "NumericalFeatures", "SqftLiving", "SqftLot", "SqftAbove", "SqftBasement", "Lat", "Long", "SqftLiving15", "SqftLot15")); @@ -61,7 +62,7 @@ public void TrainAndPredictHousePriceModelTest() Assert.InRange(prediction.Price, 260_000, 330_000); string testDataPath = GetDataPath("kc_house_test.csv"); - var testData = new TextLoader(testDataPath, useHeader: true, separator: ","); + var testData = new TextLoader(testDataPath).CreateFrom(useHeader: true, separator: ','); var evaluator = new RegressionEvaluator(); RegressionMetrics metrics = evaluator.Evaluate(model, testData); diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs index de7c602047..5dcbf3a588 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Data; using Microsoft.ML.Models; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Trainers; @@ -19,7 +20,7 @@ public void TrainAndPredictIrisModelTest() var pipeline = new LearningPipeline(); - pipeline.Add(new TextLoader(dataPath, useHeader: false, separator: "tab")); + pipeline.Add(new TextLoader(dataPath).CreateFrom(useHeader: false)); pipeline.Add(new ColumnConcatenator(outputColumn: "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")); @@ -66,7 +67,7 @@ public void TrainAndPredictIrisModelTest() // Note: Testing against the same data set as a simple way to test evaluation. // This isn't appropriate in real-world scenarios. string testDataPath = GetDataPath("iris.txt"); - var testData = new TextLoader(testDataPath, useHeader: false, separator: "tab"); + var testData = new TextLoader(testDataPath).CreateFrom(useHeader: false); var evaluator = new ClassificationEvaluator(); evaluator.OutputTopKAcc = 3; diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs index 79cc2fc137..ebddc33b03 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Data; using Microsoft.ML.Models; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Trainers; @@ -19,7 +20,7 @@ public void TrainAndPredictIrisModelWithStringLabelTest() var pipeline = new LearningPipeline(); - pipeline.Add(new TextLoader(dataPath, useHeader: false, separator: ",")); + pipeline.Add(new TextLoader(dataPath).CreateFrom(useHeader: false, separator: ',')); pipeline.Add(new Dictionarizer("Label")); // "IrisPlantType" is used as "Label" because of column attribute name on the field. @@ -69,7 +70,7 @@ public void TrainAndPredictIrisModelWithStringLabelTest() // Note: Testing against the same data set as a simple way to test evaluation. // This isn't appropriate in real-world scenarios. string testDataPath = GetDataPath("iris.data"); - var testData = new TextLoader(testDataPath, useHeader: false, separator: ","); + var testData = new TextLoader(testDataPath).CreateFrom(useHeader: false, separator: ','); var evaluator = new ClassificationEvaluator(); evaluator.OutputTopKAcc = 3; diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index 608cbef144..80947644e9 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -2,9 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Data; using Microsoft.ML.Models; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; using System.Collections.Generic; @@ -23,7 +25,32 @@ public void TrainAndPredictSentimentModelTest() { string dataPath = GetDataPath(SentimentDataPath); var pipeline = new LearningPipeline(); - pipeline.Add(new TextLoader(dataPath, useHeader: true, separator: "tab")); + + pipeline.Add(new Data.TextLoader(dataPath) + { + Arguments = new TextLoaderArguments + { + Separator = new[] { '\t' }, + HasHeader = true, + Column = new[] + { + new TextLoaderColumn() + { + Name = "Label", + Source = new [] { new TextLoaderRange(0) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "SentimentText", + Source = new [] { new TextLoaderRange(1) }, + Type = Runtime.Data.DataKind.Text + } + } + } + }); + pipeline.Add(new TextFeaturizer("Features", "SentimentText") { KeepDiacritics = false, @@ -56,12 +83,34 @@ public void TrainAndPredictSentimentModelTest() IEnumerable predictions = model.Predict(sentiments); Assert.Equal(2, predictions.Count()); - Assert.False(predictions.ElementAt(0).Sentiment); - Assert.True(predictions.ElementAt(1).Sentiment); + Assert.True(predictions.ElementAt(0).Sentiment.IsFalse); + Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); string testDataPath = GetDataPath(SentimentTestPath); - var testData = new TextLoader(testDataPath, useHeader: true, separator: "tab"); - + var testData = new Data.TextLoader(testDataPath) + { + Arguments = new TextLoaderArguments + { + Separator = new[] { '\t' }, + HasHeader = true, + Column = new[] + { + new TextLoaderColumn() + { + Name = "Label", + Source = new [] { new TextLoaderRange(0) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "SentimentText", + Source = new [] { new TextLoaderRange(1) }, + Type = Runtime.Data.DataKind.Text + } + } + } + }; var evaluator = new BinaryClassificationEvaluator(); BinaryClassificationMetrics metrics = evaluator.Evaluate(model, testData); @@ -105,7 +154,7 @@ public class SentimentData public class SentimentPrediction { [ColumnName("PredictedLabel")] - public bool Sentiment; + public DvBool Sentiment; } } } diff --git a/test/Microsoft.ML.Tests/TextLoaderTests.cs b/test/Microsoft.ML.Tests/TextLoaderTests.cs index 96075b625a..40c0b6525f 100644 --- a/test/Microsoft.ML.Tests/TextLoaderTests.cs +++ b/test/Microsoft.ML.Tests/TextLoaderTests.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML; +using Microsoft.ML.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; @@ -24,19 +25,19 @@ public TextLoaderTests(ITestOutputHelper output) [Fact] public void ConstructorDoesntThrow() { - Assert.NotNull(new TextLoader("fakeFile.txt")); - Assert.NotNull(new TextLoader("fakeFile.txt", useHeader: true)); - Assert.NotNull(new TextLoader("fakeFile.txt", separator: "tab")); - Assert.NotNull(new TextLoader("fakeFile.txt", useHeader: false, separator: "tab")); - Assert.NotNull(new TextLoader("fakeFile.txt", useHeader: false, separator: "tab", false, false)); - Assert.NotNull(new TextLoader("fakeFile.txt", useHeader: false, separator: "tab", supportSparse: false)); - Assert.NotNull(new TextLoader("fakeFile.txt", useHeader: false, separator: "tab", allowQuotedStrings: false)); + Assert.NotNull(new Data.TextLoader("fakeFile.txt").CreateFrom()); + Assert.NotNull(new Data.TextLoader("fakeFile.txt").CreateFrom(useHeader:true)); + Assert.NotNull(new Data.TextLoader("fakeFile.txt").CreateFrom()); + Assert.NotNull(new Data.TextLoader("fakeFile.txt").CreateFrom(useHeader: false)); + Assert.NotNull(new Data.TextLoader("fakeFile.txt").CreateFrom(useHeader: false, supportSparse: false, trimWhitespace: false)); + Assert.NotNull(new Data.TextLoader("fakeFile.txt").CreateFrom(useHeader: false, supportSparse: false)); + Assert.NotNull(new Data.TextLoader("fakeFile.txt").CreateFrom(useHeader: false, allowQuotedStrings: false)); } [Fact] public void CanSuccessfullyApplyATransform() { - var loader = new TextLoader("fakeFile.txt"); + var loader = new Data.TextLoader("fakeFile.txt").CreateFrom(); using (var environment = new TlcEnvironment()) { @@ -53,7 +54,7 @@ public void CanSuccessfullyApplyATransform() public void CanSuccessfullyRetrieveQuotedData() { string dataPath = GetDataPath("QuotingData.csv"); - var loader = new TextLoader(dataPath, useHeader: true, separator: ",", allowQuotedStrings: true, supportSparse: false); + var loader = new Data.TextLoader(dataPath).CreateFrom(useHeader: true, separator: ',', allowQuotedStrings: true, supportSparse: false); using (var environment = new TlcEnvironment()) { @@ -111,7 +112,7 @@ public void CanSuccessfullyRetrieveQuotedData() public void CanSuccessfullyRetrieveSparseData() { string dataPath = GetDataPath("SparseData.txt"); - var loader = new TextLoader(dataPath, useHeader: true, separator: "tab", allowQuotedStrings: false, supportSparse: true); + var loader = new Data.TextLoader(dataPath).CreateFrom(useHeader: true, allowQuotedStrings: false, supportSparse: true); using (var environment = new TlcEnvironment()) { @@ -176,7 +177,7 @@ public void CanSuccessfullyRetrieveSparseData() public void CanSuccessfullyTrimSpaces() { string dataPath = GetDataPath("TrimData.csv"); - var loader = new TextLoader(dataPath, useHeader: true, separator: ",", allowQuotedStrings: false, supportSparse: false, trimWhitespace: true); + var loader = new Data.TextLoader(dataPath).CreateFrom(useHeader: true, separator: ',', allowQuotedStrings: false, supportSparse: false, trimWhitespace: true); using (var environment = new TlcEnvironment()) { @@ -223,7 +224,7 @@ public void CanSuccessfullyTrimSpaces() [Fact] public void ThrowsExceptionWithPropertyName() { - Exception ex = Assert.Throws( () => new TextLoader("fakefile.txt") ); + Exception ex = Assert.Throws( () => new Data.TextLoader("fakefile.txt").CreateFrom() ); Assert.StartsWith("String1 is missing ColumnAttribute", ex.Message); }