diff --git a/linkis-computation-governance/linkis-computation-governance-common/src/main/scala/org/apache/linkis/governance/common/constant/job/JobRequestConstants.scala b/linkis-computation-governance/linkis-computation-governance-common/src/main/scala/org/apache/linkis/governance/common/constant/job/JobRequestConstants.scala index 9f2879d571..9f11419fb3 100644 --- a/linkis-computation-governance/linkis-computation-governance-common/src/main/scala/org/apache/linkis/governance/common/constant/job/JobRequestConstants.scala +++ b/linkis-computation-governance/linkis-computation-governance-common/src/main/scala/org/apache/linkis/governance/common/constant/job/JobRequestConstants.scala @@ -38,4 +38,7 @@ object JobRequestConstants { val LINKIS_JDBC_DEFAULT_DB = "linkis.jdbc.default.db" + val ENABLE_DIRECT_PUSH = "enableDirectPush" + + val DIRECT_PUSH_FETCH_SIZE = "direct_push_fetch_size" } diff --git a/linkis-computation-governance/linkis-engineconn/linkis-computation-engineconn/src/main/scala/org/apache/linkis/engineconn/computation/executor/execute/ComputationExecutor.scala b/linkis-computation-governance/linkis-engineconn/linkis-computation-engineconn/src/main/scala/org/apache/linkis/engineconn/computation/executor/execute/ComputationExecutor.scala index bd6d44e4a6..86ad2a1bf5 100644 --- a/linkis-computation-governance/linkis-engineconn/linkis-computation-engineconn/src/main/scala/org/apache/linkis/engineconn/computation/executor/execute/ComputationExecutor.scala +++ b/linkis-computation-governance/linkis-engineconn/linkis-computation-engineconn/src/main/scala/org/apache/linkis/engineconn/computation/executor/execute/ComputationExecutor.scala @@ -36,6 +36,7 @@ import org.apache.linkis.engineconn.core.EngineConnObject import org.apache.linkis.engineconn.core.executor.ExecutorManager import org.apache.linkis.engineconn.executor.entity.{LabelExecutor, ResourceExecutor} import org.apache.linkis.engineconn.executor.listener.ExecutorListenerBusContext +import org.apache.linkis.governance.common.constant.job.JobRequestConstants import org.apache.linkis.governance.common.entity.ExecutionNodeStatus import org.apache.linkis.governance.common.paser.CodeParser import org.apache.linkis.governance.common.protocol.task.{EngineConcurrentInfo, RequestTask} @@ -88,7 +89,7 @@ abstract class ComputationExecutor(val outputPrintLimit: Int = 1000) protected val failedTasks: Count = new Count - private var lastTask: EngineConnTask = _ + protected var lastTask: EngineConnTask = _ private val MAX_TASK_EXECUTE_NUM = ComputationExecutorConf.ENGINE_MAX_TASK_EXECUTE_NUM.getValue @@ -232,11 +233,13 @@ abstract class ComputationExecutor(val outputPrintLimit: Int = 1000) } val code = codes(index) engineExecutionContext.setCurrentParagraph(index + 1) + response = Utils.tryCatch(if (incomplete.nonEmpty) { executeCompletely(engineExecutionContext, code, incomplete.toString()) } else executeLine(engineExecutionContext, code)) { t => ErrorExecuteResponse(ExceptionUtils.getRootCauseMessage(t), t) } + incomplete ++= code response match { case e: ErrorExecuteResponse => @@ -355,6 +358,12 @@ abstract class ComputationExecutor(val outputPrintLimit: Int = 1000) engineConnTask.getProperties.get(RequestTask.RESULT_SET_STORE_PATH).toString ) } + if (engineConnTask.getProperties.containsKey(JobRequestConstants.ENABLE_DIRECT_PUSH)) { + engineExecutionContext.setEnableDirectPush( + engineConnTask.getProperties.get(JobRequestConstants.ENABLE_DIRECT_PUSH).toString.toBoolean + ) + logger.info(s"Enable direct push in engineTask ${engineConnTask.getTaskId}.") + } logger.info(s"StorePath : ${engineExecutionContext.getStorePath.orNull}.") engineExecutionContext.setJobId(engineConnTask.getTaskId) engineExecutionContext.getProperties.putAll(engineConnTask.getProperties) diff --git a/linkis-computation-governance/linkis-engineconn/linkis-computation-engineconn/src/main/scala/org/apache/linkis/engineconn/computation/executor/execute/EngineExecutionContext.scala b/linkis-computation-governance/linkis-engineconn/linkis-computation-engineconn/src/main/scala/org/apache/linkis/engineconn/computation/executor/execute/EngineExecutionContext.scala index 377c32c193..7767af9797 100644 --- a/linkis-computation-governance/linkis-engineconn/linkis-computation-engineconn/src/main/scala/org/apache/linkis/engineconn/computation/executor/execute/EngineExecutionContext.scala +++ b/linkis-computation-governance/linkis-engineconn/linkis-computation-engineconn/src/main/scala/org/apache/linkis/engineconn/computation/executor/execute/EngineExecutionContext.scala @@ -67,6 +67,7 @@ class EngineExecutionContext(executor: ComputationExecutor, executorUser: String private var totalParagraph = 0 private var currentParagraph = 0 + private var enableDirectPush = false def getTotalParagraph: Int = totalParagraph @@ -76,6 +77,11 @@ class EngineExecutionContext(executor: ComputationExecutor, executorUser: String def setCurrentParagraph(currentParagraph: Int): Unit = this.currentParagraph = currentParagraph + def setEnableDirectPush(enable: Boolean): Unit = + this.enableDirectPush = enable + + def isEnableDirectPush: Boolean = enableDirectPush + def pushProgress(progress: Float, progressInfo: Array[JobProgressInfo]): Unit = if (!executor.isInternalExecute) { val listenerBus = getEngineSyncListenerBus diff --git a/linkis-engineconn-plugins/spark/pom.xml b/linkis-engineconn-plugins/spark/pom.xml index 8d79a9bf72..62f3da22f3 100644 --- a/linkis-engineconn-plugins/spark/pom.xml +++ b/linkis-engineconn-plugins/spark/pom.xml @@ -344,6 +344,11 @@ + + org.apache.arrow + arrow-vector + ${arrow.version} + org.codehaus.janino janino diff --git a/linkis-engineconn-plugins/spark/src/main/java/org/apache/linkis/engineplugin/spark/DirectPushRestfulApi.java b/linkis-engineconn-plugins/spark/src/main/java/org/apache/linkis/engineplugin/spark/DirectPushRestfulApi.java new file mode 100644 index 0000000000..826b412d2c --- /dev/null +++ b/linkis-engineconn-plugins/spark/src/main/java/org/apache/linkis/engineplugin/spark/DirectPushRestfulApi.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.linkis.engineplugin.spark; + +import org.apache.linkis.engineplugin.spark.utils.DataFrameResponse; +import org.apache.linkis.engineplugin.spark.utils.DirectPushCache; +import org.apache.linkis.server.Message; + +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestMethod; +import org.springframework.web.bind.annotation.RestController; + +import javax.servlet.http.HttpServletRequest; + +import java.util.Map; + +import io.swagger.annotations.Api; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@Api(tags = "DirectPush") +@RestController +@RequestMapping(path = "directpush") +public class DirectPushRestfulApi { + private static final Logger logger = LoggerFactory.getLogger(DirectPushRestfulApi.class); + + @RequestMapping(path = "pull", method = RequestMethod.POST) + public Message getDirectPushResult( + HttpServletRequest req, @RequestBody Map json) { + Message message = null; + try { + String taskId = (String) json.getOrDefault("taskId", null); + if (taskId == null) { + message = Message.error("taskId is null"); + return message; + } + int fetchSize = (int) json.getOrDefault("fetchSize", 1000); + + DataFrameResponse response = DirectPushCache.fetchResultSetOfDataFrame(taskId, fetchSize); + if (response.dataFrame() == null) { + message = Message.error("No result found for taskId: " + taskId); + } else { + message = + Message.ok() + .data("data", response.dataFrame()) + .data("hasMoreData", response.hasMoreData()); + } + } catch (Exception e) { + logger.error("Failed to get direct push result", e); + message = Message.error("Failed to get direct push result: " + e.getMessage()); + } + return message; + } +} diff --git a/linkis-engineconn-plugins/spark/src/main/scala/org/apache/linkis/engineplugin/spark/executor/SparkSqlExecutor.scala b/linkis-engineconn-plugins/spark/src/main/scala/org/apache/linkis/engineplugin/spark/executor/SparkSqlExecutor.scala index 33b3dadb38..f435314d5d 100644 --- a/linkis-engineconn-plugins/spark/src/main/scala/org/apache/linkis/engineplugin/spark/executor/SparkSqlExecutor.scala +++ b/linkis-engineconn-plugins/spark/src/main/scala/org/apache/linkis/engineplugin/spark/executor/SparkSqlExecutor.scala @@ -22,16 +22,13 @@ import org.apache.linkis.engineconn.computation.executor.execute.EngineExecution import org.apache.linkis.engineplugin.spark.common.{Kind, SparkSQL} import org.apache.linkis.engineplugin.spark.config.SparkConfiguration import org.apache.linkis.engineplugin.spark.entity.SparkEngineSession -import org.apache.linkis.engineplugin.spark.utils.EngineUtils +import org.apache.linkis.engineplugin.spark.utils.{ArrowUtils, DirectPushCache, EngineUtils} import org.apache.linkis.governance.common.constant.job.JobRequestConstants import org.apache.linkis.governance.common.paser.SQLCodeParser -import org.apache.linkis.scheduler.executer.{ - ErrorExecuteResponse, - ExecuteResponse, - SuccessExecuteResponse -} +import org.apache.linkis.scheduler.executer._ import org.apache.commons.lang3.exception.ExceptionUtils +import org.apache.spark.sql.DataFrame import java.lang.reflect.InvocationTargetException @@ -47,6 +44,16 @@ class SparkSqlExecutor(sparkEngineSession: SparkEngineSession, id: Long) override protected def getKind: Kind = SparkSQL() + // Only used in the scenario of direct pushing, dataFrame won't be fetched at a time, + // It will cache the lazy dataFrame in memory and return the result when client . + private def submitResultSetIterator(taskId: String, df: DataFrame): Unit = { + if (!DirectPushCache.isTaskCached(taskId)) { + DirectPushCache.submitExecuteResult(taskId, df) + } else { + logger.error(s"Task $taskId already exists in resultSet cache.") + } + } + override protected def runCode( executor: SparkEngineConnExecutor, code: String, @@ -89,14 +96,19 @@ class SparkSqlExecutor(sparkEngineSession: SparkEngineSession, id: Long) ) ) ) - SQLSession.showDF( - sparkEngineSession.sparkContext, - jobGroup, - df, - null, - SparkConfiguration.SHOW_DF_MAX_RES.getValue, - engineExecutionContext - ) + + if (engineExecutionContext.isEnableDirectPush) { + submitResultSetIterator(lastTask.getTaskId, df) + } else { + SQLSession.showDF( + sparkEngineSession.sparkContext, + jobGroup, + df, + null, + SparkConfiguration.SHOW_DF_MAX_RES.getValue, + engineExecutionContext + ) + } SuccessExecuteResponse() } catch { case e: InvocationTargetException => diff --git a/linkis-engineconn-plugins/spark/src/main/scala/org/apache/linkis/engineplugin/spark/utils/ArrowUtils.scala b/linkis-engineconn-plugins/spark/src/main/scala/org/apache/linkis/engineplugin/spark/utils/ArrowUtils.scala new file mode 100644 index 0000000000..2a588398aa --- /dev/null +++ b/linkis-engineconn-plugins/spark/src/main/scala/org/apache/linkis/engineplugin/spark/utils/ArrowUtils.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.linkis.engineplugin.spark.utils + +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector._ +import org.apache.arrow.vector.ipc.ArrowStreamWriter +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types._ + +import java.io.ByteArrayOutputStream +import java.util + +object ArrowUtils { + + def toArrow(df: DataFrame): Array[Byte] = { + val allocator = new RootAllocator(Long.MaxValue) + val (root, fieldVectors) = createArrowVectors(df, allocator) + val outStream = new ByteArrayOutputStream() + val writer = new ArrowStreamWriter(root, null, outStream) + + writer.start() + writer.writeBatch() + writer.end() + writer.close() + + val arrowBytes = outStream.toByteArray + fieldVectors.foreach(_.close()) + allocator.close() + arrowBytes + } + + private def createArrowVectors( + df: DataFrame, + allocator: RootAllocator + ): (VectorSchemaRoot, List[FieldVector]) = { + val schema = df.schema + val fieldVectors = schema.fields.map { field => + field.dataType match { + case IntegerType => + val vector = new IntVector(field.name, allocator) + vector.allocateNew(df.count().toInt) + vector + case LongType => + val vector = new BigIntVector(field.name, allocator) + vector.allocateNew(df.count().toInt) + vector + case DoubleType => + val vector = new Float8Vector(field.name, allocator) + vector.allocateNew(df.count().toInt) + vector + case BooleanType => + val vector = new BitVector(field.name, allocator) + vector.allocateNew(df.count().toInt) + vector + case _ => + val vector: VarCharVector = new VarCharVector(field.name, allocator) + vector.allocateNew(df.count().toInt) + vector + } + }.toList + + df.collect().zipWithIndex.foreach { case (row, i) => + for (j <- fieldVectors.indices) { + val vector = fieldVectors(j) + row.schema.fields(j).dataType match { + case IntegerType => vector.asInstanceOf[IntVector].setSafe(i, row.getInt(j)) + case LongType => vector.asInstanceOf[BigIntVector].setSafe(i, row.getLong(j)) + case DoubleType => vector.asInstanceOf[Float8Vector].setSafe(i, row.getDouble(j)) + case BooleanType => + vector.asInstanceOf[BitVector].setSafe(i, if (row.getBoolean(j)) 1 else 0) + case _ => + vector.asInstanceOf[VarCharVector].setSafe(i, row.getString(j).getBytes) + } + vector.setValueCount(vector.getValueCount + 1) + } + } + + val javaFieldVectors: util.ArrayList[FieldVector] = new util.ArrayList[FieldVector]() + fieldVectors.foreach(javaFieldVectors.add) + val root = new VectorSchemaRoot(javaFieldVectors) + + (root, fieldVectors) + } + +} diff --git a/linkis-engineconn-plugins/spark/src/main/scala/org/apache/linkis/engineplugin/spark/utils/DirectPushCache.scala b/linkis-engineconn-plugins/spark/src/main/scala/org/apache/linkis/engineplugin/spark/utils/DirectPushCache.scala new file mode 100644 index 0000000000..6a540a9c60 --- /dev/null +++ b/linkis-engineconn-plugins/spark/src/main/scala/org/apache/linkis/engineplugin/spark/utils/DirectPushCache.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.linkis.engineplugin.spark.utils + +import org.apache.linkis.engineconn.common.conf.{EngineConnConf, EngineConnConstant} + +import org.apache.spark.sql.DataFrame + +import java.util.concurrent.TimeUnit + +import com.google.common.cache.{Cache, CacheBuilder} + +case class DataFrameResponse(dataFrame: DataFrame, hasMoreData: Boolean) + +object DirectPushCache { + + private val resultSet: Cache[String, DataFrame] = CacheBuilder + .newBuilder() + .expireAfterAccess(EngineConnConf.ENGINE_TASK_EXPIRE_TIME.getValue, TimeUnit.MILLISECONDS) + .maximumSize(EngineConnConstant.MAX_TASK_NUM) + .build() + + // This method is not idempotent. After fetching a result set of size fetchSize each time, the corresponding results will be removed from the cache. + def fetchResultSetOfDataFrame(taskId: String, fetchSize: Int): DataFrameResponse = { + val df = DirectPushCache.resultSet.getIfPresent(taskId) + if (df == null) { + throw new IllegalAccessException(s"Task $taskId not exists in resultSet cache.") + } else { + val batchDf = df.limit(fetchSize) + if (batchDf.count() < fetchSize) { + // All the data in df has been consumed. + DirectPushCache.resultSet.invalidate(taskId) + DataFrameResponse(batchDf, hasMoreData = false) + } else { + // Update df with consumed one. + DirectPushCache.resultSet.put(taskId, df.except(batchDf)) + DataFrameResponse(batchDf, hasMoreData = true) + } + } + } + + def isTaskCached(taskId: String): Boolean = { + DirectPushCache.resultSet.getIfPresent(taskId) != null + } + + def submitExecuteResult(taskId: String, df: DataFrame): Unit = { + DirectPushCache.resultSet.put(taskId, df) + } + +} diff --git a/linkis-engineconn-plugins/spark/src/test/scala/org/apache/linkis/engineplugin/spark/executor/TestArrowUtil.scala b/linkis-engineconn-plugins/spark/src/test/scala/org/apache/linkis/engineplugin/spark/executor/TestArrowUtil.scala new file mode 100644 index 0000000000..99e5d693f9 --- /dev/null +++ b/linkis-engineconn-plugins/spark/src/test/scala/org/apache/linkis/engineplugin/spark/executor/TestArrowUtil.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.linkis.engineplugin.spark.executor + +import org.apache.linkis.engineplugin.spark.utils.ArrowUtils + +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamReader +import org.apache.spark.sql.SparkSession + +import java.io.ByteArrayInputStream + +import org.junit.jupiter.api.{Assertions, Test} + +class TestArrowUtil { + + @Test + def testToArrow(): Unit = { + val sparkSession = SparkSession + .builder() + .master("local[1]") + .appName("testToArrow") + .getOrCreate() + val dataFrame = sparkSession + .createDataFrame( + Seq(("test1", 23, 552214221L), ("test2", 19, 41189877L), ("test3", 241, 1555223L)) + ) + .toDF("name", "age", "id") + val arrowBytes = ArrowUtils.toArrow(dataFrame) + + // read arrow bytes for checking + val allocator = new RootAllocator(Long.MaxValue) + val byteArrayInputStream = new ByteArrayInputStream(arrowBytes) + val streamReader = new ArrowStreamReader(byteArrayInputStream, allocator) + + try { + val root: VectorSchemaRoot = streamReader.getVectorSchemaRoot + val expectedData = + Seq(("test1", 23, 552214221L), ("test2", 19, 41189877L), ("test3", 241, 1555223L)) + + var rowIndex = 0 + while (streamReader.loadNextBatch()) { + for (i <- 0 until root.getRowCount) { + val name = root.getVector("name").getObject(i).toString + val age = root.getVector("age").getObject(i).asInstanceOf[Int] + val id = root.getVector("id").getObject(i).asInstanceOf[Long] + + val (expectedName, expectedAge, expectedId) = expectedData(rowIndex) + Assertions.assertEquals(name, expectedName) + Assertions.assertEquals(age, expectedAge) + Assertions.assertEquals(id, expectedId) + rowIndex += 1 + } + } + Assertions.assertEquals(rowIndex, expectedData.length) + } finally { + streamReader.close() + allocator.close() + } + } + +} diff --git a/pom.xml b/pom.xml index abca4c7e8b..b49645fdb9 100644 --- a/pom.xml +++ b/pom.xml @@ -114,6 +114,7 @@ 2.1.42 3.1.3 3.2.1 + 2.0.0 2.7.2 org.apache.hadoop hadoop-common diff --git a/tool/dependencies/known-dependencies.txt b/tool/dependencies/known-dependencies.txt index b8c73b6c88..6587cbbb5b 100644 --- a/tool/dependencies/known-dependencies.txt +++ b/tool/dependencies/known-dependencies.txt @@ -25,8 +25,11 @@ antlr-runtime-3.5.2.jar aopalliance-1.0.jar aopalliance-repackaged-2.4.0-b34.jar arrow-format-0.8.0.jar +arrow-format-2.0.0.jar arrow-memory-0.8.0.jar +arrow-memory-core-2.0.0.jar arrow-vector-0.8.0.jar +arrow-vector-2.0.0.jar asm-9.3.jar asm-analysis-9.3.jar asm-commons-9.3.jar @@ -138,6 +141,7 @@ feign-form-spring-3.8.0.jar feign-slf4j-11.10.jar findbugs-annotations-1.3.9-1.jar flatbuffers-1.2.0-3f79e055.jar +flatbuffers-java-1.9.0.jar flink-annotations-1.12.2.jar flink-annotations-1.16.2.jar flink-cep-1.16.2.jar