Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvm-packages] group data is only set for training set and is set incorrectly #3097

Closed
gaofan0905 opened this issue Feb 6, 2018 · 6 comments

Comments

@gaofan0905
Copy link

When creating a watch, input data is split into trainMatrix and testMatrix randomly. But the input groupData is set only to trainMatrix. And the groupData param is for the original data set, it does fit for the split trainMatrix any more.

https://github.com/dmlc/xgboost/blob/master/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala#L520

@superbobry
Copy link
Contributor

Yes, this is a known limitation of the current group data support. The "right" way to fix this is to make the group explicitly available for each row in the input data frame, e.g. via #2749.

@gaofan0905
Copy link
Author

seems #2749 is merged, just wondering is there any example on how to set the group id in dataframe and pass it to XGBoost?

@superbobry
Copy link
Contributor

No, it has not been merged yet. Exposing this in the JVM wrapper would require a little bit of work as well.

@gaofan0905
Copy link
Author

I made some local changes like following, would that work?

And also when looking at the results, I found for some records, the prediction result is different for the same input. Is it the nature of distributed training? That only a random subset of all trees will be used?

`
private object Watches {
def apply(
params: Map[String, Any],
labeledPoints: Iterator[XGBLabeledPoint],
baseMarginsOpt: Option[Array[Float]],
cacheDirName: Option[String]): Watches = {
val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0)
if (params.contains("groupData") && params("groupData") != null) {
val groups = params("groupData").asInstanceOfSeq[Seq[Int]].toArray
val total = groups.sum
var cnt = 0
var index = 0
breakable {
for (i <- groups.indices) {
index = i
cnt += groups(i)
if (1.0 * cnt / total >= trainTestRatio) break
}
}

  val (trainGroups, testGroups) = groups.splitAt(index + 1)

  val trainPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
  val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]

  while (labeledPoints.hasNext) {
    val p = labeledPoints.next()
    if (trainPoints.size < cnt) {
      trainPoints += p
    } else {
      testPoints += p
    }
  }

  val trainMatrix = new DMatrix(trainPoints.iterator, cacheDirName.map(_ + "/train").orNull)
  val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)

  for (baseMargins <- baseMarginsOpt) {
    val (trainMargin, testMargin) = baseMargins.splitAt(cnt)
    trainMatrix.setBaseMargin(trainMargin)
    testMatrix.setBaseMargin(testMargin)
  }

  trainMatrix.setGroup(trainGroups)
  testMatrix.setGroup(testGroups)
  new Watches(trainMatrix, testMatrix, cacheDirName)

} else {
  val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
  val r = new Random(seed)
  val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
  val trainPoints = labeledPoints.filter { labeledPoint =>
    val accepted = r.nextDouble() <= trainTestRatio
    if (!accepted) {
      testPoints += labeledPoint
    }

    accepted
  }
  val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
  val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
  r.setSeed(seed)
  for (baseMargins <- baseMarginsOpt) {
    val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
    trainMatrix.setBaseMargin(trainMargin)
    testMatrix.setBaseMargin(testMargin)
  }

  new Watches(trainMatrix, testMatrix, cacheDirName)
}

}
`

@hcho3
Copy link
Collaborator

hcho3 commented Jul 4, 2018

#2749 has been merged now. I think additional work would be needed to enable grouping in test set.

All feature requests are now consolidated to #3439. This issue should be re-opened if someone decides to actively work on implementing this feature.

@hcho3 hcho3 closed this as completed Jul 4, 2018
@hcho3 hcho3 mentioned this issue Jul 4, 2018
32 tasks
@CodingCat
Copy link
Member

in the master branch of xgboost, we have allowed the user to have per-instance group info (like qid), check #3369

@lock lock bot locked as resolved and limited conversation to collaborators Oct 24, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants