Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,14 @@

package org.apache.spark.sql.protobuf.utils

import java.io.File
import java.io.FileNotFoundException
import java.nio.file.NoSuchFileException
import java.util.Locale

import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBufferException, Message}
import com.google.protobuf.DescriptorProtos.{FileDescriptorProto, FileDescriptorSet}
import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
import com.google.protobuf.TypeRegistry
import org.apache.commons.io.FileUtils

import org.apache.spark.internal.Logging
import org.apache.spark.sql.errors.QueryCompilationErrors
Expand Down Expand Up @@ -228,18 +223,6 @@ private[sql] object ProtobufUtils extends Logging {
}
}

def readDescriptorFileContent(filePath: String): Array[Byte] = {
try {
FileUtils.readFileToByteArray(new File(filePath))
} catch {
case ex: FileNotFoundException =>
throw QueryCompilationErrors.cannotFindDescriptorFileError(filePath, ex)
case ex: NoSuchFileException =>
throw QueryCompilationErrors.cannotFindDescriptorFileError(filePath, ex)
case NonFatal(ex) => throw QueryCompilationErrors.descriptorParseError(ex)
}
}

private def parseFileDescriptorSet(bytes: Array[Byte]): List[Descriptors.FileDescriptor] = {
var fileDescriptorSet: DescriptorProtos.FileDescriptorSet = null
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.protobuf.utils.{ProtobufUtils, SchemaConverters}
import org.apache.spark.sql.sources.{EqualTo, Not}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._

Expand All @@ -39,15 +40,15 @@ class ProtobufCatalystDataConversionSuite
with ProtobufTestBase {

private val testFileDescFile = protobufDescriptorFile("catalyst_types.desc")
private val testFileDesc = ProtobufUtils.readDescriptorFileContent(testFileDescFile)
private val testFileDesc = CommonProtobufUtils.readDescriptorFileContent(testFileDescFile)
private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.CatalystTypes$"

private def checkResultWithEval(
data: Literal,
descFilePath: String,
messageName: String,
expected: Any): Unit = {
val descBytes = ProtobufUtils.readDescriptorFileContent(descFilePath)
val descBytes = CommonProtobufUtils.readDescriptorFileContent(descFilePath)
withClue("(Eval check with Java class name)") {
val className = s"$javaClassNamePrefix$messageName"
checkEvaluation(
Expand All @@ -72,7 +73,7 @@ class ProtobufCatalystDataConversionSuite
actualSchema: String,
badSchema: String): Unit = {

val descBytes = ProtobufUtils.readDescriptorFileContent(descFilePath)
val descBytes = CommonProtobufUtils.readDescriptorFileContent(descFilePath)
val binary = CatalystDataToProtobuf(data, actualSchema, Some(descBytes))

intercept[Exception] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,19 @@ import org.apache.spark.sql.protobuf.utils.ProtobufOptions
import org.apache.spark.sql.protobuf.utils.ProtobufUtils
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils}

class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with ProtobufTestBase
with Serializable {

import testImplicits._

val testFileDescFile = protobufDescriptorFile("functions_suite.desc")
private val testFileDesc = ProtobufUtils.readDescriptorFileContent(testFileDescFile)
private val testFileDesc = CommonProtobufUtils.readDescriptorFileContent(testFileDescFile)
private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SimpleMessageProtos$"

val proto2FileDescFile = protobufDescriptorFile("proto2_messages.desc")
val proto2FileDesc = ProtobufUtils.readDescriptorFileContent(proto2FileDescFile)
val proto2FileDesc = CommonProtobufUtils.readDescriptorFileContent(proto2FileDescFile)
private val proto2JavaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.Proto2Messages$"

private def emptyBinaryDF = Seq(Array[Byte]()).toDF("binary")
Expand Down Expand Up @@ -467,7 +468,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot

test("Handle extra fields : oldProducer -> newConsumer") {
val catalystTypesFile = protobufDescriptorFile("catalyst_types.desc")
val descBytes = ProtobufUtils.readDescriptorFileContent(catalystTypesFile)
val descBytes = CommonProtobufUtils.readDescriptorFileContent(catalystTypesFile)

val oldProducer = ProtobufUtils.buildDescriptor(descBytes, "oldProducer")
val newConsumer = ProtobufUtils.buildDescriptor(descBytes, "newConsumer")
Expand Down Expand Up @@ -509,7 +510,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot

test("Handle extra fields : newProducer -> oldConsumer") {
val catalystTypesFile = protobufDescriptorFile("catalyst_types.desc")
val descBytes = ProtobufUtils.readDescriptorFileContent(catalystTypesFile)
val descBytes = CommonProtobufUtils.readDescriptorFileContent(catalystTypesFile)

val newProducer = ProtobufUtils.buildDescriptor(descBytes, "newProducer")
val oldConsumer = ProtobufUtils.buildDescriptor(descBytes, "oldConsumer")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType
import org.apache.spark.sql.protobuf.utils.ProtobufUtils
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils}

/**
* Tests for [[ProtobufSerializer]] and [[ProtobufDeserializer]] with a more specific focus on
Expand All @@ -37,12 +38,12 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {
import ProtoSerdeSuite.MatchType._

private val testFileDescFile = protobufDescriptorFile("serde_suite.desc")
private val testFileDesc = ProtobufUtils.readDescriptorFileContent(testFileDescFile)
private val testFileDesc = CommonProtobufUtils.readDescriptorFileContent(testFileDescFile)

private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SerdeSuiteProtos$"

private val proto2DescFile = protobufDescriptorFile("proto2_messages.desc")
private val proto2Desc = ProtobufUtils.readDescriptorFileContent(proto2DescFile)
private val proto2Desc = CommonProtobufUtils.readDescriptorFileContent(proto2DescFile)

test("Test basic conversion") {
withFieldMatchType { fieldMatch =>
Expand Down Expand Up @@ -215,7 +216,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {

val e1 = intercept[AnalysisException] {
ProtobufUtils.buildDescriptor(
ProtobufUtils.readDescriptorFileContent(fileDescFile),
CommonProtobufUtils.readDescriptorFileContent(fileDescFile),
"SerdeBasicMessage"
)
}
Expand All @@ -225,7 +226,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {
condition = "CANNOT_PARSE_PROTOBUF_DESCRIPTOR")

val basicMessageDescWithoutImports = descriptorSetWithoutImports(
ProtobufUtils.readDescriptorFileContent(
CommonProtobufUtils.readDescriptorFileContent(
protobufDescriptorFile("basicmessage.desc")
),
"BasicMessage"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,12 @@
*/
package org.apache.spark.sql.protobuf

import java.io.FileNotFoundException
import java.nio.file.{Files, NoSuchFileException, Paths}

import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.Column
import org.apache.spark.sql.errors.CompilationErrors
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.util.ProtobufUtils

// scalastyle:off: object.name
object functions {
Expand All @@ -51,7 +47,7 @@ object functions {
messageName: String,
descFilePath: String,
options: java.util.Map[String, String]): Column = {
val descriptorFileContent = readDescriptorFileContent(descFilePath)
val descriptorFileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
from_protobuf(data, messageName, descriptorFileContent, options)
}

Expand Down Expand Up @@ -98,7 +94,7 @@ object functions {
*/
@Experimental
def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = {
val fileContent = readDescriptorFileContent(descFilePath)
val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
from_protobuf(data, messageName, fileContent)
}

Expand Down Expand Up @@ -226,7 +222,7 @@ object functions {
messageName: String,
descFilePath: String,
options: java.util.Map[String, String]): Column = {
val fileContent = readDescriptorFileContent(descFilePath)
val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
to_protobuf(data, messageName, fileContent, options)
}

Expand Down Expand Up @@ -299,18 +295,4 @@ object functions {
options: java.util.Map[String, String]): Column = {
Column.fnWithOptions("to_protobuf", options.asScala.iterator, data, lit(messageClassName))
}

// This method is copied from org.apache.spark.sql.protobuf.util.ProtobufUtils
private def readDescriptorFileContent(filePath: String): Array[Byte] = {
try {
Files.readAllBytes(Paths.get(filePath))
} catch {
case ex: FileNotFoundException =>
throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex)
case ex: NoSuchFileException =>
throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex)
case NonFatal(ex) =>
throw CompilationErrors.descriptorParseError(ex)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.util

import java.io.{File, FileNotFoundException}
import java.nio.file.NoSuchFileException

import scala.util.control.NonFatal

import org.apache.commons.io.FileUtils

import org.apache.spark.sql.errors.CompilationErrors

object ProtobufUtils {
def readDescriptorFileContent(filePath: String): Array[Byte] = {
try {
FileUtils.readFileToByteArray(new File(filePath))
} catch {
case ex: FileNotFoundException =>
throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex)
case ex: NoSuchFileException =>
throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex)
case NonFatal(ex) => throw CompilationErrors.descriptorParseError(ex)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,15 @@

package org.apache.spark.sql.catalyst.expressions

import java.io.File
import java.io.FileNotFoundException
import java.nio.file.NoSuchFileException

import scala.util.control.NonFatal

import org.apache.commons.io.FileUtils

import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{BinaryType, MapType, NullType, StringType}
import org.apache.spark.sql.util.ProtobufUtils
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

object ProtobufHelper {
def readDescriptorFileContent(filePath: String): Array[Byte] = {
try {
FileUtils.readFileToByteArray(new File(filePath))
} catch {
case ex: FileNotFoundException =>
throw new RuntimeException(s"Cannot find descriptor file at path: $filePath", ex)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, you migrated the errors on error conditions too. Could you reflect this in PR's description and in the title.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeap, very detailed disclosure, has been updated!
thanks!

case ex: NoSuchFileException =>
throw new RuntimeException(s"Cannot find descriptor file at path: $filePath", ex)
case NonFatal(ex) =>
throw new RuntimeException(s"Failed to read the descriptor file: $filePath", ex)
}
}
}

/**
* Converts a binary column of Protobuf format into its corresponding catalyst value.
* The Protobuf definition is provided through Protobuf <i>descriptor file</i>.
Expand Down Expand Up @@ -163,7 +141,7 @@ case class FromProtobuf(
}
val descFilePathValue: Option[Array[Byte]] = descFilePath.eval() match {
case s: UTF8String if s.toString.isEmpty => None
case s: UTF8String => Some(ProtobufHelper.readDescriptorFileContent(s.toString))
case s: UTF8String => Some(ProtobufUtils.readDescriptorFileContent(s.toString))
case bytes: Array[Byte] if bytes.isEmpty => None
case bytes: Array[Byte] => Some(bytes)
case null => None
Expand Down Expand Up @@ -300,7 +278,7 @@ case class ToProtobuf(
s.toString
}
val descFilePathValue: Option[Array[Byte]] = descFilePath.eval() match {
case s: UTF8String => Some(ProtobufHelper.readDescriptorFileContent(s.toString))
case s: UTF8String => Some(ProtobufUtils.readDescriptorFileContent(s.toString))
case bytes: Array[Byte] => Some(bytes)
case null => None
}
Expand Down