diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala index fee1bcdc9670..3d7bba7a82e8 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala @@ -17,19 +17,14 @@ package org.apache.spark.sql.protobuf.utils -import java.io.File -import java.io.FileNotFoundException -import java.nio.file.NoSuchFileException import java.util.Locale import scala.jdk.CollectionConverters._ -import scala.util.control.NonFatal import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBufferException, Message} import com.google.protobuf.DescriptorProtos.{FileDescriptorProto, FileDescriptorSet} import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} import com.google.protobuf.TypeRegistry -import org.apache.commons.io.FileUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.errors.QueryCompilationErrors @@ -228,18 +223,6 @@ private[sql] object ProtobufUtils extends Logging { } } - def readDescriptorFileContent(filePath: String): Array[Byte] = { - try { - FileUtils.readFileToByteArray(new File(filePath)) - } catch { - case ex: FileNotFoundException => - throw QueryCompilationErrors.cannotFindDescriptorFileError(filePath, ex) - case ex: NoSuchFileException => - throw QueryCompilationErrors.cannotFindDescriptorFileError(filePath, ex) - case NonFatal(ex) => throw QueryCompilationErrors.descriptorParseError(ex) - } - } - private def parseFileDescriptorSet(bytes: Array[Byte]): List[Descriptors.FileDescriptor] = { var fileDescriptorSet: DescriptorProtos.FileDescriptorSet = null try { diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala index ad6a88640140..abae1d622d3c 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.protobuf.utils.{ProtobufUtils, SchemaConverters} import org.apache.spark.sql.sources.{EqualTo, Not} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -39,7 +40,7 @@ class ProtobufCatalystDataConversionSuite with ProtobufTestBase { private val testFileDescFile = protobufDescriptorFile("catalyst_types.desc") - private val testFileDesc = ProtobufUtils.readDescriptorFileContent(testFileDescFile) + private val testFileDesc = CommonProtobufUtils.readDescriptorFileContent(testFileDescFile) private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.CatalystTypes$" private def checkResultWithEval( @@ -47,7 +48,7 @@ class ProtobufCatalystDataConversionSuite descFilePath: String, messageName: String, expected: Any): Unit = { - val descBytes = ProtobufUtils.readDescriptorFileContent(descFilePath) + val descBytes = CommonProtobufUtils.readDescriptorFileContent(descFilePath) withClue("(Eval check with Java class name)") { val className = s"$javaClassNamePrefix$messageName" checkEvaluation( @@ -72,7 +73,7 @@ class ProtobufCatalystDataConversionSuite actualSchema: String, badSchema: String): Unit = { - val descBytes = ProtobufUtils.readDescriptorFileContent(descFilePath) + val descBytes = CommonProtobufUtils.readDescriptorFileContent(descFilePath) val binary = CatalystDataToProtobuf(data, actualSchema, Some(descBytes)) intercept[Exception] { diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala index 3eaa91e472c4..44a8339ac1f0 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.protobuf.utils.ProtobufOptions import org.apache.spark.sql.protobuf.utils.ProtobufUtils import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils} class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with ProtobufTestBase with Serializable { @@ -40,11 +41,11 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot import testImplicits._ val testFileDescFile = protobufDescriptorFile("functions_suite.desc") - private val testFileDesc = ProtobufUtils.readDescriptorFileContent(testFileDescFile) + private val testFileDesc = CommonProtobufUtils.readDescriptorFileContent(testFileDescFile) private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SimpleMessageProtos$" val proto2FileDescFile = protobufDescriptorFile("proto2_messages.desc") - val proto2FileDesc = ProtobufUtils.readDescriptorFileContent(proto2FileDescFile) + val proto2FileDesc = CommonProtobufUtils.readDescriptorFileContent(proto2FileDescFile) private val proto2JavaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.Proto2Messages$" private def emptyBinaryDF = Seq(Array[Byte]()).toDF("binary") @@ -467,7 +468,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot test("Handle extra fields : oldProducer -> newConsumer") { val catalystTypesFile = protobufDescriptorFile("catalyst_types.desc") - val descBytes = ProtobufUtils.readDescriptorFileContent(catalystTypesFile) + val descBytes = CommonProtobufUtils.readDescriptorFileContent(catalystTypesFile) val oldProducer = ProtobufUtils.buildDescriptor(descBytes, "oldProducer") val newConsumer = ProtobufUtils.buildDescriptor(descBytes, "newConsumer") @@ -509,7 +510,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot test("Handle extra fields : newProducer -> oldConsumer") { val catalystTypesFile = protobufDescriptorFile("catalyst_types.desc") - val descBytes = ProtobufUtils.readDescriptorFileContent(catalystTypesFile) + val descBytes = CommonProtobufUtils.readDescriptorFileContent(catalystTypesFile) val newProducer = ProtobufUtils.buildDescriptor(descBytes, "newProducer") val oldConsumer = ProtobufUtils.buildDescriptor(descBytes, "oldConsumer") diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala index 2737bb9feb3a..f3bd49e1b24a 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType import org.apache.spark.sql.protobuf.utils.ProtobufUtils import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils} /** * Tests for [[ProtobufSerializer]] and [[ProtobufDeserializer]] with a more specific focus on @@ -37,12 +38,12 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { import ProtoSerdeSuite.MatchType._ private val testFileDescFile = protobufDescriptorFile("serde_suite.desc") - private val testFileDesc = ProtobufUtils.readDescriptorFileContent(testFileDescFile) + private val testFileDesc = CommonProtobufUtils.readDescriptorFileContent(testFileDescFile) private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SerdeSuiteProtos$" private val proto2DescFile = protobufDescriptorFile("proto2_messages.desc") - private val proto2Desc = ProtobufUtils.readDescriptorFileContent(proto2DescFile) + private val proto2Desc = CommonProtobufUtils.readDescriptorFileContent(proto2DescFile) test("Test basic conversion") { withFieldMatchType { fieldMatch => @@ -215,7 +216,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { val e1 = intercept[AnalysisException] { ProtobufUtils.buildDescriptor( - ProtobufUtils.readDescriptorFileContent(fileDescFile), + CommonProtobufUtils.readDescriptorFileContent(fileDescFile), "SerdeBasicMessage" ) } @@ -225,7 +226,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { condition = "CANNOT_PARSE_PROTOBUF_DESCRIPTOR") val basicMessageDescWithoutImports = descriptorSetWithoutImports( - ProtobufUtils.readDescriptorFileContent( + CommonProtobufUtils.readDescriptorFileContent( protobufDescriptorFile("basicmessage.desc") ), "BasicMessage" diff --git a/sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala index ea9e3c429d65..fab5cdc8de1b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala @@ -16,16 +16,12 @@ */ package org.apache.spark.sql.protobuf -import java.io.FileNotFoundException -import java.nio.file.{Files, NoSuchFileException, Paths} - import scala.jdk.CollectionConverters._ -import scala.util.control.NonFatal import org.apache.spark.annotation.Experimental import org.apache.spark.sql.Column -import org.apache.spark.sql.errors.CompilationErrors import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.util.ProtobufUtils // scalastyle:off: object.name object functions { @@ -51,7 +47,7 @@ object functions { messageName: String, descFilePath: String, options: java.util.Map[String, String]): Column = { - val descriptorFileContent = readDescriptorFileContent(descFilePath) + val descriptorFileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) from_protobuf(data, messageName, descriptorFileContent, options) } @@ -98,7 +94,7 @@ object functions { */ @Experimental def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = { - val fileContent = readDescriptorFileContent(descFilePath) + val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) from_protobuf(data, messageName, fileContent) } @@ -226,7 +222,7 @@ object functions { messageName: String, descFilePath: String, options: java.util.Map[String, String]): Column = { - val fileContent = readDescriptorFileContent(descFilePath) + val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) to_protobuf(data, messageName, fileContent, options) } @@ -299,18 +295,4 @@ object functions { options: java.util.Map[String, String]): Column = { Column.fnWithOptions("to_protobuf", options.asScala.iterator, data, lit(messageClassName)) } - - // This method is copied from org.apache.spark.sql.protobuf.util.ProtobufUtils - private def readDescriptorFileContent(filePath: String): Array[Byte] = { - try { - Files.readAllBytes(Paths.get(filePath)) - } catch { - case ex: FileNotFoundException => - throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex) - case ex: NoSuchFileException => - throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex) - case NonFatal(ex) => - throw CompilationErrors.descriptorParseError(ex) - } - } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ProtobufUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ProtobufUtils.scala new file mode 100644 index 000000000000..11f35ceb060c --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ProtobufUtils.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.io.{File, FileNotFoundException} +import java.nio.file.NoSuchFileException + +import scala.util.control.NonFatal + +import org.apache.commons.io.FileUtils + +import org.apache.spark.sql.errors.CompilationErrors + +object ProtobufUtils { + def readDescriptorFileContent(filePath: String): Array[Byte] = { + try { + FileUtils.readFileToByteArray(new File(filePath)) + } catch { + case ex: FileNotFoundException => + throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex) + case ex: NoSuchFileException => + throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex) + case NonFatal(ex) => throw CompilationErrors.descriptorParseError(ex) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala index ad9610ea0c78..96bcf49dbd09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala @@ -17,37 +17,15 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.File -import java.io.FileNotFoundException -import java.nio.file.NoSuchFileException - -import scala.util.control.NonFatal - -import org.apache.commons.io.FileUtils - import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{BinaryType, MapType, NullType, StringType} +import org.apache.spark.sql.util.ProtobufUtils import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -object ProtobufHelper { - def readDescriptorFileContent(filePath: String): Array[Byte] = { - try { - FileUtils.readFileToByteArray(new File(filePath)) - } catch { - case ex: FileNotFoundException => - throw new RuntimeException(s"Cannot find descriptor file at path: $filePath", ex) - case ex: NoSuchFileException => - throw new RuntimeException(s"Cannot find descriptor file at path: $filePath", ex) - case NonFatal(ex) => - throw new RuntimeException(s"Failed to read the descriptor file: $filePath", ex) - } - } -} - /** * Converts a binary column of Protobuf format into its corresponding catalyst value. * The Protobuf definition is provided through Protobuf descriptor file. @@ -163,7 +141,7 @@ case class FromProtobuf( } val descFilePathValue: Option[Array[Byte]] = descFilePath.eval() match { case s: UTF8String if s.toString.isEmpty => None - case s: UTF8String => Some(ProtobufHelper.readDescriptorFileContent(s.toString)) + case s: UTF8String => Some(ProtobufUtils.readDescriptorFileContent(s.toString)) case bytes: Array[Byte] if bytes.isEmpty => None case bytes: Array[Byte] => Some(bytes) case null => None @@ -300,7 +278,7 @@ case class ToProtobuf( s.toString } val descFilePathValue: Option[Array[Byte]] = descFilePath.eval() match { - case s: UTF8String => Some(ProtobufHelper.readDescriptorFileContent(s.toString)) + case s: UTF8String => Some(ProtobufUtils.readDescriptorFileContent(s.toString)) case bytes: Array[Byte] => Some(bytes) case null => None }