Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert TextNormalizer to estimator #1276

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/Transforms/HashTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1289,10 +1289,10 @@ public override void Process()
/// </summary>
public sealed class HashEstimator : IEstimator<HashTransformer>
{
public const int NumBitsMin = 1;
public const int NumBitsLim = 32;
internal const int NumBitsMin = 1;
internal const int NumBitsLim = 32;

public static class Defaults
internal static class Defaults
{
public const int HashBits = NumBitsLim - 1;
public const uint Seed = 314489979;
Expand Down
19 changes: 8 additions & 11 deletions src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,14 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData

env.CheckValue(args.Column, nameof(args.Column));
var cols = new ColumnInfo[args.Column.Length];
using (var ch = env.Start("ValidateArgs"))
for (int i = 0; i < cols.Length; i++)
{
for (int i = 0; i < cols.Length; i++)
{
var item = args.Column[i];
var item = args.Column[i];

cols[i] = new ColumnInfo(item.Source ?? item.Name,
item.Name,
item.Bag ?? args.Bag);
};
}
cols[i] = new ColumnInfo(item.Source ?? item.Name,
item.Name,
item.Bag ?? args.Bag);
};
return new KeyToVectorTransform(env, cols).MakeDataTransform(input);
}

Expand Down Expand Up @@ -727,7 +724,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
// Note that one input feature got expended to an one-hot vector.
opType = "ReduceSum";
var reduceNode = ctx.CreateNode(opType, encodedVariableName, dstVariableName, ctx.GetNodeName(opType), "");
reduceNode.AddAttribute("axes", new long[] { shape.Count - 1});
reduceNode.AddAttribute("axes", new long[] { shape.Count - 1 });
reduceNode.AddAttribute("keepdims", 0);
}
return true;
Expand All @@ -737,7 +734,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src

public sealed class KeyToVectorEstimator : TrivialEstimator<KeyToVectorTransform>
{
public static class Defaults
internal static class Defaults
{
public const bool Bag = false;
}
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Legacy/CSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16789,7 +16789,7 @@ public enum TextTransformLanguage
Japanese = 7
}

public enum TextNormalizerTransformCaseNormalizationMode
public enum TextNormalizerEstimatorCaseNormalizationMode
{
Lower = 0,
Upper = 1,
Expand Down Expand Up @@ -16877,7 +16877,7 @@ public void AddColumn(string name, params string[] source)
/// <summary>
/// Casing text using the rules of the invariant culture.
/// </summary>
public TextNormalizerTransformCaseNormalizationMode TextCase { get; set; } = TextNormalizerTransformCaseNormalizationMode.Lower;
public TextNormalizerEstimatorCaseNormalizationMode TextCase { get; set; } = TextNormalizerEstimatorCaseNormalizationMode.Lower;

/// <summary>
/// Whether to keep diacritical marks or remove them.
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Transforms/RffTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ private void TransformFeatures(ref VBuffer<float> src, ref VBuffer<float> dst, T
/// </summary>
public sealed class RffEstimator : IEstimator<RffTransform>
{
public static class Defaults
internal static class Defaults
{
public const int NewDim = 1000;
public const bool UseSin = false;
Expand Down
662 changes: 429 additions & 233 deletions src/Microsoft.ML.Transforms/Text/TextNormalizerTransform.cs

Large diffs are not rendered by default.

11 changes: 5 additions & 6 deletions src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
using System;
using System.Collections.Generic;
using static Microsoft.ML.Runtime.TextAnalytics.StopWordsRemoverTransform;
using static Microsoft.ML.Runtime.TextAnalytics.TextNormalizerTransform;

namespace Microsoft.ML.Transforms.Text
{
Expand Down Expand Up @@ -182,7 +181,7 @@ private sealed class OutPipelineColumn : Scalar<string>
{
public readonly Scalar<string> Input;

public OutPipelineColumn(Scalar<string> input, CaseNormalizationMode textCase, bool keepDiacritics, bool keepPunctuations, bool keepNumbers)
public OutPipelineColumn(Scalar<string> input, TextNormalizerEstimator.CaseNormalizationMode textCase, bool keepDiacritics, bool keepPunctuations, bool keepNumbers)
: base(new Reconciler(textCase, keepDiacritics, keepPunctuations, keepNumbers), input)
{
Input = input;
Expand All @@ -191,12 +190,12 @@ public OutPipelineColumn(Scalar<string> input, CaseNormalizationMode textCase, b

private sealed class Reconciler : EstimatorReconciler, IEquatable<Reconciler>
{
private readonly CaseNormalizationMode _textCase;
private readonly TextNormalizerEstimator.CaseNormalizationMode _textCase;
private readonly bool _keepDiacritics;
private readonly bool _keepPunctuations;
private readonly bool _keepNumbers;

public Reconciler(CaseNormalizationMode textCase, bool keepDiacritics, bool keepPunctuations, bool keepNumbers)
public Reconciler(TextNormalizerEstimator.CaseNormalizationMode textCase, bool keepDiacritics, bool keepPunctuations, bool keepNumbers)
{
_textCase = textCase;
_keepDiacritics = keepDiacritics;
Expand Down Expand Up @@ -225,7 +224,7 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
foreach (var outCol in toOutput)
pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol]));

return new TextNormalizer(env, pairs.ToArray(), _textCase, _keepDiacritics, _keepPunctuations, _keepNumbers);
return new TextNormalizerEstimator(env, pairs.ToArray(), _textCase, _keepDiacritics, _keepPunctuations, _keepNumbers);
}
}

Expand All @@ -238,7 +237,7 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
/// <param name="keepPunctuations">Whether to keep punctuation marks or remove them.</param>
/// <param name="keepNumbers">Whether to keep numbers or remove them.</param>
public static Scalar<string> NormalizeText(this Scalar<string> input,
CaseNormalizationMode textCase = CaseNormalizationMode.Lower,
TextNormalizerEstimator.CaseNormalizationMode textCase = TextNormalizerEstimator.CaseNormalizationMode.Lower,
bool keepDiacritics = false,
bool keepPunctuations = true,
bool keepNumbers = true) => new OutPipelineColumn(input, textCase, keepDiacritics, keepPunctuations, keepNumbers);
Expand Down
37 changes: 13 additions & 24 deletions src/Microsoft.ML.Transforms/Text/TextTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
Expand All @@ -18,6 +14,11 @@
using Microsoft.ML.Runtime.TextAnalytics;
using Microsoft.ML.StaticPipe;
using Microsoft.ML.StaticPipe.Runtime;
using Microsoft.ML.Transforms.Text;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

[assembly: LoadableClass(TextTransform.Summary, typeof(IDataTransform), typeof(TextTransform), typeof(TextTransform.Arguments), typeof(SignatureDataTransform),
TextTransform.UserName, "TextTransform", TextTransform.LoaderSignature)]
Expand All @@ -27,12 +28,8 @@

namespace Microsoft.ML.Runtime.Data
{
using StopWordsArgs = StopWordsRemoverTransform.Arguments;
using TextNormalizerArgs = TextNormalizerTransform.Arguments;
using CaseNormalizationMode = TextNormalizerEstimator.CaseNormalizationMode;
using StopWordsCol = StopWordsRemoverTransform.Column;
using TextNormalizerCol = TextNormalizerTransform.Column;
using StopWordsLang = StopWordsRemoverTransform.Language;
using CaseNormalizationMode = TextNormalizerTransform.CaseNormalizationMode;

// A transform that turns a collection of text documents into numerical feature vectors. The feature vectors are counts
// of (word or character) ngrams in a given text. It offers ngram hashing (finding the ngram token string name to feature
Expand Down Expand Up @@ -97,16 +94,16 @@ public sealed class Arguments : TransformInputBase
public IStopWordsRemoverFactory StopWordsRemover;

[Argument(ArgumentType.AtMostOnce, HelpText = "Casing text using the rules of the invariant culture.", ShortName = "case", SortOrder = 5)]
public CaseNormalizationMode TextCase = CaseNormalizationMode.Lower;
public CaseNormalizationMode TextCase = TextNormalizerEstimator.Defaults.TextCase;

[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to keep diacritical marks or remove them.", ShortName = "diac", SortOrder = 6)]
public bool KeepDiacritics;
public bool KeepDiacritics= TextNormalizerEstimator.Defaults.KeepDiacritics;

[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to keep punctuation marks or remove them.", ShortName = "punc", SortOrder = 7)]
public bool KeepPunctuations = true;
public bool KeepPunctuations = TextNormalizerEstimator.Defaults.KeepPunctuations;

[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to keep numbers or remove them.", ShortName = "num", SortOrder = 8)]
public bool KeepNumbers = true;
public bool KeepNumbers = TextNormalizerEstimator.Defaults.KeepNumbers;

[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to output the transformed text tokens as an additional column.", ShortName = "tokens,showtext,showTransformedText", SortOrder = 9)]
public bool OutputTokens;
Expand Down Expand Up @@ -323,24 +320,16 @@ public ITransformer Fit(IDataView input)

if (tparams.NeedsNormalizeTransform)
{
var xfCols = new TextNormalizerCol[textCols.Length];
var xfCols = new TextNormalizerTransform.ColumnInfo[textCols.Length];
string[] dstCols = new string[textCols.Length];
for (int i = 0; i < textCols.Length; i++)
{
dstCols[i] = GenerateColumnName(view.Schema, textCols[i], "TextNormalizer");
tempCols.Add(dstCols[i]);
xfCols[i] = new TextNormalizerCol() { Source = textCols[i], Name = dstCols[i] };
xfCols[i] = new TextNormalizerTransform.ColumnInfo(textCols[i], dstCols[i], tparams.TextCase, tparams.KeepDiacritics, tparams.KeepPunctuations, tparams.KeepNumbers);
}

view = new TextNormalizerTransform(h,
new TextNormalizerArgs()
{
Column = xfCols,
KeepDiacritics = tparams.KeepDiacritics,
KeepNumbers = tparams.KeepNumbers,
KeepPunctuations = tparams.KeepPunctuations,
TextCase = tparams.TextCase
}, view);
view = new TextNormalizerEstimator(h, xfCols).Fit(view).Transform(view);

textCols = dstCols;
}
Expand Down
85 changes: 0 additions & 85 deletions src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,8 @@
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.TextAnalytics;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using static Microsoft.ML.Runtime.TextAnalytics.LdaTransform;
using static Microsoft.ML.Runtime.TextAnalytics.StopWordsRemoverTransform;
using static Microsoft.ML.Runtime.TextAnalytics.TextNormalizerTransform;

namespace Microsoft.ML.Transforms
{
Expand Down Expand Up @@ -180,87 +176,6 @@ private static TransformWrapper MakeTransformer(IHostEnvironment env, (string in
}
}

/// <summary>
/// Text normalizer allows normalizing text by changing case (Upper/Lower case), removing diacritical marks, punctuation marks and/or numbers.
/// </summary>
public sealed class TextNormalizer : TrivialWrapperEstimator
{
/// <summary>
/// Normalizes incoming text in <paramref name="inputColumn"/> by changing case, removing diacritical marks, punctuation marks and/or numbers
/// and outputs new text as <paramref name="outputColumn"/>.
/// </summary>
/// <param name="env">The environment.</param>
/// <param name="inputColumn">The column containing text to normalize.</param>
/// <param name="outputColumn">The column containing output tokens. Null means <paramref name="inputColumn"/> is replaced.</param>
/// <param name="textCase">Casing text using the rules of the invariant culture.</param>
/// <param name="keepDiacritics">Whether to keep diacritical marks or remove them.</param>
/// <param name="keepPunctuations">Whether to keep punctuation marks or remove them.</param>
/// <param name="keepNumbers">Whether to keep numbers or remove them.</param>
public TextNormalizer(IHostEnvironment env,
string inputColumn,
string outputColumn = null,
CaseNormalizationMode textCase = CaseNormalizationMode.Lower,
bool keepDiacritics = false,
bool keepPunctuations = true,
bool keepNumbers = true)
: this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, textCase, keepDiacritics, keepPunctuations, keepNumbers)
{
}

/// <summary>
/// Normalizes incoming text in input columns by changing case, removing diacritical marks, punctuation marks and/or numbers
/// and outputs new text as output columns.
/// </summary>
/// <param name="env">The environment.</param>
/// <param name="columns">Pairs of columns to run the text normalization on.</param>
/// <param name="textCase">Casing text using the rules of the invariant culture.</param>
/// <param name="keepDiacritics">Whether to keep diacritical marks or remove them.</param>
/// <param name="keepPunctuations">Whether to keep punctuation marks or remove them.</param>
/// <param name="keepNumbers">Whether to keep numbers or remove them.</param>
public TextNormalizer(IHostEnvironment env,
(string input, string output)[] columns,
CaseNormalizationMode textCase = CaseNormalizationMode.Lower,
bool keepDiacritics = false,
bool keepPunctuations = true,
bool keepNumbers = true)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextNormalizer)),
MakeTransformer(env, columns, textCase, keepDiacritics, keepPunctuations, keepNumbers))
{
}

private static TransformWrapper MakeTransformer(IHostEnvironment env,
(string input, string output)[] columns,
CaseNormalizationMode textCase,
bool keepDiacritics,
bool keepPunctuations,
bool keepNumbers)
{
Contracts.AssertValue(env);
env.CheckNonEmpty(columns, nameof(columns));
foreach (var (input, output) in columns)
{
env.CheckValue(input, nameof(input));
env.CheckValue(output, nameof(input));
}

// Create arguments.
var args = new TextNormalizerTransform.Arguments
{
Column = columns.Select(x => new TextNormalizerTransform.Column { Source = x.input, Name = x.output }).ToArray(),
TextCase = textCase,
KeepDiacritics = keepDiacritics,
KeepPunctuations = keepPunctuations,
KeepNumbers = keepNumbers
};

// Create a valid instance of data.
var schema = new Schema(columns.Select(x => new Schema.Column(x.input, TextType.Instance, null)));
var emptyData = new EmptyDataView(env, schema);

return new TransformWrapper(env, new TextNormalizerTransform(env, args, emptyData));
}
}

/// <summary>
/// Produces a bag of counts of ngrams (sequences of consecutive words) in a given text.
/// It does so by building a dictionary of ngrams and using the id in the dictionary as the index in the bag.
Expand Down
Loading