From 16f7883933b56f8fd86077bf0fd262b24374e9d0 Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Tue, 3 Jul 2018 17:20:26 -0700 Subject: [PATCH 1/6] Added convenience constructors for set of transforms (Part 2). --- .../Transforms/ChooseColumnsTransform.cs | 12 +++++ .../Transforms/ConvertTransform.cs | 29 +++++++++++ .../Transforms/DropSlotsTransform.cs | 12 +++++ .../Transforms/GenerateNumberTransform.cs | 22 ++++++++- .../Transforms/HashTransform.cs | 49 +++++++++++++++++-- .../Transforms/KeyToValueTransform.cs | 13 +++++ .../Transforms/KeyToVectorTransform.cs | 24 ++++++++- .../Transforms/LabelConvertTransform.cs | 12 +++++ .../Transforms/LabelIndicatorTransform.cs | 24 ++++++++- .../Transforms/RangeFilter.cs | 13 +++++ .../Transforms/ShuffleTransform.cs | 30 ++++++++++-- .../Transforms/SkipTakeFilter.cs | 28 ++++++++++- .../Transforms/TermTransform.cs | 27 +++++++++- 13 files changed, 280 insertions(+), 15 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs index 5efc9264f1..71482036f2 100644 --- a/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs @@ -442,6 +442,18 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "ChooseColumns"; + /// + /// Convenience constructor for public facing API. + /// + /// Host Environment. + /// Input . This is the output from previous transform or loader. + /// Name of the output column. + /// Name of the selected column. If this is null '' will be used. + public ChooseColumnsTransform(IHostEnvironment env, IDataView input, string name, string source = null) + : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input) + { + } + /// /// Public constructor corresponding to SignatureDataTransform. /// diff --git a/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs index c37f0a6983..89b836cd34 100644 --- a/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs @@ -108,6 +108,16 @@ public bool TryUnparse(StringBuilder sb) public class Arguments : TransformInputBase { + public Arguments() + { + + } + + public Arguments(string name, string source) + { + Column = new[] { new Column() { Source = source ?? name, Name = name } }; + } + [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:type:src)", ShortName = "col", SortOrder = 1)] public Column[] Column; @@ -169,6 +179,25 @@ private static VersionInfo GetVersionInfo() // This is parallel to Infos. private readonly ColInfoEx[] _exes; + /// + /// Convenience constructor for public facing API. + /// + /// Host Environment. + /// Input . This is the output from previous transform or loader. + /// Name of the output column. + /// Name of the column to be converted. If this is null '' will be used. + /// The expected type of the converted column. + /// For a key column, this defines the range of values. + public ConvertTransform(IHostEnvironment env, + IDataView input, + string name, + string source = null, + DataKind? resultType = null, + KeyRange keyRange = null) + : this(env, new Arguments(name, source) { ResultType = resultType, KeyRange = keyRange }, input) + { + } + public ConvertTransform(IHostEnvironment env, Arguments args, IDataView input) : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, null) diff --git a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs index 9a40f404ea..6a4621fadc 100644 --- a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs @@ -216,6 +216,18 @@ public ColInfoEx(SlotDropper slotDropper, bool suppressed, ColumnType typeDst, i private readonly ColInfoEx[] _exes; + /// + /// Convenience constructor for public facing API. + /// + /// Host Environment. + /// Input . This is the output from previous transform or loader. + /// Name of the output column. + /// Name of the input column. If this is null '' will be used. + public DropSlotsTransform(IHostEnvironment env, IDataView input, string name, string source = null) + : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input) + { + } + /// /// Public constructor corresponding to SignatureDataTransform. /// diff --git a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs index f80589bdab..713f85f9df 100644 --- a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs @@ -77,16 +77,22 @@ private bool TryParse(string str) } } + private static class Defaults + { + public const bool UseCounter = false; + public const uint Seed = 42; + } + public sealed class Arguments : TransformInputBase { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:seed)", ShortName = "col", SortOrder = 1)] public Column[] Column; [Argument(ArgumentType.AtMostOnce, HelpText = "Use an auto-incremented integer starting at zero instead of a random number", ShortName = "cnt")] - public bool UseCounter; + public bool UseCounter = Defaults.UseCounter; [Argument(ArgumentType.AtMostOnce, HelpText = "The random seed")] - public uint Seed = 42; + public uint Seed = Defaults.Seed; } private sealed class Bindings : ColumnBindingsBase @@ -250,6 +256,18 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "GenerateNumber"; + /// + /// Convenience constructor for public facing API. + /// + /// Host Environment. + /// Input . This is the output from previous transform or loader. + /// Name of the output column. + /// Use an auto-incremented integer starting at zero instead of a random number. + public GenerateNumberTransform(IHostEnvironment env, IDataView input, string name, bool useCounter = Defaults.UseCounter) + : this(env, new Arguments() { Column = new[] { new Column() { Name = name } }, UseCounter = useCounter }, input) + { + } + /// /// Public constructor corresponding to SignatureDataTransform. /// diff --git a/src/Microsoft.ML.Data/Transforms/HashTransform.cs b/src/Microsoft.ML.Data/Transforms/HashTransform.cs index ca959069f7..59f5029cfb 100644 --- a/src/Microsoft.ML.Data/Transforms/HashTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/HashTransform.cs @@ -33,26 +33,48 @@ public sealed class HashTransform : OneToOneTransformBase, ITransformTemplate public const int NumBitsMin = 1; public const int NumBitsLim = 32; + private static class Defaults + { + public const int HashBits = NumBitsLim - 1; + public const uint Seed = 314489979; + public const bool Ordered = false; + public const int InvertHash = 0; + } + public sealed class Arguments { + public Arguments() + { + + } + + public Arguments(string name, string source) + { + Column = new[] { new Column(){ + Source = source ?? name, + Name = name + } + }; + } + [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)] public Column[] Column; [Argument(ArgumentType.AtMostOnce, HelpText = "Number of bits to hash into. Must be between 1 and 31, inclusive", ShortName = "bits", SortOrder = 2)] - public int HashBits = NumBitsLim - 1; + public int HashBits = Defaults.HashBits; [Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")] - public uint Seed = 314489979; + public uint Seed = Defaults.Seed; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether the position of each term should be included in the hash", ShortName = "ord")] - public bool Ordered; + public bool Ordered = Defaults.Ordered; [Argument(ArgumentType.AtMostOnce, HelpText = "Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.", ShortName = "ih")] - public int InvertHash; + public int InvertHash = Defaults.InvertHash; } public sealed class Column : OneToOneColumn @@ -234,6 +256,25 @@ public override void Save(ModelSaveContext ctx) TextModelHelper.SaveAll(Host, ctx, Infos.Length, _keyValues); } + /// + /// Convenience constructor for public facing API. + /// + /// Host Environment. + /// Input . This is the output from previous transform or loader. + /// Name of the output column. + /// Name of the column to be transformed. If this is null '' will be used. + /// Number of bits to hash into. Must be between 1 and 31, inclusive. + /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit. + public HashTransform(IHostEnvironment env, + IDataView input, + string name, + string source = null, + int hashBits = Defaults.HashBits, + int invertHash = Defaults.InvertHash) + : this(env, new Arguments(name, source) { HashBits = hashBits, InvertHash = invertHash }, input) + { + } + public HashTransform(IHostEnvironment env, Arguments args, IDataView input) : base(Contracts.CheckRef(env, nameof(env)), RegistrationName, env.CheckRef(args, nameof(args)).Column, input, TestType) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs index 165ab7e0df..7c1fa19c10 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs @@ -73,6 +73,19 @@ private static VersionInfo GetVersionInfo() private readonly ColumnType[] _types; private KeyToValueMap[] _kvMaps; + /// + /// Convenience constructor for public facing API. + /// + /// Host Environment. + /// Input . This is the output from previous transform or loader. + /// Name of the output column. + /// Name of the input column. If this is null '' will be used. + public KeyToValueTransform(IHostEnvironment env, IDataView input, string name, string source = null) + : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input) + { + } + + /// /// Public constructor corresponding to SignatureDataTransform. /// diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index bffbaa881c..a0c24413b1 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -70,6 +70,11 @@ public bool TryUnparse(StringBuilder sb) } } + private static class Defaults + { + public const bool Bag = false; + } + public sealed class Arguments { [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)] @@ -77,7 +82,7 @@ public sealed class Arguments [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to combine multiple indicator vectors into a single bag vector instead of concatenating them. This is only relevant when the input is a vector.")] - public bool Bag; + public bool Bag = Defaults.Bag; } internal const string Summary = "Converts a key column to an indicator vector."; @@ -112,6 +117,23 @@ private static VersionInfo GetVersionInfo() private readonly bool[] _concat; private readonly VectorType[] _types; + /// + /// Convenience constructor for public facing API. + /// + /// Host Environment. + /// Input . This is the output from previous transform or loader. + /// Name of the output column. + /// Name of the input column. If this is null '' will be used. + /// Whether to combine multiple indicator vectors into a single bag vector instead of concatenating them. This is only relevant when the input is a vector. + public KeyToVectorTransform(IHostEnvironment env, + IDataView input, + string name, + string source = null, + bool bag = Defaults.Bag) + : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, Bag = bag }, input) + { + } + /// /// Public constructor corresponding to SignatureDataTransform. /// diff --git a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs index 5329d89a57..8817833f40 100644 --- a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs @@ -64,6 +64,18 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "LabelConvert"; private VectorType _slotType; + /// + /// Convenience constructor for public facing API. + /// + /// Host Environment. + /// Input . This is the output from previous transform or loader. + /// Name of the output column. + /// Name of the input column. If this is null '' will be used. + public LabelConvertTransform(IHostEnvironment env, IDataView input, string name, string source = null) + : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input) + { + } + public LabelConvertTransform(IHostEnvironment env, Arguments args, IDataView input) : base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, input, RowCursorUtils.TestGetLabelGetter) { diff --git a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs index 81a91b5f17..e171468048 100644 --- a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs @@ -64,13 +64,18 @@ public bool TryUnparse(StringBuilder sb) } } + private static class Defaults + { + public const int ClassIndex = 0; + } + public sealed class Arguments : TransformInputBase { [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)] public Column[] Column; [Argument(ArgumentType.AtMostOnce, HelpText = "Label of the positive class.", ShortName = "index")] - public int ClassIndex; + public int ClassIndex = Defaults.ClassIndex; } public static LabelIndicatorTransform Create(IHostEnvironment env, @@ -111,6 +116,23 @@ private static string TestIsMulticlassLabel(ColumnType type) return $"Label column type is not supported for binary remapping: {type}. Supported types: key, float, double."; } + /// + /// Convenience constructor for public facing API. + /// + /// Host Environment. + /// Input . This is the output from previous transform or loader. + /// Name of the output column. + /// Name of the input column. If this is null '' will be used. + /// Label of the positive class. + public LabelIndicatorTransform(IHostEnvironment env, + IDataView input, + string name, + string source = null, + int classIndex = Defaults.ClassIndex) + : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, ClassIndex = classIndex }, input) + { + } + public LabelIndicatorTransform(IHostEnvironment env, Arguments args, IDataView input) : base(env, LoadName, Contracts.CheckRef(args, nameof(args)).Column, input, TestIsMulticlassLabel) diff --git a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs index b9ab10f4c1..589a635aff 100644 --- a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs @@ -77,6 +77,19 @@ private static VersionInfo GetVersionInfo() private readonly bool _includeMin; private readonly bool _includeMax; + /// + /// Convenience constructor for public facing API. + /// + /// Host Environment. + /// Input . This is the output from previous transform or loader. + /// Name of the input column. + /// Minimum value (0 to 1 for key types). + /// Maximum value (0 to 1 for key types). + public RangeFilter(IHostEnvironment env, IDataView input, string column, Double? minimum = null, Double? maximum = null) + : this(env, new Arguments() { Column = column, Min = minimum, Max = maximum }, input) + { + } + public RangeFilter(IHostEnvironment env, Arguments args, IDataView input) : base(env, RegistrationName, input) { diff --git a/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs b/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs index 37e52ee2da..5080208335 100644 --- a/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs @@ -33,18 +33,25 @@ namespace Microsoft.ML.Runtime.Data /// public sealed class ShuffleTransform : RowToRowTransformBase { + private static class Defaults + { + public const int PoolRows = 1000; + public const bool PoolOnly = false; + public const bool ForceShuffle = false; + } + public sealed class Arguments { // REVIEW: A more intelligent heuristic, based on the expected size of the inputs, perhaps? [Argument(ArgumentType.LastOccurenceWins, HelpText = "The pool will have this many rows", ShortName = "rows")] - public int PoolRows = 1000; + public int PoolRows = Defaults.PoolRows; // REVIEW: Come up with a better way to specify the desired set of functionality. [Argument(ArgumentType.LastOccurenceWins, HelpText = "If true, the transform will not attempt to shuffle the input cursor but only shuffle based on the pool. This parameter has no effect if the input data was not itself shufflable.", ShortName = "po")] - public bool PoolOnly; + public bool PoolOnly = Defaults.PoolOnly; [Argument(ArgumentType.LastOccurenceWins, HelpText = "If true, the transform will always provide a shuffled view.", ShortName = "force")] - public bool ForceShuffle; + public bool ForceShuffle = Defaults.ForceShuffle; [Argument(ArgumentType.LastOccurenceWins, HelpText = "If true, the transform will always shuffle the input. The default value is the same as forceShuffle.", ShortName = "forceSource")] public bool? ForceShuffleSource; @@ -79,6 +86,23 @@ private static VersionInfo GetVersionInfo() // know how to copy other types of values. private readonly IDataView _subsetInput; + /// + /// Convenience constructor for public facing API. + /// + /// Host Environment. + /// Input . This is the output from previous transform or loader. + /// The pool will have this many rows + /// If true, the transform will not attempt to shuffle the input cursor but only shuffle based on the pool. This parameter has no effect if the input data was not itself shufflable. + /// If true, the transform will always provide a shuffled view. + public ShuffleTransform(IHostEnvironment env, + IDataView input, + int poolRows = Defaults.PoolRows, + bool poolOnly = Defaults.PoolOnly, + bool forceShuffle = Defaults.ForceShuffle) + : this(env, new Arguments() { PoolRows = poolRows, PoolOnly = poolOnly, ForceShuffle = forceShuffle }, input) + { + } + /// /// Public constructor corresponding to SignatureDataTransform. /// diff --git a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs index 278f3ee418..5ae40b7c2a 100644 --- a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs @@ -60,13 +60,13 @@ public sealed class Arguments : TransformInputBase public sealed class TakeArguments : TransformInputBase { [Argument(ArgumentType.Required, HelpText = Arguments.TakeHelp, ShortName = "c,n,t", SortOrder = 1)] - public long Count = long.MaxValue; + public long Count = Arguments.DefaultTake; } public sealed class SkipArguments : TransformInputBase { [Argument(ArgumentType.Required, HelpText = Arguments.SkipHelp, ShortName = "c,n,s", SortOrder = 1)] - public long Count = 0; + public long Count = Arguments.DefaultSkip; } private static VersionInfo GetVersionInfo() @@ -108,6 +108,18 @@ public static SkipTakeFilter Create(IHostEnvironment env, Arguments args, IDataV return new SkipTakeFilter(skip, take, env, input); } + /// + /// A helper method to create 'SkipFilter' for public facing API. + /// + /// Host Environment. + /// >Input . This is the output from previous transform or loader. + /// Number of rows to skip + /// + public static SkipTakeFilter CreateSkipFilter(IHostEnvironment env, IDataView input, long count = Arguments.DefaultSkip) + { + return Create(env, new SkipArguments() { Count = count }, input); + } + public static SkipTakeFilter Create(IHostEnvironment env, SkipArguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); @@ -116,6 +128,18 @@ public static SkipTakeFilter Create(IHostEnvironment env, SkipArguments args, ID return new SkipTakeFilter(args.Count, Arguments.DefaultTake, env, input); } + /// + /// A helper method to create 'TakeFilter' for public facing API. + /// + /// Host Environment. + /// >Input . This is the output from previous transform or loader. + /// Number of rows to take + /// + public static SkipTakeFilter CreateTakeFilter(IHostEnvironment env, IDataView input, long count = Arguments.DefaultTake) + { + return Create(env, new TakeArguments() { Count = count }, input); + } + public static SkipTakeFilter Create(IHostEnvironment env, TakeArguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs index da7442f90e..194405c6a3 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs @@ -97,10 +97,16 @@ public enum SortOrder : byte // other things, like case insensitive (where appropriate), culturally aware, etc.? } + private static class Defaults + { + public const int MaxNumTerms = 1000000; + public const SortOrder Sort = SortOrder.Occurrence; + } + public abstract class ArgumentsBase : TransformInputBase { [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of terms to keep per column when auto-training", ShortName = "max", SortOrder = 5)] - public int MaxNumTerms = 1000000; + public int MaxNumTerms = Defaults.MaxNumTerms; [Argument(ArgumentType.AtMostOnce, HelpText = "Comma separated list of terms", SortOrder = 105, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)] public string Terms; @@ -124,7 +130,7 @@ public abstract class ArgumentsBase : TransformInputBase // REVIEW: Should we always sort? Opinions are mixed. See work item 7797429. [Argument(ArgumentType.AtMostOnce, HelpText = "How items should be ordered when vectorized. By default, they will be in the order encountered. " + "If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a').", SortOrder = 113)] - public SortOrder Sort = SortOrder.Occurrence; + public SortOrder Sort = Defaults.Sort; // REVIEW: Should we do this here, or correct the various pieces of code here and in MRS etc. that // assume key-values will be string? Once we correct these things perhaps we can see about removing it. @@ -196,6 +202,23 @@ private CodecFactory CodecFactory public override bool CanSavePfa => true; public override bool CanSaveOnnx => true; + /// + /// Convenience constructor for public facing API. + /// + /// Host Environment. + /// Input . This is the output from previous transform or loader. + /// Name of the output column. + /// Name of the column to be transformed. If this is null '' will be used. + /// Maximum number of terms to keep per column when auto-training. + public TermTransform(IHostEnvironment env, + IDataView input, + string name, + string source = null, + int maxNumTerms = Defaults.MaxNumTerms) + : this(env, new Arguments() { Column = new[] { new Column() { Name = name, Source = source ?? name } }, MaxNumTerms = maxNumTerms }, input) + { + } + /// /// Public constructor corresponding to SignatureDataTransform. /// From 9c8ca3d498828ad6f2f8e63599a7900f75b58073 Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Thu, 5 Jul 2018 13:31:38 -0700 Subject: [PATCH 2/6] Addressed reviewers' comments. --- .../Transforms/ChooseColumnsTransform.cs | 21 ++++++-- .../Transforms/ConvertTransform.cs | 16 +----- .../Transforms/DropSlotsTransform.cs | 12 ----- .../Transforms/HashTransform.cs | 18 ++----- .../Transforms/SkipTakeFilter.cs | 50 ++++++++++--------- .../Transforms/TermTransform.cs | 7 ++- 6 files changed, 53 insertions(+), 71 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs index 71482036f2..96971c297b 100644 --- a/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs @@ -58,6 +58,20 @@ public bool TryUnparse(StringBuilder sb) public sealed class Arguments { + public Arguments() + { + + } + + public Arguments(params string[] columns) + { + Column = new Column[columns.Length]; + for (int i = 0; i < columns.Length; i++) + { + Column[i] = new Column() { Source = columns[i], Name = columns[i] }; + } + } + [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)] public Column[] Column; @@ -447,10 +461,9 @@ private static VersionInfo GetVersionInfo() /// /// Host Environment. /// Input . This is the output from previous transform or loader. - /// Name of the output column. - /// Name of the selected column. If this is null '' will be used. - public ChooseColumnsTransform(IHostEnvironment env, IDataView input, string name, string source = null) - : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input) + /// Names of the columns to choose. + public ChooseColumnsTransform(IHostEnvironment env, IDataView input, params string[] columns) + : this(env, new Arguments(columns), input) { } diff --git a/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs index 89b836cd34..2a7b20e95c 100644 --- a/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs @@ -108,16 +108,6 @@ public bool TryUnparse(StringBuilder sb) public class Arguments : TransformInputBase { - public Arguments() - { - - } - - public Arguments(string name, string source) - { - Column = new[] { new Column() { Source = source ?? name, Name = name } }; - } - [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:type:src)", ShortName = "col", SortOrder = 1)] public Column[] Column; @@ -187,14 +177,12 @@ private static VersionInfo GetVersionInfo() /// Name of the output column. /// Name of the column to be converted. If this is null '' will be used. /// The expected type of the converted column. - /// For a key column, this defines the range of values. public ConvertTransform(IHostEnvironment env, IDataView input, string name, string source = null, - DataKind? resultType = null, - KeyRange keyRange = null) - : this(env, new Arguments(name, source) { ResultType = resultType, KeyRange = keyRange }, input) + DataKind? resultType = DataKind.Num) + : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, ResultType = resultType }, input) { } diff --git a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs index 6a4621fadc..9a40f404ea 100644 --- a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs @@ -216,18 +216,6 @@ public ColInfoEx(SlotDropper slotDropper, bool suppressed, ColumnType typeDst, i private readonly ColInfoEx[] _exes; - /// - /// Convenience constructor for public facing API. - /// - /// Host Environment. - /// Input . This is the output from previous transform or loader. - /// Name of the output column. - /// Name of the input column. If this is null '' will be used. - public DropSlotsTransform(IHostEnvironment env, IDataView input, string name, string source = null) - : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input) - { - } - /// /// Public constructor corresponding to SignatureDataTransform. /// diff --git a/src/Microsoft.ML.Data/Transforms/HashTransform.cs b/src/Microsoft.ML.Data/Transforms/HashTransform.cs index 59f5029cfb..23ba5592b7 100644 --- a/src/Microsoft.ML.Data/Transforms/HashTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/HashTransform.cs @@ -43,20 +43,6 @@ private static class Defaults public sealed class Arguments { - public Arguments() - { - - } - - public Arguments(string name, string source) - { - Column = new[] { new Column(){ - Source = source ?? name, - Name = name - } - }; - } - [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)] public Column[] Column; @@ -271,7 +257,9 @@ public HashTransform(IHostEnvironment env, string source = null, int hashBits = Defaults.HashBits, int invertHash = Defaults.InvertHash) - : this(env, new Arguments(name, source) { HashBits = hashBits, InvertHash = invertHash }, input) + : this(env, new Arguments() { + Column = new[] { new Column() { Source = source ?? name, Name = name } }, + HashBits = hashBits, InvertHash = invertHash }, input) { } diff --git a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs index 5ae40b7c2a..3dbb467eae 100644 --- a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs @@ -108,18 +108,6 @@ public static SkipTakeFilter Create(IHostEnvironment env, Arguments args, IDataV return new SkipTakeFilter(skip, take, env, input); } - /// - /// A helper method to create 'SkipFilter' for public facing API. - /// - /// Host Environment. - /// >Input . This is the output from previous transform or loader. - /// Number of rows to skip - /// - public static SkipTakeFilter CreateSkipFilter(IHostEnvironment env, IDataView input, long count = Arguments.DefaultSkip) - { - return Create(env, new SkipArguments() { Count = count }, input); - } - public static SkipTakeFilter Create(IHostEnvironment env, SkipArguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); @@ -128,18 +116,6 @@ public static SkipTakeFilter Create(IHostEnvironment env, SkipArguments args, ID return new SkipTakeFilter(args.Count, Arguments.DefaultTake, env, input); } - /// - /// A helper method to create 'TakeFilter' for public facing API. - /// - /// Host Environment. - /// >Input . This is the output from previous transform or loader. - /// Number of rows to take - /// - public static SkipTakeFilter CreateTakeFilter(IHostEnvironment env, IDataView input, long count = Arguments.DefaultTake) - { - return Create(env, new TakeArguments() { Count = count }, input); - } - public static SkipTakeFilter Create(IHostEnvironment env, TakeArguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); @@ -294,4 +270,30 @@ protected override bool MoveManyCore(long count) } } } + + public static class SkipFilter + { + /// + /// A helper method to create for public facing API. + /// + /// Host Environment. + /// >Input . This is the output from previous transform or loader. + /// Number of rows to skip + public static IDataTransform Create(IHostEnvironment env, IDataView input, long count = SkipTakeFilter.Arguments.DefaultSkip) + => SkipTakeFilter.Create(env, new SkipTakeFilter.SkipArguments() { Count = count }, input); + } + + public static class TakeFilter + { + + + /// + /// A helper method to create for public facing API. + /// + /// Host Environment. + /// >Input . This is the output from previous transform or loader. + /// Number of rows to take + public static IDataTransform Create(IHostEnvironment env, IDataView input, long count = SkipTakeFilter.Arguments.DefaultTake) + => SkipTakeFilter.Create(env, new SkipTakeFilter.TakeArguments() { Count = count }, input); + } } \ No newline at end of file diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs index 194405c6a3..842f608d98 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs @@ -210,12 +210,15 @@ private CodecFactory CodecFactory /// Name of the output column. /// Name of the column to be transformed. If this is null '' will be used. /// Maximum number of terms to keep per column when auto-training. + /// How items should be ordered when vectorized. By default, they will be in the order encountered. + /// If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a'). public TermTransform(IHostEnvironment env, IDataView input, string name, string source = null, - int maxNumTerms = Defaults.MaxNumTerms) - : this(env, new Arguments() { Column = new[] { new Column() { Name = name, Source = source ?? name } }, MaxNumTerms = maxNumTerms }, input) + int maxNumTerms = Defaults.MaxNumTerms, + SortOrder sort = Defaults.Sort) + : this(env, new Arguments() { Column = new[] { new Column() { Name = name, Source = source ?? name } }, MaxNumTerms = maxNumTerms, Sort = sort }, input) { } From 3f7713dfedb003d4701fcc4d5be135e4b100cd34 Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Thu, 5 Jul 2018 16:53:00 -0700 Subject: [PATCH 3/6] Addressed reviewers' comments. --- .../Transforms/ConvertTransform.cs | 6 +++--- .../Transforms/LabelIndicatorTransform.cs | 13 ++++--------- src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs | 4 ++-- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs index 2a7b20e95c..52005c7558 100644 --- a/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs @@ -174,14 +174,14 @@ private static VersionInfo GetVersionInfo() /// /// Host Environment. /// Input . This is the output from previous transform or loader. + /// The expected type of the converted column. /// Name of the output column. /// Name of the column to be converted. If this is null '' will be used. - /// The expected type of the converted column. public ConvertTransform(IHostEnvironment env, IDataView input, + DataKind resultType, string name, - string source = null, - DataKind? resultType = DataKind.Num) + string source = null) : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, ResultType = resultType }, input) { } diff --git a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs index e171468048..a7672b5a1c 100644 --- a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs @@ -64,18 +64,13 @@ public bool TryUnparse(StringBuilder sb) } } - private static class Defaults - { - public const int ClassIndex = 0; - } - public sealed class Arguments : TransformInputBase { [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)] public Column[] Column; [Argument(ArgumentType.AtMostOnce, HelpText = "Label of the positive class.", ShortName = "index")] - public int ClassIndex = Defaults.ClassIndex; + public int ClassIndex; } public static LabelIndicatorTransform Create(IHostEnvironment env, @@ -121,14 +116,14 @@ private static string TestIsMulticlassLabel(ColumnType type) /// /// Host Environment. /// Input . This is the output from previous transform or loader. + /// Label of the positive class. /// Name of the output column. /// Name of the input column. If this is null '' will be used. - /// Label of the positive class. public LabelIndicatorTransform(IHostEnvironment env, IDataView input, + int classIndex, string name, - string source = null, - int classIndex = Defaults.ClassIndex) + string source = null) : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, ClassIndex = classIndex }, input) { } diff --git a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs index 3dbb467eae..20390c32a7 100644 --- a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs @@ -274,7 +274,7 @@ protected override bool MoveManyCore(long count) public static class SkipFilter { /// - /// A helper method to create for public facing API. + /// A helper method to create'SkipFilter' transform by skipping the number of rows defined by the parameter. /// /// Host Environment. /// >Input . This is the output from previous transform or loader. @@ -288,7 +288,7 @@ public static class TakeFilter /// - /// A helper method to create for public facing API. + /// A helper method to create 'TakeFilter' transform by taking the top rows defined by the parameter. /// /// Host Environment. /// >Input . This is the output from previous transform or loader. From 1414750bc9b14725a8dc4d95119a7b0200800098 Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Mon, 9 Jul 2018 11:06:01 -0700 Subject: [PATCH 4/6] Made ChooseColumnsTransform.Arguments constructor internal. --- src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs index 96971c297b..1459f55cab 100644 --- a/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs @@ -63,7 +63,7 @@ public Arguments() } - public Arguments(params string[] columns) + internal Arguments(params string[] columns) { Column = new Column[columns.Length]; for (int i = 0; i < columns.Length; i++) From 1580c203031478945725272408100f39676faab2 Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Tue, 10 Jul 2018 12:25:17 -0700 Subject: [PATCH 5/6] Improved comments on SkipTakeFilter helper methods. --- src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs index 20390c32a7..8e730447b3 100644 --- a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs @@ -274,7 +274,8 @@ protected override bool MoveManyCore(long count) public static class SkipFilter { /// - /// A helper method to create'SkipFilter' transform by skipping the number of rows defined by the parameter. + /// A helper method to create transform for skipping the number of rows defined by the parameter. + /// when created with behaves as 'SkipFilter'. /// /// Host Environment. /// >Input . This is the output from previous transform or loader. @@ -288,7 +289,8 @@ public static class TakeFilter /// - /// A helper method to create 'TakeFilter' transform by taking the top rows defined by the parameter. + /// A helper method to create transform by taking the top rows defined by the parameter. + /// when created with behaves as 'TakeFilter'. /// /// Host Environment. /// >Input . This is the output from previous transform or loader. From 6dc2bc02f670c83fe72158716c69c0f2e9021bea Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Wed, 11 Jul 2018 10:53:41 -0700 Subject: [PATCH 6/6] Addressed reviewers' comments. --- src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs | 2 -- src/Microsoft.ML.Data/Transforms/TermTransform.cs | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs index 8e730447b3..bfd3522f73 100644 --- a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs @@ -286,8 +286,6 @@ public static IDataTransform Create(IHostEnvironment env, IDataView input, long public static class TakeFilter { - - /// /// A helper method to create transform by taking the top rows defined by the parameter. /// when created with behaves as 'TakeFilter'. diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs index f89e16ed55..7591179588 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs @@ -218,7 +218,7 @@ public TermTransform(IHostEnvironment env, string source = null, int maxNumTerms = Defaults.MaxNumTerms, SortOrder sort = Defaults.Sort) - : this(env, new Arguments() { Column = new[] { new Column() { Name = name, Source = source ?? name } }, MaxNumTerms = maxNumTerms, Sort = sort }, input) + : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, MaxNumTerms = maxNumTerms, Sort = sort }, input) { }