Skip to content

Commit

Permalink
chore: Update VerifyLightGBMClassifier.scala (#1313)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 authored Dec 16, 2021
1 parent e89420a commit b268271
Showing 1 changed file with 32 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -429,36 +429,38 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
}

test("Verify LightGBM Classifier with validation dataset") {
val df = au3DF.orderBy(rand()).withColumn(validationCol, lit(false))

val Array(train, validIntermediate, test) = df.randomSplit(Array(0.5, 0.2, 0.3), seed)
val valid = validIntermediate.withColumn(validationCol, lit(true))
val trainAndValid = train.union(valid.orderBy(rand()))

// model1 should overfit on the given dataset
val model1 = baseModel
.setNumLeaves(100)
.setNumIterations(100)
.setLearningRate(0.9)
.setMinDataInLeaf(2)
.setValidationIndicatorCol(validationCol)
.setEarlyStoppingRound(100)

// model2 should terminate early before overfitting
val model2 = baseModel
.setNumLeaves(100)
.setNumIterations(100)
.setLearningRate(0.9)
.setMinDataInLeaf(2)
.setValidationIndicatorCol(validationCol)
.setEarlyStoppingRound(5)

// Assert evaluation metric improves
Array("auc", "binary_logloss", "binary_error").foreach { metric =>
assertBinaryImprovement(
model1.setMetric(metric), trainAndValid, test,
model2.setMetric(metric), trainAndValid, test
)
tryWithRetries(Array(0, 100, 500)) {() =>
val df = au3DF.orderBy(rand()).withColumn(validationCol, lit(false))

val Array(train, validIntermediate, test) = df.randomSplit(Array(0.5, 0.2, 0.3), seed)
val valid = validIntermediate.withColumn(validationCol, lit(true))
val trainAndValid = train.union(valid.orderBy(rand()))

// model1 should overfit on the given dataset
val model1 = baseModel
.setNumLeaves(100)
.setNumIterations(100)
.setLearningRate(0.9)
.setMinDataInLeaf(2)
.setValidationIndicatorCol(validationCol)
.setEarlyStoppingRound(100)

// model2 should terminate early before overfitting
val model2 = baseModel
.setNumLeaves(100)
.setNumIterations(100)
.setLearningRate(0.9)
.setMinDataInLeaf(2)
.setValidationIndicatorCol(validationCol)
.setEarlyStoppingRound(5)

// Assert evaluation metric improves
Array("auc", "binary_logloss", "binary_error").foreach { metric =>
assertBinaryImprovement(
model1.setMetric(metric), trainAndValid, test,
model2.setMetric(metric), trainAndValid, test
)
}
}
}

Expand Down

0 comments on commit b268271

Please sign in to comment.