diff --git a/src/Microsoft.ML.LightGBM/LightGbmArguments.cs b/src/Microsoft.ML.LightGBM/LightGbmArguments.cs index e6c5ed360a..0612135ce3 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmArguments.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmArguments.cs @@ -32,7 +32,7 @@ public interface ISupportBoosterParameterFactory : IComponentFactory res); + void UpdateParameters(Dictionary res); } /// @@ -54,7 +54,7 @@ protected BoosterParameter(TArgs args) /// /// Update the parameters by specific Booster, will update parameters into "res" directly. /// - public virtual void UpdateParameters(Dictionary res) + public virtual void UpdateParameters(Dictionary res) { FieldInfo[] fields = Args.GetType().GetFields(); foreach (var field in fields) @@ -163,7 +163,7 @@ public TreeBooster(Arguments args) Contracts.CheckUserArg(Args.ScalePosWeight > 0 && Args.ScalePosWeight <= 1, nameof(Args.ScalePosWeight), "must be in (0,1]."); } - public override void UpdateParameters(Dictionary res) + public override void UpdateParameters(Dictionary res) { base.UpdateParameters(res); res["boosting_type"] = Name; @@ -207,7 +207,7 @@ public DartBooster(Arguments args) Contracts.CheckUserArg(Args.SkipDrop >= 0 && Args.SkipDrop < 1, nameof(Args.SkipDrop), "must be in [0,1)."); } - public override void UpdateParameters(Dictionary res) + public override void UpdateParameters(Dictionary res) { base.UpdateParameters(res); res["boosting_type"] = Name; @@ -244,7 +244,7 @@ public GossBooster(Arguments args) Contracts.Check(Args.TopRate + Args.OtherRate <= 1, "Sum of topRate and otherRate cannot be larger than 1."); } - public override void UpdateParameters(Dictionary res) + public override void UpdateParameters(Dictionary res) { base.UpdateParameters(res); res["boosting_type"] = Name; @@ -355,11 +355,11 @@ public enum EvalMetricType [Argument(ArgumentType.Multiple, HelpText = "Parallel LightGBM Learning Algorithm", ShortName = "parag")] public ISupportParallel ParallelTrainer = new SingleTrainerFactory(); - internal Dictionary ToDictionary(IHost host) + internal Dictionary ToDictionary(IHost host) { Contracts.CheckValue(host, nameof(host)); Contracts.CheckUserArg(MaxBin > 0, nameof(MaxBin), "must be > 0."); - Dictionary res = new Dictionary(); + Dictionary res = new Dictionary(); var boosterParams = Booster.CreateComponent(host); boosterParams.UpdateParameters(res); diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index 5acd90e83d..3b840367ac 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Globalization; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; @@ -130,9 +131,9 @@ protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float protected override void GetDefaultParameters(IChannel ch, int numRow, bool hasCategorical, int totalCats, bool hiddenMsg=false) { base.GetDefaultParameters(ch, numRow, hasCategorical, totalCats, true); - int numLeaves = int.Parse(Options["num_leaves"]); + int numLeaves = (int)Options["num_leaves"]; int minDataPerLeaf = Args.MinDataPerLeaf ?? DefaultMinDataPerLeaf(numRow, numLeaves, _numClass); - Options["min_data_per_leaf"] = minDataPerLeaf.ToString(); + Options["min_data_per_leaf"] = minDataPerLeaf; if (!hiddenMsg) { if (!Args.LearningRate.HasValue) @@ -149,7 +150,7 @@ protected override void CheckAndUpdateParametersBeforeTraining(IChannel ch, Role Host.AssertValue(ch); ch.Assert(PredictionKind == PredictionKind.MultiClassClassification); ch.Assert(_numClass > 1); - Options["num_class"] = _numClass.ToString(); + Options["num_class"] = _numClass; bool useSoftmax = false; if (Args.UseSoftmax.HasValue) diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index aff3befa0d..c778c4ee23 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -3,14 +3,10 @@ // See the LICENSE file in the project root for more information. using System; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; using System.Collections.Generic; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Training; -using Microsoft.ML.Runtime.FastTree.Internal; namespace Microsoft.ML.Runtime.LightGBM { @@ -49,7 +45,13 @@ private sealed class CategoricalMetaData protected readonly IHost Host; protected readonly LightGbmArguments Args; - protected readonly Dictionary Options; + + /// + /// Stores argumments as objects to convert them to invariant string type in the end so that + /// the code is culture agnostic. When retrieving key value from this dictionary as string + /// please convert to string invariant by string.Format(CultureInfo.InvariantCulture, "{0}", Option[key]). + /// + protected readonly Dictionary Options; protected readonly IParallel ParallelTraining; // Store _featureCount and _trainedEnsemble to construct predictor. @@ -159,9 +161,9 @@ protected virtual void GetDefaultParameters(IChannel ch, int numRow, bool hasCat double learningRate = Args.LearningRate ?? DefaultLearningRate(numRow, hasCategarical, totalCats); int numLeaves = Args.NumLeaves ?? DefaultNumLeaves(numRow, hasCategarical, totalCats); int minDataPerLeaf = Args.MinDataPerLeaf ?? DefaultMinDataPerLeaf(numRow, numLeaves, 1); - Options["learning_rate"] = learningRate.ToString(); - Options["num_leaves"] = numLeaves.ToString(); - Options["min_data_per_leaf"] = minDataPerLeaf.ToString(); + Options["learning_rate"] = learningRate; + Options["num_leaves"] = numLeaves; + Options["min_data_per_leaf"] = minDataPerLeaf; if (!hiddenMsg) { if (!Args.LearningRate.HasValue) @@ -192,7 +194,7 @@ private static List GetCategoricalBoundires(int[] categoricalFeatures, int { if (j < categoricalFeatures.Length && curFidx == categoricalFeatures[j]) { - if (curFidx > catBoundaries.Last()) + if (curFidx > catBoundaries[catBoundaries.Count - 1]) catBoundaries.Add(curFidx); if (categoricalFeatures[j + 1] - categoricalFeatures[j] >= 0) { @@ -219,7 +221,7 @@ private static List GetCategoricalBoundires(int[] categoricalFeatures, int private static List ConstructCategoricalFeatureMetaData(int[] categoricalFeatures, int rawNumCol, ref CategoricalMetaData catMetaData) { List catBoundaries = GetCategoricalBoundires(categoricalFeatures, rawNumCol); - catMetaData.NumCol = catBoundaries.Count() - 1; + catMetaData.NumCol = catBoundaries.Count - 1; catMetaData.CategoricalBoudaries = catBoundaries.ToArray(); catMetaData.IsCategoricalFeature = new bool[catMetaData.NumCol]; catMetaData.OnehotIndices = new int[rawNumCol]; @@ -279,7 +281,7 @@ private CategoricalMetaData GetCategoricalMetaData(IChannel ch, RoleMappedData t { var catIndices = ConstructCategoricalFeatureMetaData(categoricalFeatures, rawNumCol, ref catMetaData); // Set categorical features - Options["categorical_feature"] = String.Join(",", catIndices); + Options["categorical_feature"] = string.Join(",", catIndices); } return catMetaData; } @@ -527,13 +529,13 @@ private void GetFeatureValueSparse(IChannel ch, FloatLabelCursor cursor, ++nhot; var prob = rand.NextSingle(); if (prob < 1.0f / nhot) - values[values.Count() - 1] = fv; + values[values.Count - 1] = fv; } lastIdx = newColIdx; } indices = featureIndices.ToArray(); featureValues = values.ToArray(); - cnt = featureIndices.Count(); + cnt = featureIndices.Count; } else { diff --git a/src/Microsoft.ML.LightGBM/WrappedLightGbmBooster.cs b/src/Microsoft.ML.LightGBM/WrappedLightGbmBooster.cs index 116fd1c4a9..c9f3128434 100644 --- a/src/Microsoft.ML.LightGBM/WrappedLightGbmBooster.cs +++ b/src/Microsoft.ML.LightGBM/WrappedLightGbmBooster.cs @@ -20,7 +20,7 @@ internal sealed class Booster : IDisposable public IntPtr Handle { get; private set; } public int BestIteration { get; set; } - public Booster(Dictionary parameters, Dataset trainset, Dataset validset = null) + public Booster(Dictionary parameters, Dataset trainset, Dataset validset = null) { var param = LightGbmInterfaceUtils.JoinParameters(parameters); var handle = IntPtr.Zero; diff --git a/src/Microsoft.ML.LightGBM/WrappedLightGbmInterface.cs b/src/Microsoft.ML.LightGBM/WrappedLightGbmInterface.cs index 6d5e13bbeb..eec00d9bd1 100644 --- a/src/Microsoft.ML.LightGBM/WrappedLightGbmInterface.cs +++ b/src/Microsoft.ML.LightGBM/WrappedLightGbmInterface.cs @@ -3,8 +3,9 @@ // See the LICENSE file in the project root for more information. using System; -using System.Runtime.InteropServices; using System.Collections.Generic; +using System.Globalization; +using System.Runtime.InteropServices; namespace Microsoft.ML.Runtime.LightGBM { @@ -199,13 +200,13 @@ public static void Check(int res) /// /// Join the parameters to key=value format. /// - public static string JoinParameters(Dictionary parameters) + public static string JoinParameters(Dictionary parameters) { if (parameters == null) return ""; List res = new List(); foreach (var keyVal in parameters) - res.Add(keyVal.Key + "=" + keyVal.Value); + res.Add(keyVal.Key + "=" + string.Format(CultureInfo.InvariantCulture, "{0}", keyVal.Value)); return string.Join(" ", res); } diff --git a/src/Microsoft.ML.LightGBM/WrappedLightGbmTraining.cs b/src/Microsoft.ML.LightGBM/WrappedLightGbmTraining.cs index 9699581a13..8b2036fb11 100644 --- a/src/Microsoft.ML.LightGBM/WrappedLightGbmTraining.cs +++ b/src/Microsoft.ML.LightGBM/WrappedLightGbmTraining.cs @@ -16,7 +16,7 @@ internal static class WrappedLightGbmTraining /// Train and return a booster. /// public static Booster Train(IChannel ch, IProgressChannel pch, - Dictionary parameters, Dataset dtrain, Dataset dvalid = null, int numIteration = 100, + Dictionary parameters, Dataset dtrain, Dataset dvalid = null, int numIteration = 100, bool verboseEval = true, int earlyStoppingRound = 0) { // create Booster. @@ -33,12 +33,9 @@ public static Booster Train(IChannel ch, IProgressChannel pch, double bestScore = double.MaxValue; double factorToSmallerBetter = 1.0; - if (earlyStoppingRound > 0 && (parameters["metric"] == "auc" - || parameters["metric"] == "ndcg" - || parameters["metric"] == "map")) - { + var metric = (string)parameters["metric"]; + if (earlyStoppingRound > 0 && (metric == "auc" || metric == "ndcg" || metric == "map")) factorToSmallerBetter = -1.0; - } const int evalFreq = 50;