@@ -55,6 +55,28 @@ public void LightGBMBinaryEstimator()
5555 NumberOfLeaves = 10 ,
5656 NumberOfThreads = 1 ,
5757 MinimumExampleCountPerLeaf = 2 ,
58+ UnbalancedSets = false , // default value
59+ } ) ;
60+
61+ var pipeWithTrainer = pipe . Append ( trainer ) ;
62+ TestEstimatorCore ( pipeWithTrainer , dataView ) ;
63+
64+ var transformedDataView = pipe . Fit ( dataView ) . Transform ( dataView ) ;
65+ var model = trainer . Fit ( transformedDataView , transformedDataView ) ;
66+ Done ( ) ;
67+ }
68+
69+ [ LightGBMFact ]
70+ public void LightGBMBinaryEstimatorUnbalanced ( )
71+ {
72+ var ( pipe , dataView ) = GetBinaryClassificationPipeline ( ) ;
73+
74+ var trainer = ML . BinaryClassification . Trainers . LightGbm ( new LightGbmBinaryTrainer . Options
75+ {
76+ NumberOfLeaves = 10 ,
77+ NumberOfThreads = 1 ,
78+ MinimumExampleCountPerLeaf = 2 ,
79+ UnbalancedSets = true ,
5880 } ) ;
5981
6082 var pipeWithTrainer = pipe . Append ( trainer ) ;
@@ -322,6 +344,44 @@ public void LightGbmMulticlassEstimatorCorrectSigmoid()
322344 Done ( ) ;
323345 }
324346
347+ /// <summary>
348+ /// LightGbmMulticlass Test of Balanced Data
349+ /// </summary>
350+ [ LightGBMFact ]
351+ public void LightGbmMulticlassEstimatorBalanced ( )
352+ {
353+ var ( pipeline , dataView ) = GetMulticlassPipeline ( ) ;
354+
355+ var trainer = ML . MulticlassClassification . Trainers . LightGbm ( new LightGbmMulticlassTrainer . Options
356+ {
357+ UnbalancedSets = false
358+ } ) ;
359+
360+ var pipe = pipeline . Append ( trainer )
361+ . Append ( new KeyToValueMappingEstimator ( Env , "PredictedLabel" ) ) ;
362+ TestEstimatorCore ( pipe , dataView ) ;
363+ Done ( ) ;
364+ }
365+
366+ /// <summary>
367+ /// LightGbmMulticlass Test of Unbalanced Data
368+ /// </summary>
369+ [ LightGBMFact ]
370+ public void LightGbmMulticlassEstimatorUnbalanced ( )
371+ {
372+ var ( pipeline , dataView ) = GetMulticlassPipeline ( ) ;
373+
374+ var trainer = ML . MulticlassClassification . Trainers . LightGbm ( new LightGbmMulticlassTrainer . Options
375+ {
376+ UnbalancedSets = true
377+ } ) ;
378+
379+ var pipe = pipeline . Append ( trainer )
380+ . Append ( new KeyToValueMappingEstimator ( Env , "PredictedLabel" ) ) ;
381+ TestEstimatorCore ( pipe , dataView ) ;
382+ Done ( ) ;
383+ }
384+
325385 // Number of examples
326386 private const int _rowNumber = 1000 ;
327387 // Number of features
@@ -338,7 +398,7 @@ private class GbmExample
338398 public float [ ] Score ;
339399 }
340400
341- private void LightGbmHelper ( bool useSoftmax , double sigmoid , out string modelString , out List < GbmExample > mlnetPredictions , out double [ ] lgbmRawScores , out double [ ] lgbmProbabilities )
401+ private void LightGbmHelper ( bool useSoftmax , double sigmoid , out string modelString , out List < GbmExample > mlnetPredictions , out double [ ] lgbmRawScores , out double [ ] lgbmProbabilities , bool unbalancedSets = false )
342402 {
343403 // Prepare data and train LightGBM model via ML.NET
344404 // Training matrix. It contains all feature vectors.
@@ -372,7 +432,8 @@ private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelStr
372432 MinimumExampleCountPerGroup = 1 ,
373433 MinimumExampleCountPerLeaf = 1 ,
374434 UseSoftmax = useSoftmax ,
375- Sigmoid = sigmoid // Custom sigmoid value.
435+ Sigmoid = sigmoid , // Custom sigmoid value.
436+ UnbalancedSets = unbalancedSets // false by default
376437 } ) ;
377438
378439 var gbm = gbmTrainer . Fit ( dataView ) ;
@@ -583,6 +644,35 @@ public void LightGbmMulticlassEstimatorCompareSoftMax()
583644 Done ( ) ;
584645 }
585646
647+ [ LightGBMFact ]
648+ public void LightGbmMulticlassEstimatorCompareUnbalanced ( )
649+ {
650+ // Train ML.NET LightGBM and native LightGBM and apply the trained models to the training set.
651+ LightGbmHelper ( useSoftmax : true , sigmoid : .5 , out string modelString , out List < GbmExample > mlnetPredictions , out double [ ] nativeResult1 , out double [ ] nativeResult0 , unbalancedSets : true ) ;
652+
653+ // The i-th predictor returned by LightGBM produces the raw score, denoted by z_i, of the i-th class.
654+ // Assume that we have n classes in total. The i-th class probability can be computed via
655+ // p_i = exp(z_i) / (exp(z_1) + ... + exp(z_n)).
656+ Assert . True ( modelString != null ) ;
657+ // Compare native LightGBM's and ML.NET's LightGBM results example by example
658+ for ( int i = 0 ; i < _rowNumber ; ++ i )
659+ {
660+ double sum = 0 ;
661+ for ( int j = 0 ; j < _classNumber ; ++ j )
662+ {
663+ Assert . Equal ( nativeResult0 [ j + i * _classNumber ] , mlnetPredictions [ i ] . Score [ j ] , 6 ) ;
664+ sum += Math . Exp ( ( float ) nativeResult1 [ j + i * _classNumber ] ) ;
665+ }
666+ for ( int j = 0 ; j < _classNumber ; ++ j )
667+ {
668+ double prob = Math . Exp ( nativeResult1 [ j + i * _classNumber ] ) ;
669+ Assert . Equal ( prob / sum , mlnetPredictions [ i ] . Score [ j ] , 6 ) ;
670+ }
671+ }
672+
673+ Done ( ) ;
674+ }
675+
586676 [ LightGBMFact ]
587677 public void LightGbmInDifferentCulture ( )
588678 {
0 commit comments