Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: minmingzhu <minming.zhu@intel.com>
  • Loading branch information
minmingzhu committed Oct 11, 2023
1 parent 97a470d commit 4c4c993
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,14 @@ class RandomForestClassifierDALImpl(val uid: String,
}.count()
rfcTimer.record("OneCCL Init")

val results = labeledPointsTables.mapPartitionsWithIndex {
(rank: Int, tables: Iterator[(String, String)]) =>
val results = labeledPointsTables.mapPartitionsWithIndex { (rank, tables) =>
val (feature, label) = tables.next()
val (featureTabAddr : Long, featureRows : Long, featureColumns : Long) = {
val parts = feature.toString.split("_")
(parts(0).toLong, parts(1).toLong, parts(2).toLong)
}
val (labelTabAddr : Long, labelRows : Long, labelColumns : Long) = {
val parts = feature.toString.split("_")
val parts = label.toString.split("_")
(parts(0).toLong, parts(1).toLong, parts(2).toLong)
}
val gpuIndices = if (useDevice == "GPU") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ class LinearRegressionDALImpl( val fitIntercept: Boolean,
}
lrTimer.record("Data Convertion")

val results = labeledPointsTables.mapPartitionsWithIndex {
case (rank: Int, tables: Iterator[(Any, Any)]) =>
val results = labeledPointsTables.mapPartitionsWithIndex { (rank, tables) =>
val (feature, label) = tables.next()
val (featureTabAddr : Long, featureRows : Long, featureColumns : Long) =
if (useDevice == "GPU") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,14 @@ class RandomForestRegressorDALImpl(val uid: String,
}.count()
rfrTimer.record("OneCCL Init")

val results = labeledPointsTables.mapPartitionsWithIndex {
(rank: Int, tables: Iterator[(String, String)]) =>
val results = labeledPointsTables.mapPartitionsWithIndex { (rank, tables) =>
val (feature, label) = tables.next()
val (featureTabAddr : Long, featureRows : Long, featureColumns : Long) = {
val parts = feature.toString.split("_")
(parts(0).toLong, parts(1).toLong, parts(2).toLong)
}
val (labelTabAddr : Long, labelRows : Long, labelColumns : Long) = {
val parts = feature.toString.split("_")
val parts = label.toString.split("_")
(parts(0).toLong, parts(1).toLong, parts(2).toLong)
}
val gpuIndices = if (useDevice == "GPU") {
Expand Down

0 comments on commit 4c4c993

Please sign in to comment.