-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[jvm-packages] XGBoost Spark supports ranking with group data. #3369
[jvm-packages] XGBoost Spark supports ranking with group data. #3369
Conversation
eh....large than I expect...will look at it tmr afternoon (Friday afternoon :-) ) |
@CodingCat @yanboliang Sorry for hijacking this thread, but would #2749 be useful for ranking tasks on Spark? |
@hcho3 Yep, it would be useful for ranking on xgboost-spark, we are in the same direction. This PR expose a new group data API for xgboost-spark, we can update internal implementation if the backend C++ code changed. Thanks. |
@yanboliang That's good to know. I will add some tests to #2749 and merge it. Thanks! |
@@ -21,12 +21,10 @@ import java.nio.file.Files | |||
|
|||
import scala.collection.mutable | |||
import scala.util.Random | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we keep these empty lines to separate xgboost4j and the other imports?
@@ -56,8 +54,8 @@ object XGBoost extends Serializable { | |||
private val logger = LogFactory.getLog("XGBoostSpark") | |||
|
|||
private def removeMissingValues( | |||
denseLabeledPoints: Iterator[XGBLabeledPoint], | |||
missing: Float): Iterator[XGBLabeledPoint] = { | |||
denseLabeledPoints: Seq[XGBLabeledPoint], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you take sequence, that means you will load a partition into memory entirely which will leads to OOM potentially
@@ -129,7 +127,7 @@ object XGBoost extends Serializable { | |||
rabitEnv.put("DMLC_TASK_ID", taskId) | |||
Rabit.init(rabitEnv) | |||
val watches = Watches(params, | |||
removeMissingValues(labeledPoints, missing), | |||
removeMissingValues(labeledPoints.toSeq, missing), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, this toSeq is risky to lead to OOM
@@ -308,9 +306,26 @@ private class Watches private( | |||
|
|||
private object Watches { | |||
|
|||
def formatGroups(groups: Seq[Int]): Seq[Int] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
buildGroups or formGroups as a better name?
Codecov Report
@@ Coverage Diff @@
## spark_dev_do_not_delete #3369 +/- ##
=============================================================
+ Coverage 44.92% 45.02% +0.09%
- Complexity 186 188 +2
=============================================================
Files 163 163
Lines 12932 12952 +20
Branches 439 443 +4
=============================================================
+ Hits 5810 5831 +21
+ Misses 6921 6920 -1
Partials 201 201
Continue to review full report at Codecov.
|
LGTM |
can you file a PR from spark_dev_do_not_delete to master to ensure that they count as your contribution |
* add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * [jvm-packages] XGBoost Spark integration refactor. (#3313) * XGBoost Spark integration refactor. * Make corresponding update for xgboost4j-example * Address comments. * [jvm-packages] Refactor XGBoost-Spark params to make it compatible with both XGBoost and Spark MLLib (#3326) * Refactor XGBoost-Spark params to make it compatible with both XGBoost and Spark MLLib * Fix extra space. * [jvm-packages] XGBoost Spark supports ranking with group data. (#3369) * XGBoost Spark supports ranking with group data. * Use Iterator.duplicate to prevent OOM. * Update CheckpointManagerSuite.scala * Resolve conflicts
XGBoost Spark supports ranking with group data.