Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
Expand All @@ -32,6 +32,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.util.Utils

/**
* :: Experimental ::
Expand Down Expand Up @@ -115,7 +116,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
override protected def formatVersion: String = "1.0"
}

object DecisionTreeModel extends Loader[DecisionTreeModel] {
object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {

private[tree] object SaveLoadV1_0 {

Expand Down Expand Up @@ -187,6 +188,28 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

// SPARK-6120: We do a hacky check here so users understand why save() is failing
// when they run the ML guide example.
// TODO: Fix this issue for real.
val memThreshold = 768
if (sc.isLocal) {
val driverMemory = sc.getConf.getOption("spark.driver.memory")
.orElse(Option(System.getenv("SPARK_DRIVER_MEMORY")))
.map(Utils.memoryStringToMb)
.getOrElse(512)
if (driverMemory <= memThreshold) {
logWarning(s"$thisClassName.save() was called, but it may fail because of too little" +
s" driver memory (${driverMemory}m)." +
s" If failure occurs, try setting driver-memory ${memThreshold}m (or larger).")
}
} else {
if (sc.executorMemory <= memThreshold) {
logWarning(s"$thisClassName.save() was called, but it may fail because of too little" +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Menion executorMemory in the log message.

s" executor memory (${sc.executorMemory}m)." +
s" If failure occurs try setting executor-memory ${memThreshold}m (or larger).")
}
}

// Create JSON metadata.
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
Expand All @@ -34,6 +34,7 @@ import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils

/**
* :: Experimental ::
Expand Down Expand Up @@ -250,7 +251,7 @@ private[tree] sealed class TreeEnsembleModel(
def totalNumNodes: Int = trees.map(_.numNodes).sum
}

private[tree] object TreeEnsembleModel {
private[tree] object TreeEnsembleModel extends Logging {

object SaveLoadV1_0 {

Expand All @@ -277,6 +278,28 @@ private[tree] object TreeEnsembleModel {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

// SPARK-6120: We do a hacky check here so users understand why save() is failing
// when they run the ML guide example.
// TODO: Fix this issue for real.
val memThreshold = 768
if (sc.isLocal) {
val driverMemory = sc.getConf.getOption("spark.driver.memory")
.orElse(Option(System.getenv("SPARK_DRIVER_MEMORY")))
.map(Utils.memoryStringToMb)
.getOrElse(512)
if (driverMemory <= memThreshold) {
logWarning(s"$className.save() was called, but it may fail because of too little" +
s" driver memory (${driverMemory}m)." +
s" If failure occurs, try setting driver-memory ${memThreshold}m (or larger).")
}
} else {
if (sc.executorMemory <= memThreshold) {
logWarning(s"$className.save() was called, but it may fail because of too little" +
s" executor memory (${sc.executorMemory}m)." +
s" If failure occurs try setting executor-memory ${memThreshold}m (or larger).")
}
}

// Create JSON metadata.
implicit val format = DefaultFormats
val ensembleMetadata = Metadata(model.algo.toString, model.trees(0).algo.toString,
Expand Down