Skip to content

Commit 4328c3f

Browse files
codemzseerhardt
authored andcommitted
Set culture to culture invariant in LightGBM (dotnet#454)
* Make options culture agnostic.
1 parent 72a8a89 commit 4328c3f

File tree

6 files changed

+34
-33
lines changed

6 files changed

+34
-33
lines changed

src/Microsoft.ML.LightGBM/LightGbmArguments.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public interface ISupportBoosterParameterFactory : IComponentFactory<IBoosterPar
3232
}
3333
public interface IBoosterParameter
3434
{
35-
void UpdateParameters(Dictionary<string, string> res);
35+
void UpdateParameters(Dictionary<string, object> res);
3636
}
3737

3838
/// <summary>
@@ -54,7 +54,7 @@ protected BoosterParameter(TArgs args)
5454
/// <summary>
5555
/// Update the parameters by specific Booster, will update parameters into "res" directly.
5656
/// </summary>
57-
public virtual void UpdateParameters(Dictionary<string, string> res)
57+
public virtual void UpdateParameters(Dictionary<string, object> res)
5858
{
5959
FieldInfo[] fields = Args.GetType().GetFields();
6060
foreach (var field in fields)
@@ -163,7 +163,7 @@ public TreeBooster(Arguments args)
163163
Contracts.CheckUserArg(Args.ScalePosWeight > 0 && Args.ScalePosWeight <= 1, nameof(Args.ScalePosWeight), "must be in (0,1].");
164164
}
165165

166-
public override void UpdateParameters(Dictionary<string, string> res)
166+
public override void UpdateParameters(Dictionary<string, object> res)
167167
{
168168
base.UpdateParameters(res);
169169
res["boosting_type"] = Name;
@@ -207,7 +207,7 @@ public DartBooster(Arguments args)
207207
Contracts.CheckUserArg(Args.SkipDrop >= 0 && Args.SkipDrop < 1, nameof(Args.SkipDrop), "must be in [0,1).");
208208
}
209209

210-
public override void UpdateParameters(Dictionary<string, string> res)
210+
public override void UpdateParameters(Dictionary<string, object> res)
211211
{
212212
base.UpdateParameters(res);
213213
res["boosting_type"] = Name;
@@ -244,7 +244,7 @@ public GossBooster(Arguments args)
244244
Contracts.Check(Args.TopRate + Args.OtherRate <= 1, "Sum of topRate and otherRate cannot be larger than 1.");
245245
}
246246

247-
public override void UpdateParameters(Dictionary<string, string> res)
247+
public override void UpdateParameters(Dictionary<string, object> res)
248248
{
249249
base.UpdateParameters(res);
250250
res["boosting_type"] = Name;
@@ -355,11 +355,11 @@ public enum EvalMetricType
355355
[Argument(ArgumentType.Multiple, HelpText = "Parallel LightGBM Learning Algorithm", ShortName = "parag")]
356356
public ISupportParallel ParallelTrainer = new SingleTrainerFactory();
357357

358-
internal Dictionary<string, string> ToDictionary(IHost host)
358+
internal Dictionary<string, object> ToDictionary(IHost host)
359359
{
360360
Contracts.CheckValue(host, nameof(host));
361361
Contracts.CheckUserArg(MaxBin > 0, nameof(MaxBin), "must be > 0.");
362-
Dictionary<string, string> res = new Dictionary<string, string>();
362+
Dictionary<string, object> res = new Dictionary<string, object>();
363363

364364
var boosterParams = Booster.CreateComponent(host);
365365
boosterParams.UpdateParameters(res);

src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Globalization;
67
using Microsoft.ML.Runtime;
78
using Microsoft.ML.Runtime.Data;
89
using Microsoft.ML.Runtime.EntryPoints;
@@ -130,9 +131,9 @@ protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float
130131
protected override void GetDefaultParameters(IChannel ch, int numRow, bool hasCategorical, int totalCats, bool hiddenMsg=false)
131132
{
132133
base.GetDefaultParameters(ch, numRow, hasCategorical, totalCats, true);
133-
int numLeaves = int.Parse(Options["num_leaves"]);
134+
int numLeaves = (int)Options["num_leaves"];
134135
int minDataPerLeaf = Args.MinDataPerLeaf ?? DefaultMinDataPerLeaf(numRow, numLeaves, _numClass);
135-
Options["min_data_per_leaf"] = minDataPerLeaf.ToString();
136+
Options["min_data_per_leaf"] = minDataPerLeaf;
136137
if (!hiddenMsg)
137138
{
138139
if (!Args.LearningRate.HasValue)
@@ -149,7 +150,7 @@ protected override void CheckAndUpdateParametersBeforeTraining(IChannel ch, Role
149150
Host.AssertValue(ch);
150151
ch.Assert(PredictionKind == PredictionKind.MultiClassClassification);
151152
ch.Assert(_numClass > 1);
152-
Options["num_class"] = _numClass.ToString();
153+
Options["num_class"] = _numClass;
153154
bool useSoftmax = false;
154155

155156
if (Args.UseSoftmax.HasValue)

src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,10 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.Linq;
7-
using System.Threading;
8-
using System.Threading.Tasks;
96
using System.Collections.Generic;
107
using Microsoft.ML.Runtime.Data;
118
using Microsoft.ML.Runtime.Internal.Utilities;
129
using Microsoft.ML.Runtime.Training;
13-
using Microsoft.ML.Runtime.FastTree.Internal;
1410

1511
namespace Microsoft.ML.Runtime.LightGBM
1612
{
@@ -49,7 +45,13 @@ private sealed class CategoricalMetaData
4945

5046
protected readonly IHost Host;
5147
protected readonly LightGbmArguments Args;
52-
protected readonly Dictionary<string, string> Options;
48+
49+
/// <summary>
50+
/// Stores argumments as objects to convert them to invariant string type in the end so that
51+
/// the code is culture agnostic. When retrieving key value from this dictionary as string
52+
/// please convert to string invariant by string.Format(CultureInfo.InvariantCulture, "{0}", Option[key]).
53+
/// </summary>
54+
protected readonly Dictionary<string, object> Options;
5355
protected readonly IParallel ParallelTraining;
5456

5557
// Store _featureCount and _trainedEnsemble to construct predictor.
@@ -159,9 +161,9 @@ protected virtual void GetDefaultParameters(IChannel ch, int numRow, bool hasCat
159161
double learningRate = Args.LearningRate ?? DefaultLearningRate(numRow, hasCategarical, totalCats);
160162
int numLeaves = Args.NumLeaves ?? DefaultNumLeaves(numRow, hasCategarical, totalCats);
161163
int minDataPerLeaf = Args.MinDataPerLeaf ?? DefaultMinDataPerLeaf(numRow, numLeaves, 1);
162-
Options["learning_rate"] = learningRate.ToString();
163-
Options["num_leaves"] = numLeaves.ToString();
164-
Options["min_data_per_leaf"] = minDataPerLeaf.ToString();
164+
Options["learning_rate"] = learningRate;
165+
Options["num_leaves"] = numLeaves;
166+
Options["min_data_per_leaf"] = minDataPerLeaf;
165167
if (!hiddenMsg)
166168
{
167169
if (!Args.LearningRate.HasValue)
@@ -192,7 +194,7 @@ private static List<int> GetCategoricalBoundires(int[] categoricalFeatures, int
192194
{
193195
if (j < categoricalFeatures.Length && curFidx == categoricalFeatures[j])
194196
{
195-
if (curFidx > catBoundaries.Last())
197+
if (curFidx > catBoundaries[catBoundaries.Count - 1])
196198
catBoundaries.Add(curFidx);
197199
if (categoricalFeatures[j + 1] - categoricalFeatures[j] >= 0)
198200
{
@@ -219,7 +221,7 @@ private static List<int> GetCategoricalBoundires(int[] categoricalFeatures, int
219221
private static List<string> ConstructCategoricalFeatureMetaData(int[] categoricalFeatures, int rawNumCol, ref CategoricalMetaData catMetaData)
220222
{
221223
List<int> catBoundaries = GetCategoricalBoundires(categoricalFeatures, rawNumCol);
222-
catMetaData.NumCol = catBoundaries.Count() - 1;
224+
catMetaData.NumCol = catBoundaries.Count - 1;
223225
catMetaData.CategoricalBoudaries = catBoundaries.ToArray();
224226
catMetaData.IsCategoricalFeature = new bool[catMetaData.NumCol];
225227
catMetaData.OnehotIndices = new int[rawNumCol];
@@ -279,7 +281,7 @@ private CategoricalMetaData GetCategoricalMetaData(IChannel ch, RoleMappedData t
279281
{
280282
var catIndices = ConstructCategoricalFeatureMetaData(categoricalFeatures, rawNumCol, ref catMetaData);
281283
// Set categorical features
282-
Options["categorical_feature"] = String.Join(",", catIndices);
284+
Options["categorical_feature"] = string.Join(",", catIndices);
283285
}
284286
return catMetaData;
285287
}
@@ -527,13 +529,13 @@ private void GetFeatureValueSparse(IChannel ch, FloatLabelCursor cursor,
527529
++nhot;
528530
var prob = rand.NextSingle();
529531
if (prob < 1.0f / nhot)
530-
values[values.Count() - 1] = fv;
532+
values[values.Count - 1] = fv;
531533
}
532534
lastIdx = newColIdx;
533535
}
534536
indices = featureIndices.ToArray();
535537
featureValues = values.ToArray();
536-
cnt = featureIndices.Count();
538+
cnt = featureIndices.Count;
537539
}
538540
else
539541
{

src/Microsoft.ML.LightGBM/WrappedLightGbmBooster.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ internal sealed class Booster : IDisposable
2020
public IntPtr Handle { get; private set; }
2121
public int BestIteration { get; set; }
2222

23-
public Booster(Dictionary<string, string> parameters, Dataset trainset, Dataset validset = null)
23+
public Booster(Dictionary<string, object> parameters, Dataset trainset, Dataset validset = null)
2424
{
2525
var param = LightGbmInterfaceUtils.JoinParameters(parameters);
2626
var handle = IntPtr.Zero;

src/Microsoft.ML.LightGBM/WrappedLightGbmInterface.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.Runtime.InteropServices;
76
using System.Collections.Generic;
7+
using System.Globalization;
8+
using System.Runtime.InteropServices;
89

910
namespace Microsoft.ML.Runtime.LightGBM
1011
{
@@ -199,13 +200,13 @@ public static void Check(int res)
199200
/// <summary>
200201
/// Join the parameters to key=value format.
201202
/// </summary>
202-
public static string JoinParameters(Dictionary<string, string> parameters)
203+
public static string JoinParameters(Dictionary<string, object> parameters)
203204
{
204205
if (parameters == null)
205206
return "";
206207
List<string> res = new List<string>();
207208
foreach (var keyVal in parameters)
208-
res.Add(keyVal.Key + "=" + keyVal.Value);
209+
res.Add(keyVal.Key + "=" + string.Format(CultureInfo.InvariantCulture, "{0}", keyVal.Value));
209210
return string.Join(" ", res);
210211
}
211212

src/Microsoft.ML.LightGBM/WrappedLightGbmTraining.cs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ internal static class WrappedLightGbmTraining
1616
/// Train and return a booster.
1717
/// </summary>
1818
public static Booster Train(IChannel ch, IProgressChannel pch,
19-
Dictionary<string, string> parameters, Dataset dtrain, Dataset dvalid = null, int numIteration = 100,
19+
Dictionary<string, object> parameters, Dataset dtrain, Dataset dvalid = null, int numIteration = 100,
2020
bool verboseEval = true, int earlyStoppingRound = 0)
2121
{
2222
// create Booster.
@@ -33,12 +33,9 @@ public static Booster Train(IChannel ch, IProgressChannel pch,
3333
double bestScore = double.MaxValue;
3434
double factorToSmallerBetter = 1.0;
3535

36-
if (earlyStoppingRound > 0 && (parameters["metric"] == "auc"
37-
|| parameters["metric"] == "ndcg"
38-
|| parameters["metric"] == "map"))
39-
{
36+
var metric = (string)parameters["metric"];
37+
if (earlyStoppingRound > 0 && (metric == "auc" || metric == "ndcg" || metric == "map"))
4038
factorToSmallerBetter = -1.0;
41-
}
4239

4340
const int evalFreq = 50;
4441

0 commit comments

Comments
 (0)