diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 61773ed3ee8c..7955cc54d4b5 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -363,6 +363,7 @@ sparkR.session <- function( ...) { sparkConfigMap <- convertNamedListToEnv(sparkConfig) + namedParams <- list(...) if (length(namedParams) > 0) { paramMap <- convertNamedListToEnv(namedParams) @@ -400,11 +401,16 @@ sparkR.session <- function( sparkConfigMap) } else { jsc <- get(".sparkRjsc", envir = .sparkREnv) + # NOTE(shivaram): Pass in a tempdir that is optionally used if the user has not + # overridden this. See SPARK-18817 for more details + warehouseTmpDir <- file.path(tempdir(), "spark-warehouse") + sparkSession <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getOrCreateSparkSession", jsc, sparkConfigMap, - enableHiveSupport) + enableHiveSupport, + warehouseTmpDir) assign(".sparkRsession", sparkSession, envir = .sparkREnv) } sparkSession diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 7c096597fea6..3403410a7d12 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2890,6 +2890,20 @@ test_that("Collect on DataFrame when NAs exists at the top of a timestamp column expect_equal(class(ldf3$col3), c("POSIXct", "POSIXt")) }) +test_that("Default warehouse dir should be set to tempdir", { + setHiveContext(sc) + + # Create a temporary database and a table in it + sql("CREATE DATABASE db1") + sql("USE db1") + sql("CREATE TABLE boxes (width INT, length INT, height INT)") + # spark-warehouse should be written only tempdir() and not current working directory + expect_true(file.exists(file.path(tempdir(), "spark-warehouse", "db1.db", "boxes"))) + sql("DROP TABLE boxes") + sql("DROP DATABASE db1") + unsetHiveContext(sc) +}) + unlink(parquetPath) unlink(orcPath) unlink(jsonPath) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index a4c5bf756cd5..e34f35e4d5b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION +import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH import org.apache.spark.sql.types._ private[sql] object SQLUtils extends Logging { @@ -46,7 +47,17 @@ private[sql] object SQLUtils extends Logging { def getOrCreateSparkSession( jsc: JavaSparkContext, sparkConfigMap: JMap[Object, Object], - enableHiveSupport: Boolean): SparkSession = { + enableHiveSupport: Boolean, + warehouseDir: String): SparkSession = { + + // Check if SparkContext of sparkConfigMap contains spark.sql.warehouse.dir + // If not, set it to warehouseDir chosen by the R process. + // NOTE: We need to do this before creating the SparkSession. + val sqlWarehouseKey = WAREHOUSE_PATH.key + if (!jsc.sc.conf.contains(sqlWarehouseKey) && !sparkConfigMap.containsKey(sqlWarehouseKey)) { + jsc.sc.conf.set(sqlWarehouseKey, warehouseDir) + } + val spark = if (SparkSession.hiveClassesArePresent && enableHiveSupport && jsc.sc.conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") { SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/r/SQLUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/r/SQLUtilsSuite.scala index f54e23e3aa6c..40feb0c99c07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/r/SQLUtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/r/SQLUtilsSuite.scala @@ -17,13 +17,23 @@ package org.apache.spark.sql.api.r -import org.apache.spark.sql.test.SharedSQLContext +import java.util.HashMap -class SQLUtilsSuite extends SharedSQLContext { +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.test.SharedSQLContext - import testImplicits._ +class SQLUtilsSuite extends SparkFunSuite { test("dfToCols should collect and transpose a data frame") { + val sparkSession = SparkSession.builder() + .master("local") + .config("spark.ui.enabled", value = false) + .getOrCreate() + + import sparkSession.implicits._ + val df = Seq( (1, 2, 3), (4, 5, 6) @@ -33,6 +43,19 @@ class SQLUtilsSuite extends SharedSQLContext { Array(2, 5), Array(3, 6) )) + sparkSession.stop() } + test("warehouse path is set correctly by R constructor") { + SparkSession.clearDefaultSession() + val conf = new SparkConf().setAppName("test").setMaster("local") + val sparkContext2 = new SparkContext(conf) + val jsc = new JavaSparkContext(sparkContext2) + val warehouseDir = "/tmp/test-warehouse-dir" + val session = SQLUtils.getOrCreateSparkSession( + jsc, new HashMap[Object, Object], false, warehouseDir) + assert(session.sessionState.conf.warehousePath == warehouseDir) + session.stop() + SparkSession.clearDefaultSession() + } }