diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index f5191fa9132bd..ba6a77f1b3b7d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -27,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hadoop.hive.ql.session.OperationLog import org.apache.hadoop.hive.shims.Utils import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.ExecuteStatementOperation @@ -170,12 +171,15 @@ private[hive] class SparkExecuteStatementOperation( override def run(): Unit = { val doAsAction = new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { + registerCurrentOperationLog() try { execute() } catch { case e: HiveSQLException => setOperationException(e) log.error("Error running hive query: ", e) + } finally { + unregisterOperationLog() } } } @@ -271,6 +275,19 @@ private[hive] class SparkExecuteStatementOperation( HiveThriftServer2.listener.onStatementFinish(statementId) } + private def registerCurrentOperationLog(): Unit = { + if (isOperationLogEnabled) { + if (operationLog == null) { + logWarning("Failed to get current OperationLog object of Operation: " + + getHandle().getHandleIdentifier()) + isOperationLogEnabled = false + } else { + OperationLog.setCurrentOperationLog(operationLog) + } + } + + } + override def cancel(): Unit = { logInfo(s"Cancel '$statement' with $statementId") cleanup(OperationState.CANCELED) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 4997d7f96afa2..08cb962e8bb11 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -22,6 +22,7 @@ import java.net.URL import java.nio.charset.StandardCharsets import java.sql.{Date, DriverManager, SQLException, Statement} +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.{ExecutionContext, Future, Promise} @@ -33,9 +34,11 @@ import com.google.common.io.Files import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.jdbc.HiveDriver import org.apache.hive.service.auth.PlainSaslHelper -import org.apache.hive.service.cli.{FetchOrientation, FetchType, GetInfoType} +import org.apache.hive.service.cli.GetInfoType +import org.apache.hive.service.cli.thrift.{ThriftCLIServiceClient, TProtocolVersion} import org.apache.hive.service.cli.thrift.TCLIService.Client -import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient +import org.apache.hive.service.cli.FetchOrientation +import org.apache.hive.service.cli.FetchType import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll @@ -597,6 +600,48 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { bufferSrc.close() } } + + test("SPARK-21395 registerCurrentOperationLog before execute sql statement") { + withCLIServiceClient { client => + val user = System.getProperty("user.name") + val sessionHandle = client.openSession(user, "") + + withJdbcStatement("test_21395") { statement => + val queries = Seq( + "CREATE TABLE test_21395(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_21395") + queries.foreach(statement.execute) + + val confOverlay = new java.util.HashMap[java.lang.String, java.lang.String] + val operationHandle = client.executeStatementAsync( + sessionHandle, + "SELECT * FROM test_21395", + confOverlay) + assertResult(true, "Fetching OperationLog from HiveServer2") { + while (!client.getOperationStatus(operationHandle).getState().isTerminal()) { + Thread.sleep(3.seconds.toMillis) + } + val rows_next = client.fetchResults( + operationHandle, + FetchOrientation.FETCH_FIRST, + 1000, + FetchType.LOG + ) + + val version = operationHandle.getProtocolVersion() + if (version.getValue() >= TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue()) { + val lines = rows_next.toTRowSet().getColumns().get(0) + .getStringVal().getValues().asScala + lines.exists(_.length > 0) + } else { + val lines = rows_next.toTRowSet().getRows().asScala + .map(_.getColVals().get(0).getStringVal().getValue) + lines.exists(_.length > 0) + } + } + } + } + } } class SingleSessionSuite extends HiveThriftJdbcTest { @@ -825,6 +870,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode | --hiveconf ${ConfVars.HIVE_SERVER2_LOGGING_OPERATION_LOG_LOCATION}=$operationLogPath + | --hiveconf ${ConfVars.HIVE_SERVER2_LOGGING_OPERATION_LEVEL}=VERBOSE | --hiveconf $portConf=$port | --driver-class-path $driverClassPath | --driver-java-options -Dlog4j.debug