diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 1c87c48ae7cb..b66f94ae1107 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream} +import java.io._ import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.{SparkException, SparkFiles, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -72,6 +72,10 @@ trait BaseScriptTransformationExec extends UnaryExecNode { protected def initProc: (OutputStream, Process, InputStream, CircularBuffer) = { val cmd = List("/bin/bash", "-c", script) val builder = new ProcessBuilder(cmd.asJava) + .directory(new File(SparkFiles.getRootDirectory())) + val path = System.getenv("PATH") + File.pathSeparator + + SparkFiles.getRootDirectory() + builder.environment().put("PATH", path) val proc = builder.start() val inputStream = proc.getInputStream diff --git a/sql/core/src/test/resources/test_script.py b/sql/core/src/test/resources/test_script.py index 82ef7b38f0c1..75b4f106d3a1 100644 --- a/sql/core/src/test/resources/test_script.py +++ b/sql/core/src/test/resources/test_script.py @@ -1,3 +1,5 @@ +#! /usr/bin/python + # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index cf9ee1ef6db7..a25e4b8f8ea0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -470,6 +470,119 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU Row("3\u00014\u00015") :: Nil) } } + + test("SPARK-33934: Add SparkFile's root dir to env property PATH") { + assume(TestUtils.testCommandAvailable("python")) + val scriptFilePath = copyAndGetResourceFile("test_script.py", ".py").getAbsoluteFile + withTempView("v") { + val df = Seq( + (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1)), + (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2)), + (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3)) + ).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18) + df.createTempView("v") + + // test 'python /path/to/script.py' with local file + checkAnswer( + sql( + s""" + |SELECT + |TRANSFORM(a, b, c, d, e) + | ROW FORMAT DELIMITED + | FIELDS TERMINATED BY '\t' + | USING 'python $scriptFilePath' AS (a, b, c, d, e) + | ROW FORMAT DELIMITED + | FIELDS TERMINATED BY '\t' + |FROM v + """.stripMargin), identity, df.select( + 'a.cast("string"), + 'b.cast("string"), + 'c.cast("string"), + 'd.cast("string"), + 'e.cast("string")).collect()) + + // test '/path/to/script.py' with script not executable + val e1 = intercept[TestFailedException] { + checkAnswer( + sql( + s""" + |SELECT + |TRANSFORM(a, b, c, d, e) + | ROW FORMAT DELIMITED + | FIELDS TERMINATED BY '\t' + | USING '$scriptFilePath' AS (a, b, c, d, e) + | ROW FORMAT DELIMITED + | FIELDS TERMINATED BY '\t' + |FROM v + """.stripMargin), identity, df.select( + 'a.cast("string"), + 'b.cast("string"), + 'c.cast("string"), + 'd.cast("string"), + 'e.cast("string")).collect()) + }.getMessage + assert(e1.contains("Permission denied")) + + // test `/path/to/script.py' with script executable + scriptFilePath.setExecutable(true) + checkAnswer( + sql( + s""" + |SELECT + |TRANSFORM(a, b, c, d, e) + | ROW FORMAT DELIMITED + | FIELDS TERMINATED BY '\t' + | USING '$scriptFilePath' AS (a, b, c, d, e) + | ROW FORMAT DELIMITED + | FIELDS TERMINATED BY '\t' + |FROM v + """.stripMargin), identity, df.select( + 'a.cast("string"), + 'b.cast("string"), + 'c.cast("string"), + 'd.cast("string"), + 'e.cast("string")).collect()) + + scriptFilePath.setExecutable(false) + sql(s"ADD FILE ${scriptFilePath.getAbsolutePath}") + + // test `script.py` when file added + checkAnswer( + sql( + s""" + |SELECT TRANSFORM(a, b, c, d, e) + | ROW FORMAT DELIMITED + | FIELDS TERMINATED BY '\t' + | USING '${scriptFilePath.getName}' AS (a, b, c, d, e) + | ROW FORMAT DELIMITED + | FIELDS TERMINATED BY '\t' + |FROM v + """.stripMargin), identity, df.select( + 'a.cast("string"), + 'b.cast("string"), + 'c.cast("string"), + 'd.cast("string"), + 'e.cast("string")).collect()) + + // test `python script.py` when file added + checkAnswer( + sql( + s""" + |SELECT TRANSFORM(a, b, c, d, e) + | ROW FORMAT DELIMITED + | FIELDS TERMINATED BY '\t' + | USING 'python ${scriptFilePath.getName}' AS (a, b, c, d, e) + | ROW FORMAT DELIMITED + | FIELDS TERMINATED BY '\t' + |FROM v + """.stripMargin), identity, df.select( + 'a.cast("string"), + 'b.cast("string"), + 'c.cast("string"), + 'd.cast("string"), + 'e.cast("string")).collect()) + } + } } case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode {