Skip to content

Commit c5a18ef

Browse files
author
Rayan-Krishnan
authored
LightGBM Unbalanced Data Argument [Issue #3688 Fix] (#3925)
* LightGBM unbalanced data arg added * unbalanced data argument added * tests for unbalanced LightGbm and added arg for multiclass * reverted changes on LightGbmArguments * wording improvement on unbalanced arg help text * updated manifest * removed empty line * added keytype to test
1 parent 0c0789f commit c5a18ef

File tree

4 files changed

+111
-3
lines changed

4 files changed

+111
-3
lines changed

src/Microsoft.ML.LightGbm/LightGbmArguments.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ public BoosterParameterBase(OptionsBase options)
5858
public abstract class OptionsBase : IBoosterParameterFactory
5959
{
6060
internal BoosterParameterBase GetBooster() { return null; }
61-
6261
/// <summary>
6362
/// The minimum loss reduction required to make a further partition on a leaf node of the tree.
6463
/// </summary>

src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ public enum EvaluateMetricType
7575
LogLoss,
7676
}
7777

78+
/// <summary>
79+
/// Whether training data is unbalanced.
80+
/// </summary>
81+
[Argument(ArgumentType.AtMostOnce, HelpText = "Use for multi-class classification when training data is not balanced", ShortName = "us")]
82+
public bool UnbalancedSets = false;
83+
7884
/// <summary>
7985
/// Whether to use softmax loss.
8086
/// </summary>
@@ -110,6 +116,7 @@ internal override Dictionary<string, object> ToDictionary(IHost host)
110116
{
111117
var res = base.ToDictionary(host);
112118

119+
res[GetOptionName(nameof(UnbalancedSets))] = UnbalancedSets;
113120
res[GetOptionName(nameof(Sigmoid))] = Sigmoid;
114121
res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString());
115122

test/BaselineOutput/Common/EntryPoints/core_manifest.json

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11974,6 +11974,18 @@
1197411974
"IsNullable": false,
1197511975
"Default": "Auto"
1197611976
},
11977+
{
11978+
"Name": "UnbalancedSets",
11979+
"Type": "Bool",
11980+
"Desc": "Use for multi-class classification when training data is not balanced",
11981+
"Aliases": [
11982+
"us"
11983+
],
11984+
"Required": false,
11985+
"SortOrder": 150.0,
11986+
"IsNullable": false,
11987+
"Default": false
11988+
},
1197711989
{
1197811990
"Name": "UseSoftmax",
1197911991
"Type": "Bool",

test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,28 @@ public void LightGBMBinaryEstimator()
5555
NumberOfLeaves = 10,
5656
NumberOfThreads = 1,
5757
MinimumExampleCountPerLeaf = 2,
58+
UnbalancedSets = false, // default value
59+
});
60+
61+
var pipeWithTrainer = pipe.Append(trainer);
62+
TestEstimatorCore(pipeWithTrainer, dataView);
63+
64+
var transformedDataView = pipe.Fit(dataView).Transform(dataView);
65+
var model = trainer.Fit(transformedDataView, transformedDataView);
66+
Done();
67+
}
68+
69+
[LightGBMFact]
70+
public void LightGBMBinaryEstimatorUnbalanced()
71+
{
72+
var (pipe, dataView) = GetBinaryClassificationPipeline();
73+
74+
var trainer = ML.BinaryClassification.Trainers.LightGbm(new LightGbmBinaryTrainer.Options
75+
{
76+
NumberOfLeaves = 10,
77+
NumberOfThreads = 1,
78+
MinimumExampleCountPerLeaf = 2,
79+
UnbalancedSets = true,
5880
});
5981

6082
var pipeWithTrainer = pipe.Append(trainer);
@@ -322,6 +344,44 @@ public void LightGbmMulticlassEstimatorCorrectSigmoid()
322344
Done();
323345
}
324346

347+
/// <summary>
348+
/// LightGbmMulticlass Test of Balanced Data
349+
/// </summary>
350+
[LightGBMFact]
351+
public void LightGbmMulticlassEstimatorBalanced()
352+
{
353+
var (pipeline, dataView) = GetMulticlassPipeline();
354+
355+
var trainer = ML.MulticlassClassification.Trainers.LightGbm(new LightGbmMulticlassTrainer.Options
356+
{
357+
UnbalancedSets = false
358+
});
359+
360+
var pipe = pipeline.Append(trainer)
361+
.Append(new KeyToValueMappingEstimator(Env, "PredictedLabel"));
362+
TestEstimatorCore(pipe, dataView);
363+
Done();
364+
}
365+
366+
/// <summary>
367+
/// LightGbmMulticlass Test of Unbalanced Data
368+
/// </summary>
369+
[LightGBMFact]
370+
public void LightGbmMulticlassEstimatorUnbalanced()
371+
{
372+
var (pipeline, dataView) = GetMulticlassPipeline();
373+
374+
var trainer = ML.MulticlassClassification.Trainers.LightGbm(new LightGbmMulticlassTrainer.Options
375+
{
376+
UnbalancedSets = true
377+
});
378+
379+
var pipe = pipeline.Append(trainer)
380+
.Append(new KeyToValueMappingEstimator(Env, "PredictedLabel"));
381+
TestEstimatorCore(pipe, dataView);
382+
Done();
383+
}
384+
325385
// Number of examples
326386
private const int _rowNumber = 1000;
327387
// Number of features
@@ -338,7 +398,7 @@ private class GbmExample
338398
public float[] Score;
339399
}
340400

341-
private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelString, out List<GbmExample> mlnetPredictions, out double[] lgbmRawScores, out double[] lgbmProbabilities)
401+
private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelString, out List<GbmExample> mlnetPredictions, out double[] lgbmRawScores, out double[] lgbmProbabilities, bool unbalancedSets = false)
342402
{
343403
// Prepare data and train LightGBM model via ML.NET
344404
// Training matrix. It contains all feature vectors.
@@ -372,7 +432,8 @@ private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelStr
372432
MinimumExampleCountPerGroup = 1,
373433
MinimumExampleCountPerLeaf = 1,
374434
UseSoftmax = useSoftmax,
375-
Sigmoid = sigmoid // Custom sigmoid value.
435+
Sigmoid = sigmoid, // Custom sigmoid value.
436+
UnbalancedSets = unbalancedSets // false by default
376437
});
377438

378439
var gbm = gbmTrainer.Fit(dataView);
@@ -583,6 +644,35 @@ public void LightGbmMulticlassEstimatorCompareSoftMax()
583644
Done();
584645
}
585646

647+
[LightGBMFact]
648+
public void LightGbmMulticlassEstimatorCompareUnbalanced()
649+
{
650+
// Train ML.NET LightGBM and native LightGBM and apply the trained models to the training set.
651+
LightGbmHelper(useSoftmax: true, sigmoid: .5, out string modelString, out List<GbmExample> mlnetPredictions, out double[] nativeResult1, out double[] nativeResult0, unbalancedSets:true);
652+
653+
// The i-th predictor returned by LightGBM produces the raw score, denoted by z_i, of the i-th class.
654+
// Assume that we have n classes in total. The i-th class probability can be computed via
655+
// p_i = exp(z_i) / (exp(z_1) + ... + exp(z_n)).
656+
Assert.True(modelString != null);
657+
// Compare native LightGBM's and ML.NET's LightGBM results example by example
658+
for (int i = 0; i < _rowNumber; ++i)
659+
{
660+
double sum = 0;
661+
for (int j = 0; j < _classNumber; ++j)
662+
{
663+
Assert.Equal(nativeResult0[j + i * _classNumber], mlnetPredictions[i].Score[j], 6);
664+
sum += Math.Exp((float)nativeResult1[j + i * _classNumber]);
665+
}
666+
for (int j = 0; j < _classNumber; ++j)
667+
{
668+
double prob = Math.Exp(nativeResult1[j + i * _classNumber]);
669+
Assert.Equal(prob / sum, mlnetPredictions[i].Score[j], 6);
670+
}
671+
}
672+
673+
Done();
674+
}
675+
586676
[LightGBMFact]
587677
public void LightGbmInDifferentCulture()
588678
{

0 commit comments

Comments
 (0)