From b53e2d2ec14f591c67e5ffd5d3d4cf07639f92ba Mon Sep 17 00:00:00 2001 From: Civitaspo Date: Sun, 10 May 2020 07:15:14 +0900 Subject: [PATCH 1/5] Introduce Java <-> Scala implicit conversion --- .../s3_parquet/CatalogRegistrator.scala | 11 ++-- .../s3_parquet/S3ParquetOutputPlugin.scala | 49 ++++++++---------- .../embulk/output/s3_parquet/implicits.scala | 51 +++++++++++++++++++ .../parquet/LogicalTypeHandlerStore.scala | 10 ++-- .../parquet/ParquetFileWriteSupport.scala | 6 +-- 5 files changed, 85 insertions(+), 42 deletions(-) create mode 100644 src/main/scala/org/embulk/output/s3_parquet/implicits.scala diff --git a/src/main/scala/org/embulk/output/s3_parquet/CatalogRegistrator.scala b/src/main/scala/org/embulk/output/s3_parquet/CatalogRegistrator.scala index 4ad1dec..630a65f 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/CatalogRegistrator.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/CatalogRegistrator.scala @@ -27,7 +27,6 @@ import org.embulk.spi.`type`.{ } import org.slf4j.{Logger, LoggerFactory} -import scala.jdk.CollectionConverters._ import scala.util.Try object CatalogRegistrator { @@ -90,6 +89,8 @@ class CatalogRegistrator( parquetColumnLogicalTypes: Map[String, String] = Map.empty ) { + import implicits._ + val logger: Logger = loggerOption.getOrElse(LoggerFactory.getLogger(classOf[CatalogRegistrator])) @@ -156,7 +157,7 @@ class CatalogRegistrator( "EXTERNAL" -> "TRUE", "classification" -> "parquet", "parquet.compression" -> compressionCodec.name() - ).asJava + ) ) .withStorageDescriptor( new StorageDescriptor() @@ -174,7 +175,7 @@ class CatalogRegistrator( .withSerializationLibrary( "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" ) - .withParameters(Map("serialization.format" -> "1").asJava) + .withParameters(Map("serialization.format" -> "1")) ) ) ) @@ -183,8 +184,8 @@ class CatalogRegistrator( private def getGlueSchema: Seq[Column] = { val columnOptions: Map[String, ColumnOptions] = - task.getColumnOptions.asScala.toMap - schema.getColumns.asScala.toSeq.map { c => + task.getColumnOptions + schema.getColumns.map { c => val cType: String = if (columnOptions.contains(c.getName)) columnOptions(c.getName).getType else if (parquetColumnLogicalTypes.contains(c.getName)) diff --git a/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala b/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala index 5cb1253..eee6649 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala @@ -25,7 +25,8 @@ import org.embulk.config.{ } import org.embulk.output.s3_parquet.S3ParquetOutputPlugin.{ ColumnOptionTask, - PluginTask + PluginTask, + TypeOptionTask } import org.embulk.output.s3_parquet.aws.Aws import org.embulk.output.s3_parquet.parquet.{ @@ -44,7 +45,6 @@ import org.embulk.spi.time.TimestampFormatter.TimestampColumnOption import org.embulk.spi.util.Timestamps import org.slf4j.{Logger, LoggerFactory} -import scala.jdk.CollectionConverters._ import scala.util.chaining._ object S3ParquetOutputPlugin { @@ -70,9 +70,8 @@ object S3ParquetOutputPlugin { @ConfigDefault("\"uncompressed\"") def getCompressionCodecString: String - def setCompressionCodec(v: CompressionCodecName): Unit - def getCompressionCodec: CompressionCodecName + def setCompressionCodec(v: CompressionCodecName): Unit @Config("column_options") @ConfigDefault("{}") @@ -131,6 +130,8 @@ object S3ParquetOutputPlugin { class S3ParquetOutputPlugin extends OutputPlugin { + import implicits._ + val logger: Logger = LoggerFactory.getLogger(classOf[S3ParquetOutputPlugin]) override def transaction( @@ -149,21 +150,17 @@ class S3ParquetOutputPlugin extends OutputPlugin { s"s3://${task.getBucket}/${task.getPathPrefix.replaceFirst("(.*/)[^/]+$", "$1")}" val parquetColumnLogicalTypes: Map[String, String] = Map.newBuilder[String, String].pipe { builder => - val cOptions = task.getColumnOptions.asScala - val tOptions = task.getTypeOptions.asScala - schema.getColumns.asScala.foreach { c => - cOptions.get(c.getName) - if (cOptions - .contains(c.getName) && cOptions(c.getName).getLogicalType.isPresent) { - builder - .addOne(c.getName -> cOptions(c.getName).getLogicalType.get()) - } - else if (tOptions.contains(c.getType.getName) && tOptions( - c.getType.getName - ).getLogicalType.isPresent) { - builder.addOne( - c.getName -> tOptions(c.getType.getName).getLogicalType.get() - ) + val cOptions: Map[String, ColumnOptionTask] = task.getColumnOptions + val tOptions: Map[String, TypeOptionTask] = task.getTypeOptions + schema.getColumns.foreach { c => + { + for (o <- cOptions.get(c.getName); + logicalType <- o.getLogicalType) + yield builder.addOne(c.getName -> logicalType) + }.orElse { + for (o <- tOptions.get(c.getType.getName); + logicalType <- o.getLogicalType) + yield builder.addOne(c.getName -> logicalType) } } builder.result() @@ -217,7 +214,7 @@ class S3ParquetOutputPlugin extends OutputPlugin { task.getColumnOptions.forEach { (k: String, opt: ColumnOptionTask) => val c = schema.lookupColumn(k) val useTimestampOption = - opt.getFormat.isPresent || opt.getTimeZoneId.isPresent + opt.getFormat.isDefined || opt.getTimeZoneId.isDefined if (!c.getType.getName.equals("timestamp") && useTimestampOption) { throw new ConfigException(s"column:$k is not 'timestamp' type.") } @@ -257,7 +254,7 @@ class S3ParquetOutputPlugin extends OutputPlugin { taskCount: Int, successTaskReports: JList[TaskReport] ): Unit = { - successTaskReports.forEach { tr => + successTaskReports.foreach { tr => logger.info( s"Created: s3://${tr.get(classOf[String], "bucket")}/${tr.get(classOf[String], "key")}, " + s"version_id: ${tr.get(classOf[String], "version_id", null)}, " @@ -272,24 +269,20 @@ class S3ParquetOutputPlugin extends OutputPlugin { taskIndex: Int ): TransactionalPageOutput = { val task = taskSource.loadTask(classOf[PluginTask]) - val bufferDir: String = task.getBufferDir.orElse( + val bufferDir: String = task.getBufferDir.getOrElse( Files.createTempDirectory("embulk-output-s3_parquet-").toString ) val bufferFile: String = Paths .get(bufferDir, s"embulk-output-s3_parquet-task-$taskIndex-0.parquet") .toString val destS3bucket: String = task.getBucket - val destS3Key: String = task.getPathPrefix + String.format( - task.getSequenceFormat, - taskIndex: Integer, - 0: Integer - ) + task.getFileExt + val destS3Key: String = + s"${task.getPathPrefix}${task.getSequenceFormat.format(taskIndex, 0)}${task.getFileExt}" val pageReader: PageReader = new PageReader(schema) val aws: Aws = Aws(task) val timestampFormatters: Seq[TimestampFormatter] = Timestamps .newTimestampColumnFormatters(task, schema, task.getColumnOptions) - .toSeq val logicalTypeHandlers = LogicalTypeHandlerStore.fromEmbulkOptions( task.getTypeOptions, task.getColumnOptions diff --git a/src/main/scala/org/embulk/output/s3_parquet/implicits.scala b/src/main/scala/org/embulk/output/s3_parquet/implicits.scala new file mode 100644 index 0000000..eb50a59 --- /dev/null +++ b/src/main/scala/org/embulk/output/s3_parquet/implicits.scala @@ -0,0 +1,51 @@ +package org.embulk.output.s3_parquet + +import java.util.{Optional, Iterator => JIterator, List => JList, Map => JMap} + +import com.google.common.base.{Optional => GoogleOptional} + +import scala.jdk.CollectionConverters._ +import scala.language.implicitConversions + +case object implicits { + implicit def JList2Seq[A](a: JList[A]): Seq[A] = a.asScala.toSeq + implicit def Seq2JList[A](a: Seq[A]): JList[A] = a.asJava + implicit def JIte2Ite[A](a: JIterator[A]): Iterator[A] = a.asScala + implicit def Ite2JIte[A](a: Iterator[A]): JIterator[A] = a.asJava + + implicit def OptionalJList2OptionSeq[A]( + a: Optional[JList[A]] + ): Option[Seq[A]] = a.map(JList2Seq(_)) + + implicit def OptionSeq2OptionalJList[A]( + a: Option[Seq[A]] + ): Optional[JList[A]] = a.map(Seq2JList) + implicit def JMap2Map[K, V](a: JMap[K, V]): Map[K, V] = a.asScala.toMap + implicit def Map2JMap[K, V](a: Map[K, V]): JMap[K, V] = a.asJava + + implicit def OptionalJMap2OptionMap[K, V]( + a: Optional[JMap[K, V]] + ): Option[Map[K, V]] = a.map(JMap2Map(_)) + + implicit def OptionMap2Optional2JMap[K, V]( + a: Option[Map[K, V]] + ): Optional[JMap[K, V]] = a.map(Map2JMap) + + implicit def Optional2Option[A](a: Optional[A]): Option[A] = + if (a.isPresent) Some(a.get()) else None + + implicit def Option2Optional[A](a: Option[A]): Optional[A] = a match { + case Some(v) => Optional.of(v) + case None => Optional.empty() + } + + implicit def GoogleOptional2Option[A](a: GoogleOptional[A]): Option[A] = + Option(a.orNull()) + + implicit def Option2GoogleOptional[A](a: Option[A]): GoogleOptional[A] = + a match { + case Some(v) => GoogleOptional.of(v) + case None => GoogleOptional.absent() + } + +} diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandlerStore.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandlerStore.scala index e0dc4b3..6597923 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandlerStore.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandlerStore.scala @@ -9,8 +9,6 @@ import org.embulk.output.s3_parquet.S3ParquetOutputPlugin.{ } import org.embulk.spi.`type`.{Type, Types} -import scala.jdk.CollectionConverters._ - /** * A storage has mapping from logical type query (column name, type) to handler. * @@ -38,6 +36,8 @@ case class LogicalTypeHandlerStore private ( object LogicalTypeHandlerStore { + import org.embulk.output.s3_parquet.implicits._ + private val STRING_TO_EMBULK_TYPE = Map[String, Type]( "boolean" -> Types.BOOLEAN, "long" -> Types.LONG, @@ -74,7 +74,7 @@ object LogicalTypeHandlerStore { typeOpts: JMap[String, TypeOptionTask], columnOpts: JMap[String, ColumnOptionTask] ): LogicalTypeHandlerStore = { - val fromEmbulkType = typeOpts.asScala + val fromEmbulkType = typeOpts .filter(_._2.getLogicalType.isPresent) .map[Type, LogicalTypeHandler] { case (k, v) => @@ -86,9 +86,8 @@ object LogicalTypeHandlerStore { throw new ConfigException("invalid logical types in type_options") } } - .toMap - val fromColumnName = columnOpts.asScala + val fromColumnName = columnOpts .filter(_._2.getLogicalType.isPresent) .map[String, LogicalTypeHandler] { case (k, v) => @@ -101,7 +100,6 @@ object LogicalTypeHandlerStore { ) } } - .toMap LogicalTypeHandlerStore(fromEmbulkType, fromColumnName) } diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriteSupport.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriteSupport.scala index 76ff8dc..3a437a8 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriteSupport.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriteSupport.scala @@ -8,14 +8,14 @@ import org.apache.parquet.schema.MessageType import org.embulk.spi.{PageReader, Schema} import org.embulk.spi.time.TimestampFormatter -import scala.jdk.CollectionConverters._ - private[parquet] case class ParquetFileWriteSupport( schema: Schema, timestampFormatters: Seq[TimestampFormatter], logicalTypeHandlers: LogicalTypeHandlerStore = LogicalTypeHandlerStore.empty ) extends WriteSupport[PageReader] { + import org.embulk.output.s3_parquet.implicits._ + private var currentParquetFileWriter: ParquetFileWriter = _ override def init(configuration: Configuration): WriteContext = { @@ -25,7 +25,7 @@ private[parquet] case class ParquetFileWriteSupport( .withLogicalTypeHandlers(logicalTypeHandlers) .build() val metadata: Map[String, String] = Map.empty // NOTE: When is this used? - new WriteContext(messageType, metadata.asJava) + new WriteContext(messageType, metadata) } override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { From a780442e9f6bcc665e1787f2d31dee5214f50f8c Mon Sep 17 00:00:00 2001 From: Civitaspo Date: Mon, 11 May 2020 08:54:22 +0900 Subject: [PATCH 2/5] As more settings have been added, I extracted the PluginTask from the Plugin to make it easier to understand when reading the code. --- .../embulk/output/s3_parquet/PluginTask.scala | 90 +++++++++++++++++ .../s3_parquet/S3ParquetOutputPlugin.scala | 96 +------------------ .../parquet/LogicalTypeHandlerStore.scala | 2 +- 3 files changed, 93 insertions(+), 95 deletions(-) create mode 100644 src/main/scala/org/embulk/output/s3_parquet/PluginTask.scala diff --git a/src/main/scala/org/embulk/output/s3_parquet/PluginTask.scala b/src/main/scala/org/embulk/output/s3_parquet/PluginTask.scala new file mode 100644 index 0000000..2edbdda --- /dev/null +++ b/src/main/scala/org/embulk/output/s3_parquet/PluginTask.scala @@ -0,0 +1,90 @@ +package org.embulk.output.s3_parquet + +import java.util.{Optional, Map => JMap} + +import com.amazonaws.services.s3.model.CannedAccessControlList +import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.embulk.config.{Config, ConfigDefault, Task} +import org.embulk.output.s3_parquet.aws.Aws +import org.embulk.spi.time.TimestampFormatter +import org.embulk.spi.time.TimestampFormatter.TimestampColumnOption + +object PluginTask { + + trait ColumnOptionTask + extends Task + with TimestampColumnOption + with LogicalTypeOption + + trait TypeOptionTask extends Task with LogicalTypeOption + + trait LogicalTypeOption { + + @Config("logical_type") + def getLogicalType: Optional[String] + } +} + +trait PluginTask extends Task with TimestampFormatter.Task with Aws.Task { + + @Config("bucket") + def getBucket: String + + @Config("path_prefix") + @ConfigDefault("\"\"") + def getPathPrefix: String + + @Config("sequence_format") + @ConfigDefault("\"%03d.%02d.\"") + def getSequenceFormat: String + + @Config("file_ext") + @ConfigDefault("\"parquet\"") + def getFileExt: String + + @Config("compression_codec") + @ConfigDefault("\"uncompressed\"") + def getCompressionCodecString: String + + def getCompressionCodec: CompressionCodecName + def setCompressionCodec(v: CompressionCodecName): Unit + + @Config("column_options") + @ConfigDefault("{}") + def getColumnOptions: JMap[String, PluginTask.ColumnOptionTask] + + @Config("canned_acl") + @ConfigDefault("\"private\"") + def getCannedAclString: String + + def getCannedAcl: CannedAccessControlList + def setCannedAcl(v: CannedAccessControlList): Unit + + @Config("block_size") + @ConfigDefault("null") + def getBlockSize: Optional[Int] + + @Config("page_size") + @ConfigDefault("null") + def getPageSize: Optional[Int] + + @Config("max_padding_size") + @ConfigDefault("null") + def getMaxPaddingSize: Optional[Int] + + @Config("enable_dictionary_encoding") + @ConfigDefault("null") + def getEnableDictionaryEncoding: Optional[Boolean] + + @Config("buffer_dir") + @ConfigDefault("null") + def getBufferDir: Optional[String] + + @Config("catalog") + @ConfigDefault("null") + def getCatalog: Optional[CatalogRegistrator.Task] + + @Config("type_options") + @ConfigDefault("{}") + def getTypeOptions: JMap[String, PluginTask.TypeOptionTask] +} diff --git a/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala b/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala index eee6649..256aa54 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala @@ -1,31 +1,21 @@ package org.embulk.output.s3_parquet import java.nio.file.{Files, Paths} -import java.util.{ - IllegalFormatException, - Locale, - Optional, - List => JList, - Map => JMap -} +import java.util.{IllegalFormatException, Locale, List => JList} import com.amazonaws.services.s3.model.CannedAccessControlList import org.apache.parquet.column.ParquetProperties import org.apache.parquet.hadoop.ParquetWriter import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.embulk.config.{ - Config, - ConfigDefault, ConfigDiff, ConfigException, ConfigSource, - Task, TaskReport, TaskSource } -import org.embulk.output.s3_parquet.S3ParquetOutputPlugin.{ +import org.embulk.output.s3_parquet.PluginTask.{ ColumnOptionTask, - PluginTask, TypeOptionTask } import org.embulk.output.s3_parquet.aws.Aws @@ -41,93 +31,11 @@ import org.embulk.spi.{ TransactionalPageOutput } import org.embulk.spi.time.TimestampFormatter -import org.embulk.spi.time.TimestampFormatter.TimestampColumnOption import org.embulk.spi.util.Timestamps import org.slf4j.{Logger, LoggerFactory} import scala.util.chaining._ -object S3ParquetOutputPlugin { - - trait PluginTask extends Task with TimestampFormatter.Task with Aws.Task { - - @Config("bucket") - def getBucket: String - - @Config("path_prefix") - @ConfigDefault("\"\"") - def getPathPrefix: String - - @Config("sequence_format") - @ConfigDefault("\"%03d.%02d.\"") - def getSequenceFormat: String - - @Config("file_ext") - @ConfigDefault("\"parquet\"") - def getFileExt: String - - @Config("compression_codec") - @ConfigDefault("\"uncompressed\"") - def getCompressionCodecString: String - - def getCompressionCodec: CompressionCodecName - def setCompressionCodec(v: CompressionCodecName): Unit - - @Config("column_options") - @ConfigDefault("{}") - def getColumnOptions: JMap[String, ColumnOptionTask] - - @Config("canned_acl") - @ConfigDefault("\"private\"") - def getCannedAclString: String - - def setCannedAcl(v: CannedAccessControlList): Unit - - def getCannedAcl: CannedAccessControlList - - @Config("block_size") - @ConfigDefault("null") - def getBlockSize: Optional[Int] - - @Config("page_size") - @ConfigDefault("null") - def getPageSize: Optional[Int] - - @Config("max_padding_size") - @ConfigDefault("null") - def getMaxPaddingSize: Optional[Int] - - @Config("enable_dictionary_encoding") - @ConfigDefault("null") - def getEnableDictionaryEncoding: Optional[Boolean] - - @Config("buffer_dir") - @ConfigDefault("null") - def getBufferDir: Optional[String] - - @Config("catalog") - @ConfigDefault("null") - def getCatalog: Optional[CatalogRegistrator.Task] - - @Config("type_options") - @ConfigDefault("{}") - def getTypeOptions: JMap[String, TypeOptionTask] - } - - trait ColumnOptionTask - extends Task - with TimestampColumnOption - with LogicalTypeOption - - trait TypeOptionTask extends Task with LogicalTypeOption - - trait LogicalTypeOption { - - @Config("logical_type") - def getLogicalType: Optional[String] - } -} - class S3ParquetOutputPlugin extends OutputPlugin { import implicits._ diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandlerStore.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandlerStore.scala index 6597923..abfbc14 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandlerStore.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandlerStore.scala @@ -3,7 +3,7 @@ package org.embulk.output.s3_parquet.parquet import java.util.{Map => JMap} import org.embulk.config.ConfigException -import org.embulk.output.s3_parquet.S3ParquetOutputPlugin.{ +import org.embulk.output.s3_parquet.PluginTask.{ ColumnOptionTask, TypeOptionTask } From d2d045b244be22915e9b0e7ff2ca6465779cf531 Mon Sep 17 00:00:00 2001 From: Civitaspo Date: Fri, 15 May 2020 11:36:42 +0900 Subject: [PATCH 3/5] Set newlines.alwaysBeforeTopLevelStatements = false --- .scalafmt.conf | 1 - 1 file changed, 1 deletion(-) diff --git a/.scalafmt.conf b/.scalafmt.conf index 82d30e6..a8582cc 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -2,5 +2,4 @@ version = "2.4.2" newlines.alwaysBeforeElseAfterCurlyIf = true -newlines.alwaysBeforeTopLevelStatements = true assumeStandardLibraryStripMargin = true From 6b3b7521600723228337c67d943e2834b5e83d56 Mon Sep 17 00:00:00 2001 From: Civitaspo Date: Mon, 25 May 2020 09:21:32 +0900 Subject: [PATCH 4/5] Support more detailed logical type configuration & converted type --- .../s3_parquet/CatalogRegistrator.scala | 253 ------------- .../embulk/output/s3_parquet/PluginTask.scala | 98 +++-- .../s3_parquet/S3ParquetOutputPlugin.scala | 130 +------ .../s3_parquet/S3ParquetPageOutput.scala | 4 +- .../catalog/CatalogRegistrator.scala | 197 ++++++++++ .../s3_parquet/catalog/GlueDataType.scala | 57 +++ .../s3_parquet/parquet/DateLogicalType.scala | 66 ++++ .../parquet/DecimalLogicalType.scala | 104 ++++++ .../parquet/DefaultColumnType.scala | 77 ++++ .../parquet/EmbulkMessageType.scala | 114 ------ .../s3_parquet/parquet/IntLogicalType.scala | 175 +++++++++ .../s3_parquet/parquet/JsonLogicalType.scala | 75 ++++ .../parquet/LogicalTypeHandler.scala | 184 --------- .../parquet/LogicalTypeHandlerStore.scala | 106 ------ .../s3_parquet/parquet/LogicalTypeProxy.scala | 84 +++++ .../parquet/ParquetColumnType.scala | 278 ++++++++++++++ .../parquet/ParquetFileWriteSupport.scala | 198 +++++++++- .../parquet/ParquetFileWriter.scala | 167 --------- .../s3_parquet/parquet/TimeLogicalType.scala | 118 ++++++ .../parquet/TimestampLogicalType.scala | 95 +++++ .../s3_parquet/EmbulkPluginTestHelper.scala | 101 ++--- .../TestS3ParquetOutputPlugin.scala | 174 ++------- ...S3ParquetOutputPluginConfigException.scala | 12 +- .../parquet/MockParquetRecordConsumer.scala | 59 +++ .../parquet/ParquetColumnTypeTestHelper.scala | 17 + .../parquet/TestDateLogicalType.scala | 178 +++++++++ .../parquet/TestDecimalLogicalType.scala | 179 +++++++++ .../parquet/TestDefaultColumnType.scala | 157 ++++++++ .../parquet/TestIntLogicalType.scala | 349 ++++++++++++++++++ .../parquet/TestJsonLogicalType.scala | 148 ++++++++ .../parquet/TestLogicalTypeHandler.scala | 101 ----- .../parquet/TestLogicalTypeHandlerStore.scala | 177 --------- .../parquet/TestTimeLogicalType.scala | 223 +++++++++++ .../parquet/TestTimestampLogicalType.scala | 189 ++++++++++ 34 files changed, 3173 insertions(+), 1471 deletions(-) delete mode 100644 src/main/scala/org/embulk/output/s3_parquet/CatalogRegistrator.scala create mode 100644 src/main/scala/org/embulk/output/s3_parquet/catalog/CatalogRegistrator.scala create mode 100644 src/main/scala/org/embulk/output/s3_parquet/catalog/GlueDataType.scala create mode 100644 src/main/scala/org/embulk/output/s3_parquet/parquet/DateLogicalType.scala create mode 100644 src/main/scala/org/embulk/output/s3_parquet/parquet/DecimalLogicalType.scala create mode 100644 src/main/scala/org/embulk/output/s3_parquet/parquet/DefaultColumnType.scala delete mode 100644 src/main/scala/org/embulk/output/s3_parquet/parquet/EmbulkMessageType.scala create mode 100644 src/main/scala/org/embulk/output/s3_parquet/parquet/IntLogicalType.scala create mode 100644 src/main/scala/org/embulk/output/s3_parquet/parquet/JsonLogicalType.scala delete mode 100644 src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandler.scala delete mode 100644 src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandlerStore.scala create mode 100644 src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeProxy.scala create mode 100644 src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetColumnType.scala delete mode 100644 src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriter.scala create mode 100644 src/main/scala/org/embulk/output/s3_parquet/parquet/TimeLogicalType.scala create mode 100644 src/main/scala/org/embulk/output/s3_parquet/parquet/TimestampLogicalType.scala create mode 100644 src/test/scala/org/embulk/output/s3_parquet/parquet/MockParquetRecordConsumer.scala create mode 100644 src/test/scala/org/embulk/output/s3_parquet/parquet/ParquetColumnTypeTestHelper.scala create mode 100644 src/test/scala/org/embulk/output/s3_parquet/parquet/TestDateLogicalType.scala create mode 100644 src/test/scala/org/embulk/output/s3_parquet/parquet/TestDecimalLogicalType.scala create mode 100644 src/test/scala/org/embulk/output/s3_parquet/parquet/TestDefaultColumnType.scala create mode 100644 src/test/scala/org/embulk/output/s3_parquet/parquet/TestIntLogicalType.scala create mode 100644 src/test/scala/org/embulk/output/s3_parquet/parquet/TestJsonLogicalType.scala delete mode 100644 src/test/scala/org/embulk/output/s3_parquet/parquet/TestLogicalTypeHandler.scala delete mode 100644 src/test/scala/org/embulk/output/s3_parquet/parquet/TestLogicalTypeHandlerStore.scala create mode 100644 src/test/scala/org/embulk/output/s3_parquet/parquet/TestTimeLogicalType.scala create mode 100644 src/test/scala/org/embulk/output/s3_parquet/parquet/TestTimestampLogicalType.scala diff --git a/src/main/scala/org/embulk/output/s3_parquet/CatalogRegistrator.scala b/src/main/scala/org/embulk/output/s3_parquet/CatalogRegistrator.scala deleted file mode 100644 index 630a65f..0000000 --- a/src/main/scala/org/embulk/output/s3_parquet/CatalogRegistrator.scala +++ /dev/null @@ -1,253 +0,0 @@ -package org.embulk.output.s3_parquet - -import java.util.{Optional, Map => JMap} - -import com.amazonaws.services.glue.model.{ - Column, - CreateTableRequest, - DeleteTableRequest, - GetTableRequest, - SerDeInfo, - StorageDescriptor, - TableInput -} -import org.apache.parquet.hadoop.metadata.CompressionCodecName -import org.embulk.config.{Config, ConfigDefault, ConfigException} -import org.embulk.output.s3_parquet.aws.Aws -import org.embulk.output.s3_parquet.CatalogRegistrator.ColumnOptions -import org.embulk.spi.Schema -import org.embulk.spi.`type`.{ - BooleanType, - DoubleType, - JsonType, - LongType, - StringType, - TimestampType, - Type -} -import org.slf4j.{Logger, LoggerFactory} - -import scala.util.Try - -object CatalogRegistrator { - - trait Task extends org.embulk.config.Task { - - @Config("catalog_id") - @ConfigDefault("null") - def getCatalogId: Optional[String] - - @Config("database") - def getDatabase: String - - @Config("table") - def getTable: String - - @Config("column_options") - @ConfigDefault("{}") - def getColumnOptions: JMap[String, ColumnOptions] - - @Config("operation_if_exists") - @ConfigDefault("\"delete\"") - def getOperationIfExists: String - } - - trait ColumnOptions { - - @Config("type") - def getType: String - } - - def apply( - aws: Aws, - task: Task, - schema: Schema, - location: String, - compressionCodec: CompressionCodecName, - loggerOption: Option[Logger] = None, - parquetColumnLogicalTypes: Map[String, String] = Map.empty - ): CatalogRegistrator = { - new CatalogRegistrator( - aws, - task, - schema, - location, - compressionCodec, - loggerOption, - parquetColumnLogicalTypes - ) - } -} - -class CatalogRegistrator( - aws: Aws, - task: CatalogRegistrator.Task, - schema: Schema, - location: String, - compressionCodec: CompressionCodecName, - loggerOption: Option[Logger] = None, - parquetColumnLogicalTypes: Map[String, String] = Map.empty -) { - - import implicits._ - - val logger: Logger = - loggerOption.getOrElse(LoggerFactory.getLogger(classOf[CatalogRegistrator])) - - def run(): Unit = { - if (doesTableExists()) { - task.getOperationIfExists match { - case "skip" => - logger.info( - s"Skip to register the table: ${task.getDatabase}.${task.getTable}" - ) - return - - case "delete" => - logger.info(s"Delete the table: ${task.getDatabase}.${task.getTable}") - deleteTable() - - case unknown => - throw new ConfigException(s"Unsupported operation: $unknown") - } - } - registerNewParquetTable() - showNewTableInfo() - } - - def showNewTableInfo(): Unit = { - val req = new GetTableRequest() - task.getCatalogId.ifPresent(cid => req.setCatalogId(cid)) - req.setDatabaseName(task.getDatabase) - req.setName(task.getTable) - - val t = aws.withGlue(_.getTable(req)).getTable - logger.info(s"Created a table: ${t.toString}") - } - - def doesTableExists(): Boolean = { - val req = new GetTableRequest() - task.getCatalogId.ifPresent(cid => req.setCatalogId(cid)) - req.setDatabaseName(task.getDatabase) - req.setName(task.getTable) - - Try(aws.withGlue(_.getTable(req))).isSuccess - } - - def deleteTable(): Unit = { - val req = new DeleteTableRequest() - task.getCatalogId.ifPresent(cid => req.setCatalogId(cid)) - req.setDatabaseName(task.getDatabase) - req.setName(task.getTable) - aws.withGlue(_.deleteTable(req)) - } - - def registerNewParquetTable(): Unit = { - logger.info(s"Create a new table: ${task.getDatabase}.${task.getTable}") - val req = new CreateTableRequest() - task.getCatalogId.ifPresent(cid => req.setCatalogId(cid)) - req.setDatabaseName(task.getDatabase) - req.setTableInput( - new TableInput() - .withName(task.getTable) - .withDescription("Created by embulk-output-s3_parquet") - .withTableType("EXTERNAL_TABLE") - .withParameters( - Map( - "EXTERNAL" -> "TRUE", - "classification" -> "parquet", - "parquet.compression" -> compressionCodec.name() - ) - ) - .withStorageDescriptor( - new StorageDescriptor() - .withColumns(getGlueSchema: _*) - .withLocation(location) - .withCompressed(isCompressed) - .withInputFormat( - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat" - ) - .withOutputFormat( - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat" - ) - .withSerdeInfo( - new SerDeInfo() - .withSerializationLibrary( - "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" - ) - .withParameters(Map("serialization.format" -> "1")) - ) - ) - ) - aws.withGlue(_.createTable(req)) - } - - private def getGlueSchema: Seq[Column] = { - val columnOptions: Map[String, ColumnOptions] = - task.getColumnOptions - schema.getColumns.map { c => - val cType: String = - if (columnOptions.contains(c.getName)) columnOptions(c.getName).getType - else if (parquetColumnLogicalTypes.contains(c.getName)) - convertParquetLogicalTypeToGlueType( - parquetColumnLogicalTypes(c.getName) - ) - else convertEmbulkTypeToGlueType(c.getType) - new Column() - .withName(c.getName) - .withType(cType) - } - } - - private def convertParquetLogicalTypeToGlueType(t: String): String = { - t match { - case "timestamp-millis" => "timestamp" - case "timestamp-micros" => - "bigint" // Glue cannot recognize timestamp-micros. - case "timestamp-nanos" => - "bigint" // Glue cannot recognize timestamp-nanos. - case "int8" => "tinyint" - case "int16" => "smallint" - case "int32" => "int" - case "int64" => "bigint" - case "uint8" => - "smallint" // Glue tinyint is a minimum value of -2^7 and a maximum value of 2^7-1 - case "uint16" => - "int" // Glue smallint is a minimum value of -2^15 and a maximum value of 2^15-1. - case "uint32" => - "bigint" // Glue int is a minimum value of-2^31 and a maximum value of 2^31-1. - case "uint64" => - throw new ConfigException( - "Cannot convert uint64 to Glue data types automatically" + - " because the Glue bigint supports a 64-bit signed integer." + - " Please use `catalog.column_options` to define the type." - ) - case "json" => "string" - case _ => - throw new ConfigException( - s"Unsupported a parquet logical type: $t. Please use `catalog.column_options` to define the type." - ) - } - - } - - private def convertEmbulkTypeToGlueType(t: Type): String = { - t match { - case _: BooleanType => "boolean" - case _: LongType => "bigint" - case _: DoubleType => "double" - case _: StringType => "string" - case _: TimestampType => "string" - case _: JsonType => "string" - case unknown => - throw new ConfigException( - s"Unsupported embulk type: ${unknown.getName}" - ) - } - } - - private def isCompressed: Boolean = { - !compressionCodec.equals(CompressionCodecName.UNCOMPRESSED) - } - -} diff --git a/src/main/scala/org/embulk/output/s3_parquet/PluginTask.scala b/src/main/scala/org/embulk/output/s3_parquet/PluginTask.scala index 2edbdda..db8b67d 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/PluginTask.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/PluginTask.scala @@ -1,31 +1,22 @@ package org.embulk.output.s3_parquet -import java.util.{Optional, Map => JMap} +import java.util.{Locale, MissingFormatArgumentException, Optional} import com.amazonaws.services.s3.model.CannedAccessControlList import org.apache.parquet.hadoop.metadata.CompressionCodecName -import org.embulk.config.{Config, ConfigDefault, Task} -import org.embulk.output.s3_parquet.aws.Aws -import org.embulk.spi.time.TimestampFormatter -import org.embulk.spi.time.TimestampFormatter.TimestampColumnOption - -object PluginTask { - - trait ColumnOptionTask - extends Task - with TimestampColumnOption - with LogicalTypeOption - - trait TypeOptionTask extends Task with LogicalTypeOption - - trait LogicalTypeOption { - - @Config("logical_type") - def getLogicalType: Optional[String] - } +import org.embulk.config.{ + Config, + ConfigDefault, + ConfigException, + ConfigSource, + Task, + TaskSource } +import org.embulk.output.s3_parquet.aws.Aws +import org.embulk.output.s3_parquet.catalog.CatalogRegistrator +import org.embulk.output.s3_parquet.parquet.ParquetFileWriteSupport -trait PluginTask extends Task with TimestampFormatter.Task with Aws.Task { +trait PluginTask extends Task with ParquetFileWriteSupport.Task with Aws.Task { @Config("bucket") def getBucket: String @@ -49,10 +40,6 @@ trait PluginTask extends Task with TimestampFormatter.Task with Aws.Task { def getCompressionCodec: CompressionCodecName def setCompressionCodec(v: CompressionCodecName): Unit - @Config("column_options") - @ConfigDefault("{}") - def getColumnOptions: JMap[String, PluginTask.ColumnOptionTask] - @Config("canned_acl") @ConfigDefault("\"private\"") def getCannedAclString: String @@ -83,8 +70,63 @@ trait PluginTask extends Task with TimestampFormatter.Task with Aws.Task { @Config("catalog") @ConfigDefault("null") def getCatalog: Optional[CatalogRegistrator.Task] +} + +object PluginTask { + + def loadConfig(config: ConfigSource): PluginTask = { + val task = config.loadConfig(classOf[PluginTask]) + // sequence_format + try task.getSequenceFormat.format(0, 0) + catch { + case e: MissingFormatArgumentException => + throw new ConfigException( + s"Invalid sequence_format: ${task.getSequenceFormat}", + e + ) + } + + // compression_codec + CompressionCodecName + .values() + .find( + _.name() + .toLowerCase(Locale.ENGLISH) + .equals(task.getCompressionCodecString) + ) match { + case Some(v) => task.setCompressionCodec(v) + case None => + val unsupported: String = task.getCompressionCodecString + val supported: String = CompressionCodecName + .values() + .map(v => s"'${v.name().toLowerCase}'") + .mkString(", ") + throw new ConfigException( + s"'$unsupported' is unsupported: `compression_codec` must be one of [$supported]." + ) + } + + // canned_acl + CannedAccessControlList + .values() + .find(_.toString.equals(task.getCannedAclString)) match { + case Some(v) => task.setCannedAcl(v) + case None => + val unsupported: String = task.getCannedAclString + val supported: String = CannedAccessControlList + .values() + .map(v => s"'${v.toString}'") + .mkString(", ") + throw new ConfigException( + s"'$unsupported' is unsupported: `canned_acl` must be one of [$supported]." + ) + } + + ParquetFileWriteSupport.configure(task) + task + } + + def loadTask(taskSource: TaskSource): PluginTask = + taskSource.loadTask(classOf[PluginTask]) - @Config("type_options") - @ConfigDefault("{}") - def getTypeOptions: JMap[String, PluginTask.TypeOptionTask] } diff --git a/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala b/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala index 256aa54..6512eaa 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala @@ -1,28 +1,14 @@ package org.embulk.output.s3_parquet import java.nio.file.{Files, Paths} -import java.util.{IllegalFormatException, Locale, List => JList} +import java.util.{List => JList} -import com.amazonaws.services.s3.model.CannedAccessControlList import org.apache.parquet.column.ParquetProperties import org.apache.parquet.hadoop.ParquetWriter -import org.apache.parquet.hadoop.metadata.CompressionCodecName -import org.embulk.config.{ - ConfigDiff, - ConfigException, - ConfigSource, - TaskReport, - TaskSource -} -import org.embulk.output.s3_parquet.PluginTask.{ - ColumnOptionTask, - TypeOptionTask -} +import org.embulk.config.{ConfigDiff, ConfigSource, TaskReport, TaskSource} import org.embulk.output.s3_parquet.aws.Aws -import org.embulk.output.s3_parquet.parquet.{ - LogicalTypeHandlerStore, - ParquetFileWriter -} +import org.embulk.output.s3_parquet.catalog.CatalogRegistrator +import org.embulk.output.s3_parquet.parquet.ParquetFileWriteSupport import org.embulk.spi.{ Exec, OutputPlugin, @@ -30,12 +16,8 @@ import org.embulk.spi.{ Schema, TransactionalPageOutput } -import org.embulk.spi.time.TimestampFormatter -import org.embulk.spi.util.Timestamps import org.slf4j.{Logger, LoggerFactory} -import scala.util.chaining._ - class S3ParquetOutputPlugin extends OutputPlugin { import implicits._ @@ -48,38 +30,22 @@ class S3ParquetOutputPlugin extends OutputPlugin { taskCount: Int, control: OutputPlugin.Control ): ConfigDiff = { - val task: PluginTask = config.loadConfig(classOf[PluginTask]) - - configure(task, schema) + val task: PluginTask = PluginTask.loadConfig(config) + val support: ParquetFileWriteSupport = ParquetFileWriteSupport(task, schema) + support.showOutputSchema(logger) control.run(task.dump) task.getCatalog.ifPresent { catalog => val location = s"s3://${task.getBucket}/${task.getPathPrefix.replaceFirst("(.*/)[^/]+$", "$1")}" - val parquetColumnLogicalTypes: Map[String, String] = - Map.newBuilder[String, String].pipe { builder => - val cOptions: Map[String, ColumnOptionTask] = task.getColumnOptions - val tOptions: Map[String, TypeOptionTask] = task.getTypeOptions - schema.getColumns.foreach { c => - { - for (o <- cOptions.get(c.getName); - logicalType <- o.getLogicalType) - yield builder.addOne(c.getName -> logicalType) - }.orElse { - for (o <- tOptions.get(c.getType.getName); - logicalType <- o.getLogicalType) - yield builder.addOne(c.getName -> logicalType) - } - } - builder.result() - } - val cr = CatalogRegistrator( - aws = Aws(task), + val cr = CatalogRegistrator.fromTask( task = catalog, + aws = Aws(task), schema = schema, location = location, compressionCodec = task.getCompressionCodec, - parquetColumnLogicalTypes = parquetColumnLogicalTypes + defaultGlueTypes = + support.parquetSchema.transform((k, v) => v.glueDataType(k)) ) cr.run() } @@ -87,64 +53,6 @@ class S3ParquetOutputPlugin extends OutputPlugin { Exec.newConfigDiff } - private def configure(task: PluginTask, schema: Schema): Unit = { - // sequence_format - try String.format(task.getSequenceFormat, 0: Integer, 0: Integer) - catch { - case e: IllegalFormatException => - throw new ConfigException( - s"Invalid sequence_format: ${task.getSequenceFormat}", - e - ) - } - - // compression_codec - CompressionCodecName - .values() - .find(v => - v.name() - .toLowerCase(Locale.ENGLISH) - .equals(task.getCompressionCodecString) - ) match { - case Some(v) => task.setCompressionCodec(v) - case None => - val unsupported: String = task.getCompressionCodecString - val supported: String = CompressionCodecName - .values() - .map(v => s"'${v.name().toLowerCase}'") - .mkString(", ") - throw new ConfigException( - s"'$unsupported' is unsupported: `compression_codec` must be one of [$supported]." - ) - } - - // column_options - task.getColumnOptions.forEach { (k: String, opt: ColumnOptionTask) => - val c = schema.lookupColumn(k) - val useTimestampOption = - opt.getFormat.isDefined || opt.getTimeZoneId.isDefined - if (!c.getType.getName.equals("timestamp") && useTimestampOption) { - throw new ConfigException(s"column:$k is not 'timestamp' type.") - } - } - - // canned_acl - CannedAccessControlList - .values() - .find(v => v.toString.equals(task.getCannedAclString)) match { - case Some(v) => task.setCannedAcl(v) - case None => - val unsupported: String = task.getCannedAclString - val supported: String = CannedAccessControlList - .values() - .map(v => s"'${v.toString}'") - .mkString(", ") - throw new ConfigException( - s"'$unsupported' is unsupported: `canned_acl` must be one of [$supported]." - ) - } - } - override def resume( taskSource: TaskSource, schema: Schema, @@ -176,7 +84,7 @@ class S3ParquetOutputPlugin extends OutputPlugin { schema: Schema, taskIndex: Int ): TransactionalPageOutput = { - val task = taskSource.loadTask(classOf[PluginTask]) + val task = PluginTask.loadTask(taskSource) val bufferDir: String = task.getBufferDir.getOrElse( Files.createTempDirectory("embulk-output-s3_parquet-").toString ) @@ -189,20 +97,10 @@ class S3ParquetOutputPlugin extends OutputPlugin { val pageReader: PageReader = new PageReader(schema) val aws: Aws = Aws(task) - val timestampFormatters: Seq[TimestampFormatter] = Timestamps - .newTimestampColumnFormatters(task, schema, task.getColumnOptions) - val logicalTypeHandlers = LogicalTypeHandlerStore.fromEmbulkOptions( - task.getTypeOptions, - task.getColumnOptions - ) val parquetWriter: ParquetWriter[PageReader] = ContextClassLoaderSwapper.usingPluginClass { - ParquetFileWriter - .builder() - .withPath(bufferFile) - .withSchema(schema) - .withLogicalTypeHandlers(logicalTypeHandlers) - .withTimestampFormatters(timestampFormatters) + ParquetFileWriteSupport(task, schema) + .newWriterBuilder(bufferFile) .withCompressionCodec(task.getCompressionCodec) .withDictionaryEncoding( task.getEnableDictionaryEncoding.orElse( diff --git a/src/main/scala/org/embulk/output/s3_parquet/S3ParquetPageOutput.scala b/src/main/scala/org/embulk/output/s3_parquet/S3ParquetPageOutput.scala index eb0cc22..d9e8293 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/S3ParquetPageOutput.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/S3ParquetPageOutput.scala @@ -24,7 +24,9 @@ case class S3ParquetPageOutput( override def add(page: Page): Unit = { reader.setPage(page) while (reader.nextRecord()) { - writer.write(reader) + ContextClassLoaderSwapper.usingPluginClass { + writer.write(reader) + } } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/catalog/CatalogRegistrator.scala b/src/main/scala/org/embulk/output/s3_parquet/catalog/CatalogRegistrator.scala new file mode 100644 index 0000000..b157360 --- /dev/null +++ b/src/main/scala/org/embulk/output/s3_parquet/catalog/CatalogRegistrator.scala @@ -0,0 +1,197 @@ +package org.embulk.output.s3_parquet.catalog + +import java.util.{Optional, Map => JMap} + +import com.amazonaws.services.glue.model.{ + Column, + CreateTableRequest, + DeleteTableRequest, + GetTableRequest, + SerDeInfo, + StorageDescriptor, + TableInput +} +import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.embulk.config.{Config, ConfigDefault, ConfigException} +import org.embulk.output.s3_parquet.aws.Aws +import org.embulk.output.s3_parquet.implicits +import org.embulk.spi.{Schema, Column => EmbulkColumn} +import org.slf4j.{Logger, LoggerFactory} + +import scala.util.Try + +object CatalogRegistrator { + + trait Task extends org.embulk.config.Task { + @Config("catalog_id") + @ConfigDefault("null") + def getCatalogId: Optional[String] + + @Config("database") + def getDatabase: String + + @Config("table") + def getTable: String + + @Config("column_options") + @ConfigDefault("{}") + def getColumnOptions: JMap[String, ColumnOption] + + @Config("operation_if_exists") + @ConfigDefault("\"delete\"") + def getOperationIfExists: String + } + + trait ColumnOption { + @Config("type") + def getType: String + } + + import implicits._ + + def fromTask( + task: CatalogRegistrator.Task, + aws: Aws, + schema: Schema, + location: String, + compressionCodec: CompressionCodecName, + defaultGlueTypes: Map[EmbulkColumn, GlueDataType] = Map.empty + ): CatalogRegistrator = + CatalogRegistrator( + aws = aws, + catalogId = task.getCatalogId, + database = task.getDatabase, + table = task.getTable, + operationIfExists = task.getOperationIfExists, + location = location, + compressionCodec = compressionCodec, + schema = schema, + columnOptions = task.getColumnOptions, + defaultGlueTypes = defaultGlueTypes + ) +} + +case class CatalogRegistrator( + aws: Aws, + catalogId: Option[String] = None, + database: String, + table: String, + operationIfExists: String, + location: String, + compressionCodec: CompressionCodecName, + schema: Schema, + columnOptions: Map[String, CatalogRegistrator.ColumnOption], + defaultGlueTypes: Map[EmbulkColumn, GlueDataType] = Map.empty +) { + + import implicits._ + + private val logger: Logger = + LoggerFactory.getLogger(classOf[CatalogRegistrator]) + + def run(): Unit = { + if (doesTableExists()) { + operationIfExists match { + case "skip" => + logger.info( + s"Skip to register the table: ${database}.${table}" + ) + return + + case "delete" => + logger.info(s"Delete the table: ${database}.${table}") + deleteTable() + + case unknown => + throw new ConfigException(s"Unsupported operation: $unknown") + } + } + registerNewParquetTable() + showNewTableInfo() + } + + def showNewTableInfo(): Unit = { + val req = new GetTableRequest() + catalogId.foreach(req.setCatalogId) + req.setDatabaseName(database) + req.setName(table) + + val t = aws.withGlue(_.getTable(req)).getTable + logger.info(s"Created a table: ${t.toString}") + } + + def doesTableExists(): Boolean = { + val req = new GetTableRequest() + catalogId.foreach(req.setCatalogId) + req.setDatabaseName(database) + req.setName(table) + + Try(aws.withGlue(_.getTable(req))).isSuccess + } + + def deleteTable(): Unit = { + val req = new DeleteTableRequest() + catalogId.foreach(req.setCatalogId) + req.setDatabaseName(database) + req.setName(table) + aws.withGlue(_.deleteTable(req)) + } + + def registerNewParquetTable(): Unit = { + logger.info(s"Create a new table: ${database}.${table}") + val req = new CreateTableRequest() + catalogId.foreach(req.setCatalogId) + req.setDatabaseName(database) + req.setTableInput( + new TableInput() + .withName(table) + .withDescription("Created by embulk-output-s3_parquet") + .withTableType("EXTERNAL_TABLE") + .withParameters( + Map( + "EXTERNAL" -> "TRUE", + "classification" -> "parquet", + "parquet.compression" -> compressionCodec.name() + ) + ) + .withStorageDescriptor( + new StorageDescriptor() + .withColumns(getGlueSchema: _*) + .withLocation(location) + .withCompressed(isCompressed) + .withInputFormat( + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat" + ) + .withOutputFormat( + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat" + ) + .withSerdeInfo( + new SerDeInfo() + .withSerializationLibrary( + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + ) + .withParameters(Map("serialization.format" -> "1")) + ) + ) + ) + aws.withGlue(_.createTable(req)) + } + + private def getGlueSchema: Seq[Column] = { + schema.getColumns.map { c: EmbulkColumn => + new Column() + .withName(c.getName) + .withType( + columnOptions + .get(c.getName) + .map(_.getType) + .getOrElse(defaultGlueTypes(c).name) + ) + } + } + + private def isCompressed: Boolean = { + !compressionCodec.equals(CompressionCodecName.UNCOMPRESSED) + } + +} diff --git a/src/main/scala/org/embulk/output/s3_parquet/catalog/GlueDataType.scala b/src/main/scala/org/embulk/output/s3_parquet/catalog/GlueDataType.scala new file mode 100644 index 0000000..f3d85d2 --- /dev/null +++ b/src/main/scala/org/embulk/output/s3_parquet/catalog/GlueDataType.scala @@ -0,0 +1,57 @@ +package org.embulk.output.s3_parquet.catalog + +// https://docs.aws.amazon.com/athena/latest/ug/data-types.html + +sealed abstract class GlueDataType(val name: String) +object GlueDataType { + sealed abstract class AbstractIntGlueDataType(name: String, val bitWidth: Int) + extends GlueDataType(name) + + // BOOLEAN – Values are true and false. + case object BOOLEAN extends GlueDataType("BOOLEAN") + // TINYINT – A 8-bit signed INTEGER in two’s complement format, with a minimum value of -27 and a maximum value of 27-1. + case object TINYINT extends AbstractIntGlueDataType("TINYINT", bitWidth = 8) + // SMALLINT – A 16-bit signed INTEGER in two’s complement format, with a minimum value of -215 and a maximum value of 215-1. + case object SMALLINT + extends AbstractIntGlueDataType("SMALLINT", bitWidth = 16) + // INT and INTEGER – Athena combines two different implementations of the integer data type, as follows: + // * INT – In Data Definition Language (DDL) queries, Athena uses the INT data type. + // * INTEGER – In DML queries, Athena uses the INTEGER data type. INTEGER is represented as a 32-bit signed value in two's complement format, with a minimum value of -231 and a maximum value of 231-1. + case object INT extends AbstractIntGlueDataType("INT", bitWidth = 32) + // BIGINT – A 64-bit signed INTEGER in two’s complement format, with a minimum value of -263 and a maximum value of 263-1. + case object BIGINT extends AbstractIntGlueDataType("BIGINT", bitWidth = 64) + // DOUBLE – A 64-bit double-precision floating point number. + case object DOUBLE extends GlueDataType("DOUBLE") + // FLOAT – A 32-bit single-precision floating point number. Equivalent to the REAL in Presto. + case object FLOAT extends GlueDataType("FLOAT") + // DECIMAL(precision, scale) – precision is the total number of digits. scale (optional) is the number of digits in fractional part with a default of 0. For example, use these type definitions: DECIMAL(11,5), DECIMAL(15). + case class DECIMAL(precision: Int, scale: Int) + extends GlueDataType(s"DECIMAL($precision, $scale)") + // STRING – A string literal enclosed in single or double quotes. For more information, see STRING Hive Data Type. + case object STRING extends GlueDataType("STRING") + // CHAR – Fixed length character data, with a specified length between 1 and 255, such as char(10). For more information, see CHAR Hive Data Type. + case class CHAR(length: Int) extends GlueDataType(s"CHAR($length)") + // VARCHAR – Variable length character data, with a specified length between 1 and 65535, such as varchar(10). For more information, see VARCHAR Hive Data Type. + case class VARCHAR(length: Int) extends GlueDataType(s"VARCHAR($length)") + // BINARY – Used for data in Parquet. + case object BINARY extends GlueDataType("BINARY") + // DATE – A date in UNIX format, such as YYYY-MM-DD. + case object DATE extends GlueDataType("DATE") + // TIMESTAMP – Date and time instant in the UNiX format, such as yyyy-mm-dd hh:mm:ss[.f...]. For example, TIMESTAMP '2008-09-15 03:04:05.324'. This format uses the session time zone. + case object TIMESTAMP extends GlueDataType("TIMESTAMP") + // ARRAY + case class ARRAY(dataType: GlueDataType) + extends GlueDataType(s"ARRAY<${dataType.name}>") + // MAP + case class MAP(keyDataType: GlueDataType, valueDataType: GlueDataType) + extends GlueDataType(s"MAP<${keyDataType.name}, ${valueDataType.name}>") + // STRUCT + case class STRUCT(struct: Map[String, GlueDataType]) + extends GlueDataType({ + val columns = struct + .map { + case (columnName, glueType) => s"$columnName : ${glueType.name}" + } + s"STRUCT<${columns.mkString(", ")}>" + }) +} diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/DateLogicalType.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/DateLogicalType.scala new file mode 100644 index 0000000..2b1c4a1 --- /dev/null +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/DateLogicalType.scala @@ -0,0 +1,66 @@ +package org.embulk.output.s3_parquet.parquet + +import java.time.{Duration, Instant} + +import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType, Types} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.embulk.config.ConfigException +import org.embulk.output.s3_parquet.catalog.GlueDataType +import org.embulk.spi.`type`.{ + BooleanType, + DoubleType, + JsonType, + LongType, + StringType, + TimestampType +} +import org.embulk.spi.time.{Timestamp, TimestampFormatter} +import org.embulk.spi.Column +import org.msgpack.value.Value + +object DateLogicalType extends ParquetColumnType { + override def primitiveType(column: Column): PrimitiveType = { + column.getType match { + case _: LongType | _: TimestampType => + Types + .optional(PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.dateType()) + .named(column.getName) + case _: BooleanType | _: DoubleType | _: StringType | _: JsonType | _ => + throw new ConfigException(s"Unsupported column type: ${column.getName}") + } + } + + override def glueDataType(column: Column): GlueDataType = + column.getType match { + case _: LongType | _: TimestampType => GlueDataType.DATE + case _: BooleanType | _: DoubleType | _: StringType | _: JsonType | _ => + throw new ConfigException(s"Unsupported column type: ${column.getName}") + } + + override def consumeBoolean(consumer: RecordConsumer, v: Boolean): Unit = + throw newUnsupportedMethodException("consumeBoolean") + + override def consumeString(consumer: RecordConsumer, v: String): Unit = + throw newUnsupportedMethodException("consumeString") + + override def consumeLong(consumer: RecordConsumer, v: Long): Unit = + consumeLongAsInteger(consumer, v) + + override def consumeDouble(consumer: RecordConsumer, v: Double): Unit = + throw newUnsupportedMethodException("consumeDouble") + + override def consumeTimestamp( + consumer: RecordConsumer, + v: Timestamp, + formatter: TimestampFormatter + ): Unit = + consumeLongAsInteger( + consumer, + Duration.between(Instant.EPOCH, v.getInstant).toDays + ) + + override def consumeJson(consumer: RecordConsumer, v: Value): Unit = + throw newUnsupportedMethodException("consumeJson") +} diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/DecimalLogicalType.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/DecimalLogicalType.scala new file mode 100644 index 0000000..a8699a9 --- /dev/null +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/DecimalLogicalType.scala @@ -0,0 +1,104 @@ +package org.embulk.output.s3_parquet.parquet + +import java.math.{MathContext, RoundingMode => JRoundingMode} + +import org.apache.parquet.io.api.{Binary, RecordConsumer} +import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType, Types} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.embulk.config.ConfigException +import org.embulk.output.s3_parquet.catalog.GlueDataType +import org.embulk.spi.{Column, DataException} +import org.embulk.spi.`type`.{ + BooleanType, + DoubleType, + JsonType, + LongType, + StringType, + TimestampType +} +import org.embulk.spi.time.{Timestamp, TimestampFormatter} +import org.msgpack.value.Value + +import scala.math.BigDecimal.RoundingMode + +case class DecimalLogicalType(scale: Int, precision: Int) + extends ParquetColumnType { + // ref. https://github.com/apache/parquet-format/blob/apache-parquet-format-2.8.0/LogicalTypes.md#decimal + require(scale >= 0, "Scale must be zero or a positive integer.") + require( + scale < precision, + "Scale must be a positive integer less than the precision." + ) + require( + precision > 0, + "Precision is required and must be a non-zero positive integer." + ) + + override def primitiveType(column: Column): PrimitiveType = + column.getType match { + case _: LongType if 1 <= precision && precision <= 9 => + Types + .optional(PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.decimalType(scale, precision)) + .named(column.getName) + case _: LongType if 10 <= precision && precision <= 18 => + Types + .optional(PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.decimalType(scale, precision)) + .named(column.getName) + case _: StringType | _: DoubleType => + Types + .optional(PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.decimalType(scale, precision)) + .named(column.getName) + case _: BooleanType | _: TimestampType | _: JsonType | _ => + throw new ConfigException( + s"Unsupported column type: ${column.getName} (scale: $scale, precision: $precision)" + ) + } + + override def glueDataType(column: Column): GlueDataType = + column.getType match { + case _: StringType | _: LongType | _: DoubleType => + GlueDataType.DECIMAL(scale = scale, precision = precision) + case _: BooleanType | _: TimestampType | _: JsonType | _ => + throw new ConfigException( + s"Unsupported column type: ${column.getName} (scale: $scale, precision: $precision)" + ) + } + + override def consumeBoolean(consumer: RecordConsumer, v: Boolean): Unit = + throw newUnsupportedMethodException("consumeBoolean") + override def consumeString(consumer: RecordConsumer, v: String): Unit = + try consumeBigDecimal(consumer, BigDecimal.exact(v)) + catch { + case ex: NumberFormatException => + throw new DataException(s"Failed to cast String: $v to BigDecimal.", ex) + } + override def consumeLong(consumer: RecordConsumer, v: Long): Unit = + if (1 <= precision && precision <= 9) consumeLongAsInteger(consumer, v) + else if (10 <= precision && precision <= 18) consumer.addLong(v) + else + throw new ConfigException( + s"precision must be 1 <= precision <= 18 when consuming long values but precision is $precision." + ) + override def consumeDouble(consumer: RecordConsumer, v: Double): Unit = + consumeBigDecimal(consumer, BigDecimal.exact(v)) + override def consumeTimestamp( + consumer: RecordConsumer, + v: Timestamp, + formatter: TimestampFormatter + ): Unit = throw newUnsupportedMethodException("consumeTimestamp") + override def consumeJson(consumer: RecordConsumer, v: Value): Unit = + throw newUnsupportedMethodException("consumeJson") + + private def consumeBigDecimal(consumer: RecordConsumer, v: BigDecimal): Unit = + // TODO: Make RoundingMode configurable? + consumer.addBinary( + Binary.fromString( + v.setScale(scale, RoundingMode.HALF_UP) + .round(new MathContext(precision, JRoundingMode.HALF_UP)) + .toString() + ) + ) +} diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/DefaultColumnType.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/DefaultColumnType.scala new file mode 100644 index 0000000..efac158 --- /dev/null +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/DefaultColumnType.scala @@ -0,0 +1,77 @@ +package org.embulk.output.s3_parquet.parquet + +import org.apache.parquet.io.api.{Binary, RecordConsumer} +import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType, Types} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.embulk.config.ConfigException +import org.embulk.output.s3_parquet.catalog.GlueDataType +import org.embulk.spi.time.{Timestamp, TimestampFormatter} +import org.embulk.spi.Column +import org.embulk.spi.`type`.{ + BooleanType, + DoubleType, + JsonType, + LongType, + StringType, + TimestampType +} +import org.msgpack.value.Value + +object DefaultColumnType extends ParquetColumnType { + override def primitiveType(column: Column): PrimitiveType = + column.getType match { + case _: BooleanType => + Types.optional(PrimitiveTypeName.BOOLEAN).named(column.getName) + case _: LongType => + Types.optional(PrimitiveTypeName.INT64).named(column.getName) + case _: DoubleType => + Types.optional(PrimitiveTypeName.DOUBLE).named(column.getName) + case _: StringType => + Types + .optional(PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.stringType()) + .named(column.getName) + case _: TimestampType => + Types + .optional(PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.stringType()) + .named(column.getName) + case _: JsonType => + Types + .optional(PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.stringType()) + .named(column.getName) + case _ => + throw new ConfigException(s"Unsupported column type: ${column.getName}") + } + + override def glueDataType(column: Column): GlueDataType = + column.getType match { + case _: BooleanType => + GlueDataType.BOOLEAN + case _: LongType => + GlueDataType.BIGINT + case _: DoubleType => + GlueDataType.DOUBLE + case _: StringType | _: TimestampType | _: JsonType => + GlueDataType.STRING + case _ => + throw new ConfigException(s"Unsupported column type: ${column.getName}") + } + + override def consumeBoolean(consumer: RecordConsumer, v: Boolean): Unit = + consumer.addBoolean(v) + override def consumeString(consumer: RecordConsumer, v: String): Unit = + consumer.addBinary(Binary.fromString(v)) + override def consumeLong(consumer: RecordConsumer, v: Long): Unit = + consumer.addLong(v) + override def consumeDouble(consumer: RecordConsumer, v: Double): Unit = + consumer.addDouble(v) + override def consumeTimestamp( + consumer: RecordConsumer, + v: Timestamp, + formatter: TimestampFormatter + ): Unit = consumer.addBinary(Binary.fromString(formatter.format(v))) + override def consumeJson(consumer: RecordConsumer, v: Value): Unit = + consumer.addBinary(Binary.fromString(v.toJson)) +} diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/EmbulkMessageType.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/EmbulkMessageType.scala deleted file mode 100644 index b6193dc..0000000 --- a/src/main/scala/org/embulk/output/s3_parquet/parquet/EmbulkMessageType.scala +++ /dev/null @@ -1,114 +0,0 @@ -package org.embulk.output.s3_parquet.parquet - -import com.google.common.collect.ImmutableList -import org.apache.parquet.schema.{ - LogicalTypeAnnotation, - MessageType, - Type, - Types -} -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName -import org.embulk.config.ConfigException -import org.embulk.spi.{Column, ColumnVisitor, Schema} - -object EmbulkMessageType { - - def builder(): Builder = { - Builder() - } - - case class Builder( - name: String = "embulk", - schema: Schema = Schema.builder().build(), - logicalTypeHandlers: LogicalTypeHandlerStore = - LogicalTypeHandlerStore.empty - ) { - - def withName(name: String): Builder = copy(name = name) - - def withSchema(schema: Schema): Builder = copy(schema = schema) - - def withLogicalTypeHandlers( - logicalTypeHandlers: LogicalTypeHandlerStore - ): Builder = copy(logicalTypeHandlers = logicalTypeHandlers) - - def build(): MessageType = { - val builder: ImmutableList.Builder[Type] = ImmutableList.builder[Type]() - schema.visitColumns( - EmbulkMessageTypeColumnVisitor(builder, logicalTypeHandlers) - ) - new MessageType("embulk", builder.build()) - } - - } - - private case class EmbulkMessageTypeColumnVisitor( - builder: ImmutableList.Builder[Type], - logicalTypeHandlers: LogicalTypeHandlerStore = - LogicalTypeHandlerStore.empty - ) extends ColumnVisitor { - - private def addTypeByLogicalTypeHandlerOrDefault( - column: Column, - default: => Type - ): Unit = { - builder.add( - logicalTypeHandlers.get(column.getName, column.getType) match { - case Some(handler) if handler.isConvertible(column.getType) => - handler.newSchemaFieldType(column.getName) - case Some(handler) => - throw new ConfigException( - s"${column.getType} is not convertible by ${handler.getClass.getName}." - ) - case _ => default - } - ) - } - - override def booleanColumn(column: Column): Unit = { - addTypeByLogicalTypeHandlerOrDefault(column, default = { - Types.optional(PrimitiveTypeName.BOOLEAN).named(column.getName) - }) - } - - override def longColumn(column: Column): Unit = { - addTypeByLogicalTypeHandlerOrDefault(column, default = { - Types.optional(PrimitiveTypeName.INT64).named(column.getName) - }) - } - - override def doubleColumn(column: Column): Unit = { - addTypeByLogicalTypeHandlerOrDefault(column, default = { - Types.optional(PrimitiveTypeName.DOUBLE).named(column.getName) - }) - } - - override def stringColumn(column: Column): Unit = { - addTypeByLogicalTypeHandlerOrDefault(column, default = { - Types - .optional(PrimitiveTypeName.BINARY) - .as(LogicalTypeAnnotation.stringType()) - .named(column.getName) - }) - } - - override def timestampColumn(column: Column): Unit = { - addTypeByLogicalTypeHandlerOrDefault(column, default = { - Types - .optional(PrimitiveTypeName.BINARY) - .as(LogicalTypeAnnotation.stringType()) - .named(column.getName) - }) - } - - override def jsonColumn(column: Column): Unit = { - addTypeByLogicalTypeHandlerOrDefault(column, default = { - Types - .optional(PrimitiveTypeName.BINARY) - .as(LogicalTypeAnnotation.stringType()) - .named(column.getName) - }) - } - } - -} diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/IntLogicalType.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/IntLogicalType.scala new file mode 100644 index 0000000..1db19b4 --- /dev/null +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/IntLogicalType.scala @@ -0,0 +1,175 @@ +package org.embulk.output.s3_parquet.parquet + +import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType, Types} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.embulk.config.ConfigException +import org.embulk.output.s3_parquet.catalog.GlueDataType +import org.embulk.output.s3_parquet.catalog.GlueDataType.AbstractIntGlueDataType +import org.embulk.spi.{Column, DataException} +import org.embulk.spi.`type`.{ + BooleanType, + DoubleType, + JsonType, + LongType, + StringType, + TimestampType +} +import org.embulk.spi.time.{Timestamp, TimestampFormatter} +import org.msgpack.value.Value +import org.slf4j.{Logger, LoggerFactory} + +import scala.math.BigDecimal.RoundingMode + +case class IntLogicalType(bitWidth: Int, isSigned: Boolean) + extends ParquetColumnType { + require( + Seq(8, 16, 32, 64).contains(bitWidth), + s"bitWidth value must be one of (8, 16, 32, 64)." + ) + + private val logger: Logger = LoggerFactory.getLogger(classOf[IntLogicalType]) + + private val SIGNED_64BIT_INT_MAX_VALUE = BigInt("9223372036854775807") + private val SIGNED_64BIT_INT_MIN_VALUE = BigInt("-9223372036854775808") + private val SIGNED_32BIT_INT_MAX_VALUE = BigInt("2147483647") + private val SIGNED_32BIT_INT_MIN_VALUE = BigInt("-2147483648") + private val SIGNED_16BIT_INT_MAX_VALUE = BigInt("32767") + private val SIGNED_16BIT_INT_MIN_VALUE = BigInt("-32768") + private val SIGNED_8BIT_INT_MAX_VALUE = BigInt("127") + private val SIGNED_8BIT_INT_MIN_VALUE = BigInt("-128") + private val UNSIGNED_64BIT_INT_MAX_VALUE = BigInt("18446744073709551615") + private val UNSIGNED_64BIT_INT_MIN_VALUE = BigInt("0") + private val UNSIGNED_32BIT_INT_MAX_VALUE = BigInt("4294967295") + private val UNSIGNED_32BIT_INT_MIN_VALUE = BigInt("0") + private val UNSIGNED_16BIT_INT_MAX_VALUE = BigInt("65535") + private val UNSIGNED_16BIT_INT_MIN_VALUE = BigInt("0") + private val UNSIGNED_8BIT_INT_MAX_VALUE = BigInt("255") + private val UNSIGNED_8BIT_INT_MIN_VALUE = BigInt("0") + + private def isINT32: Boolean = bitWidth < 64 + + override def primitiveType(column: Column): PrimitiveType = + column.getType match { + case _: BooleanType | _: LongType | _: DoubleType | _: StringType => + Types + .optional( + if (isINT32) PrimitiveTypeName.INT32 + else PrimitiveTypeName.INT64 + ) + .as(LogicalTypeAnnotation.intType(bitWidth, isSigned)) + .named(column.getName) + case _: TimestampType | _: JsonType | _ => + throw new ConfigException(s"Unsupported column type: ${column.getName}") + } + + override def glueDataType(column: Column): GlueDataType = + column.getType match { + case _: BooleanType | _: LongType | _: DoubleType | _: StringType => + (bitWidth, isSigned) match { + case (8, true) => GlueDataType.TINYINT + case (16, true) => GlueDataType.SMALLINT + case (32, true) => GlueDataType.INT + case (64, true) => GlueDataType.BIGINT + case (8, false) => + warningWhenConvertingUnsignedIntegerToGlueType( + GlueDataType.SMALLINT + ) + GlueDataType.SMALLINT + case (16, false) => + warningWhenConvertingUnsignedIntegerToGlueType(GlueDataType.INT) + GlueDataType.INT + case (32, false) => + warningWhenConvertingUnsignedIntegerToGlueType(GlueDataType.BIGINT) + GlueDataType.BIGINT + case (64, false) => + warningWhenConvertingUnsignedIntegerToGlueType(GlueDataType.BIGINT) + GlueDataType.BIGINT + case (_, _) => + throw new ConfigException( + s"Unsupported column type: ${column.getName} (bitWidth: $bitWidth, isSigned: $isSigned)" + ) + } + case _: TimestampType | _: JsonType | _ => + throw new ConfigException(s"Unsupported column type: ${column.getName}") + } + + override def consumeBoolean(consumer: RecordConsumer, v: Boolean): Unit = + if (isINT32) + consumer.addInteger( + if (v) 1 + else 0 + ) + else + consumer.addLong( + if (v) 1 + else 0 + ) + + override def consumeString(consumer: RecordConsumer, v: String): Unit = + try consumeBigDecimal(consumer, BigDecimal.exact(v)) + catch { + case ex: NumberFormatException => + throw new DataException(s"Failed to cast String: $v to BigDecimal.", ex) + } + override def consumeLong(consumer: RecordConsumer, v: Long): Unit = + consumeBigInt(consumer, BigInt(v)) + override def consumeDouble(consumer: RecordConsumer, v: Double): Unit = + consumeBigDecimal(consumer, BigDecimal.exact(v)) + override def consumeTimestamp( + consumer: RecordConsumer, + v: Timestamp, + formatter: TimestampFormatter + ): Unit = throw newUnsupportedMethodException("consumeTimestamp") + override def consumeJson(consumer: RecordConsumer, v: Value): Unit = + throw newUnsupportedMethodException("consumeJson") + + private def warningWhenConvertingUnsignedIntegerToGlueType( + glueType: AbstractIntGlueDataType + ): Unit = { + logger.warn { + s"int(bit_width = $bitWidth, is_signed $isSigned) is converted to Glue ${glueType.name}" + + s" but this is not represented correctly, because the Glue ${glueType.name} represents" + + s" a ${glueType.bitWidth}-bit signed integer. Please use `catalog.column_options` to define the type." + } + } + + private def consumeBigDecimal(consumer: RecordConsumer, v: BigDecimal): Unit = + // TODO: Make RoundingMode configurable? + consumeBigInt(consumer, v.setScale(0, RoundingMode.HALF_UP).toBigInt) + + private def consumeBigInt(consumer: RecordConsumer, v: BigInt): Unit = { + def consume(min: BigInt, max: BigInt): Unit = + if (min <= v && v <= max) + if (isINT32) consumer.addInteger(v.toInt) + else consumer.addLong(v.toLong) + else + throw new DataException( + s"The value is out of the range: that is '$min <= value <= $max'" + + s" in the case of int(bit_width = $bitWidth, is_signed $isSigned)" + + s", but the value is $v." + ) + (bitWidth, isSigned) match { + case (8, true) => + consume(SIGNED_8BIT_INT_MIN_VALUE, SIGNED_8BIT_INT_MAX_VALUE) + case (16, true) => + consume(SIGNED_16BIT_INT_MIN_VALUE, SIGNED_16BIT_INT_MAX_VALUE) + case (32, true) => + consume(SIGNED_32BIT_INT_MIN_VALUE, SIGNED_32BIT_INT_MAX_VALUE) + case (64, true) => + consume(SIGNED_64BIT_INT_MIN_VALUE, SIGNED_64BIT_INT_MAX_VALUE) + case (8, false) => + consume(UNSIGNED_8BIT_INT_MIN_VALUE, UNSIGNED_8BIT_INT_MAX_VALUE) + case (16, false) => + consume(UNSIGNED_16BIT_INT_MIN_VALUE, UNSIGNED_16BIT_INT_MAX_VALUE) + case (32, false) => + consume(UNSIGNED_32BIT_INT_MIN_VALUE, UNSIGNED_32BIT_INT_MAX_VALUE) + case (64, false) => + consume(UNSIGNED_64BIT_INT_MIN_VALUE, UNSIGNED_64BIT_INT_MAX_VALUE) + case _ => + throw new ConfigException( + s"int(bit_width = $bitWidth, is_signed $isSigned) is unsupported." + ) + } + } +} diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/JsonLogicalType.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/JsonLogicalType.scala new file mode 100644 index 0000000..bbeae4a --- /dev/null +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/JsonLogicalType.scala @@ -0,0 +1,75 @@ +package org.embulk.output.s3_parquet.parquet +import org.apache.parquet.io.api.{Binary, RecordConsumer} +import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType, Types} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.embulk.config.ConfigException +import org.embulk.output.s3_parquet.catalog.GlueDataType +import org.embulk.spi.Column +import org.embulk.spi.`type`.{ + BooleanType, + DoubleType, + JsonType, + LongType, + StringType, + TimestampType +} +import org.embulk.spi.time.{Timestamp, TimestampFormatter} +import org.msgpack.value.{Value, ValueFactory} +import org.slf4j.{Logger, LoggerFactory} + +object JsonLogicalType extends ParquetColumnType { + private val logger: Logger = LoggerFactory.getLogger(JsonLogicalType.getClass) + override def primitiveType(column: Column): PrimitiveType = + column.getType match { + case _: BooleanType | _: LongType | _: DoubleType | _: StringType | + _: JsonType => + Types + .optional(PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.jsonType()) + .named(column.getName) + case _: TimestampType | _ => + throw new ConfigException(s"Unsupported column type: ${column.getName}") + } + + override def glueDataType(column: Column): GlueDataType = + column.getType match { + case _: BooleanType | _: LongType | _: DoubleType | _: StringType | + _: JsonType => + warningWhenConvertingJsonToGlueType(GlueDataType.STRING) + GlueDataType.STRING + case _: TimestampType | _ => + throw new ConfigException(s"Unsupported column type: ${column.getName}") + } + + override def consumeBoolean(consumer: RecordConsumer, v: Boolean): Unit = + consumeJson(consumer, ValueFactory.newBoolean(v)) + + override def consumeString(consumer: RecordConsumer, v: String): Unit = + consumeJson(consumer, ValueFactory.newString(v)) + + override def consumeLong(consumer: RecordConsumer, v: Long): Unit = + consumeJson(consumer, ValueFactory.newInteger(v)) + + override def consumeDouble(consumer: RecordConsumer, v: Double): Unit = + consumeJson(consumer, ValueFactory.newFloat(v)) + + override def consumeTimestamp( + consumer: RecordConsumer, + v: Timestamp, + formatter: TimestampFormatter + ): Unit = throw newUnsupportedMethodException("consumeTimestamp") + + override def consumeJson(consumer: RecordConsumer, v: Value): Unit = + consumer.addBinary(Binary.fromString(v.toJson)) + + private def warningWhenConvertingJsonToGlueType( + glueType: GlueDataType + ): Unit = { + logger.warn( + s"json is converted" + + s" to Glue ${glueType.name} but this is not represented correctly, because Glue" + + s" does not support json type. Please use `catalog.column_options` to define the type." + ) + } + +} diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandler.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandler.scala deleted file mode 100644 index b417c64..0000000 --- a/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandler.scala +++ /dev/null @@ -1,184 +0,0 @@ -package org.embulk.output.s3_parquet.parquet - -import org.apache.parquet.io.api.{Binary, RecordConsumer} -import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType, Types} -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName -import org.embulk.spi.DataException -import org.embulk.spi.`type`.{Type => EmbulkType, Types => EmbulkTypes} -import org.embulk.spi.time.Timestamp -import org.msgpack.value.Value - -/** - * Handle Apache Parquet 'Logical Types' on schema/value conversion. - * ref. https://github.com/apache/parquet-format/blob/master/LogicalTypes.md - * - * It focuses on only older representation because newer supported since 1.11 is not used actually yet. - * TODO Support both of older and newer representation after 1.11+ is published and other middleware supports it. - * - */ -sealed trait LogicalTypeHandler { - def isConvertible(t: EmbulkType): Boolean - - def newSchemaFieldType(name: String): PrimitiveType - - def consume(orig: Any, recordConsumer: RecordConsumer): Unit -} - -abstract class IntLogicalTypeHandler(logicalType: LogicalTypeAnnotation) - extends LogicalTypeHandler { - - override def isConvertible(t: EmbulkType): Boolean = { - t == EmbulkTypes.LONG - } - - override def newSchemaFieldType(name: String): PrimitiveType = { - Types.optional(PrimitiveTypeName.INT64).as(logicalType).named(name) - } - - override def consume(orig: Any, recordConsumer: RecordConsumer): Unit = { - orig match { - case v: Long => recordConsumer.addLong(v) - case _ => - throw new DataException( - "given mismatched type value; expected type is long" - ) - } - } -} - -object TimestampMillisLogicalTypeHandler extends LogicalTypeHandler { - - override def isConvertible(t: EmbulkType): Boolean = { - t == EmbulkTypes.TIMESTAMP - } - - override def newSchemaFieldType(name: String): PrimitiveType = { - Types - .optional(PrimitiveTypeName.INT64) - .as( - LogicalTypeAnnotation - .timestampType(true, LogicalTypeAnnotation.TimeUnit.MILLIS) - ) - .named(name) - } - - override def consume(orig: Any, recordConsumer: RecordConsumer): Unit = { - orig match { - case ts: Timestamp => recordConsumer.addLong(ts.toEpochMilli) - case _ => - throw new DataException( - "given mismatched type value; expected type is timestamp" - ) - } - } -} - -object TimestampMicrosLogicalTypeHandler extends LogicalTypeHandler { - - override def isConvertible(t: EmbulkType): Boolean = { - t == EmbulkTypes.TIMESTAMP - } - - override def newSchemaFieldType(name: String): PrimitiveType = { - Types - .optional(PrimitiveTypeName.INT64) - .as( - LogicalTypeAnnotation - .timestampType(true, LogicalTypeAnnotation.TimeUnit.MICROS) - ) - .named(name) - } - - override def consume(orig: Any, recordConsumer: RecordConsumer): Unit = { - orig match { - case ts: Timestamp => - val v = (ts.getEpochSecond * 1_000_000L) + (ts.getNano - .asInstanceOf[Long] / 1_000L) - recordConsumer.addLong(v) - case _ => - throw new DataException( - "given mismatched type value; expected type is timestamp" - ) - } - } -} - -object TimestampNanosLogicalTypeHandler extends LogicalTypeHandler { - - override def isConvertible(t: EmbulkType): Boolean = { - t == EmbulkTypes.TIMESTAMP - } - - override def newSchemaFieldType(name: String): PrimitiveType = { - Types - .optional(PrimitiveTypeName.INT64) - .as( - LogicalTypeAnnotation - .timestampType(true, LogicalTypeAnnotation.TimeUnit.NANOS) - ) - .named(name) - } - - override def consume(orig: Any, recordConsumer: RecordConsumer): Unit = { - orig match { - case ts: Timestamp => - val v = - (ts.getEpochSecond * 1_000_000_000L) + ts.getNano.asInstanceOf[Long] - recordConsumer.addLong(v) - case _ => - throw new DataException( - "given mismatched type value; expected type is timestamp" - ) - } - } -} - -object Int8LogicalTypeHandler - extends IntLogicalTypeHandler(LogicalTypeAnnotation.intType(8, true)) - -object Int16LogicalTypeHandler - extends IntLogicalTypeHandler(LogicalTypeAnnotation.intType(16, true)) - -object Int32LogicalTypeHandler - extends IntLogicalTypeHandler(LogicalTypeAnnotation.intType(32, true)) - -object Int64LogicalTypeHandler - extends IntLogicalTypeHandler(LogicalTypeAnnotation.intType(64, true)) - -object Uint8LogicalTypeHandler - extends IntLogicalTypeHandler(LogicalTypeAnnotation.intType(8, false)) - -object Uint16LogicalTypeHandler - extends IntLogicalTypeHandler(LogicalTypeAnnotation.intType(16, false)) - -object Uint32LogicalTypeHandler - extends IntLogicalTypeHandler(LogicalTypeAnnotation.intType(32, false)) - -object Uint64LogicalTypeHandler - extends IntLogicalTypeHandler(LogicalTypeAnnotation.intType(64, false)) - -object JsonLogicalTypeHandler extends LogicalTypeHandler { - - override def isConvertible(t: EmbulkType): Boolean = { - t == EmbulkTypes.JSON - } - - override def newSchemaFieldType(name: String): PrimitiveType = { - Types - .optional(PrimitiveTypeName.BINARY) - .as(LogicalTypeAnnotation.jsonType()) - .named(name) - } - - override def consume(orig: Any, recordConsumer: RecordConsumer): Unit = { - orig match { - case msgPack: Value => - val bin = Binary.fromString(msgPack.toJson) - recordConsumer.addBinary(bin) - case _ => - throw new DataException( - "given mismatched type value; expected type is json" - ) - } - } -} diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandlerStore.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandlerStore.scala deleted file mode 100644 index abfbc14..0000000 --- a/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandlerStore.scala +++ /dev/null @@ -1,106 +0,0 @@ -package org.embulk.output.s3_parquet.parquet - -import java.util.{Map => JMap} - -import org.embulk.config.ConfigException -import org.embulk.output.s3_parquet.PluginTask.{ - ColumnOptionTask, - TypeOptionTask -} -import org.embulk.spi.`type`.{Type, Types} - -/** - * A storage has mapping from logical type query (column name, type) to handler. - * - * @param fromEmbulkType - * @param fromColumnName - */ -case class LogicalTypeHandlerStore private ( - fromEmbulkType: Map[Type, LogicalTypeHandler], - fromColumnName: Map[String, LogicalTypeHandler] -) { - - // Try column name lookup, then column type - def get(n: String, t: Type): Option[LogicalTypeHandler] = { - get(n).orElse(get(t)) - } - - def get(t: Type): Option[LogicalTypeHandler] = { - fromEmbulkType.get(t) - } - - def get(n: String): Option[LogicalTypeHandler] = { - fromColumnName.get(n) - } -} - -object LogicalTypeHandlerStore { - - import org.embulk.output.s3_parquet.implicits._ - - private val STRING_TO_EMBULK_TYPE = Map[String, Type]( - "boolean" -> Types.BOOLEAN, - "long" -> Types.LONG, - "double" -> Types.DOUBLE, - "string" -> Types.STRING, - "timestamp" -> Types.TIMESTAMP, - "json" -> Types.JSON - ) - - // Listed only older logical types that we can convert from embulk type - private val STRING_TO_LOGICAL_TYPE = Map[String, LogicalTypeHandler]( - "timestamp-millis" -> TimestampMillisLogicalTypeHandler, - "timestamp-micros" -> TimestampMicrosLogicalTypeHandler, - "timestamp-nanos" -> TimestampNanosLogicalTypeHandler, - "int8" -> Int8LogicalTypeHandler, - "int16" -> Int16LogicalTypeHandler, - "int32" -> Int32LogicalTypeHandler, - "int64" -> Int64LogicalTypeHandler, - "uint8" -> Uint8LogicalTypeHandler, - "uint16" -> Uint16LogicalTypeHandler, - "uint32" -> Uint32LogicalTypeHandler, - "uint64" -> Uint64LogicalTypeHandler, - "json" -> JsonLogicalTypeHandler - ) - - def empty: LogicalTypeHandlerStore = { - LogicalTypeHandlerStore( - Map.empty[Type, LogicalTypeHandler], - Map.empty[String, LogicalTypeHandler] - ) - } - - def fromEmbulkOptions( - typeOpts: JMap[String, TypeOptionTask], - columnOpts: JMap[String, ColumnOptionTask] - ): LogicalTypeHandlerStore = { - val fromEmbulkType = typeOpts - .filter(_._2.getLogicalType.isPresent) - .map[Type, LogicalTypeHandler] { - case (k, v) => - { - for (t <- STRING_TO_EMBULK_TYPE.get(k); - h <- STRING_TO_LOGICAL_TYPE.get(v.getLogicalType.get)) - yield (t, h) - }.getOrElse { - throw new ConfigException("invalid logical types in type_options") - } - } - - val fromColumnName = columnOpts - .filter(_._2.getLogicalType.isPresent) - .map[String, LogicalTypeHandler] { - case (k, v) => - { - for (h <- STRING_TO_LOGICAL_TYPE.get(v.getLogicalType.get)) - yield (k, h) - }.getOrElse { - throw new ConfigException( - "invalid logical types in column_options" - ) - } - } - - LogicalTypeHandlerStore(fromEmbulkType, fromColumnName) - } -} diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeProxy.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeProxy.scala new file mode 100644 index 0000000..74ec012 --- /dev/null +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeProxy.scala @@ -0,0 +1,84 @@ +package org.embulk.output.s3_parquet.parquet + +import java.time.ZoneId +import java.util.Locale + +import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit +import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit.MILLIS +import org.apache.parquet.schema.PrimitiveType +import org.embulk.config.ConfigException +import org.embulk.output.s3_parquet.catalog.GlueDataType +import org.embulk.spi.Column +import org.embulk.spi.time.{Timestamp, TimestampFormatter} +import org.msgpack.value.Value + +object LogicalTypeProxy { + private val DEFAULT_SCALE: Int = 0 + private val DEFAULT_BID_WIDTH: Int = 64 + private val DEFAULT_IS_SIGNED: Boolean = true + private val DEFAULT_IS_ADJUSTED_TO_UTC: Boolean = true + private val DEFAULT_TIME_UNIT: TimeUnit = MILLIS + private val DEFAULT_TIME_ZONE: ZoneId = ZoneId.of("UTC") +} + +case class LogicalTypeProxy( + name: String, + scale: Option[Int] = None, + precision: Option[Int] = None, + bitWidth: Option[Int] = None, + isSigned: Option[Boolean] = None, + isAdjustedToUtc: Option[Boolean] = None, + timeUnit: Option[TimeUnit] = None, + timeZone: Option[ZoneId] = None +) extends ParquetColumnType { + private def getScale: Int = scale.getOrElse(LogicalTypeProxy.DEFAULT_SCALE) + private def getPrecision: Int = precision.getOrElse { + throw new ConfigException("\"precision\" must be set.") + } + private def getBidWith: Int = + bitWidth.getOrElse(LogicalTypeProxy.DEFAULT_BID_WIDTH) + private def getIsSigned: Boolean = + isSigned.getOrElse(LogicalTypeProxy.DEFAULT_IS_SIGNED) + private def getIsAdjustedToUtc: Boolean = + isAdjustedToUtc.getOrElse(LogicalTypeProxy.DEFAULT_IS_ADJUSTED_TO_UTC) + private def getTimeUnit: TimeUnit = + timeUnit.getOrElse(LogicalTypeProxy.DEFAULT_TIME_UNIT) + private def getTimeZone: ZoneId = + timeZone.getOrElse(LogicalTypeProxy.DEFAULT_TIME_ZONE) + + lazy val logicalType: ParquetColumnType = { + name.toUpperCase(Locale.ENGLISH) match { + case "INT" => IntLogicalType(getBidWith, getIsSigned) + case "TIMESTAMP" => + TimestampLogicalType(getIsAdjustedToUtc, getTimeUnit, getTimeZone) + case "TIME" => + TimeLogicalType(getIsAdjustedToUtc, getTimeUnit, getTimeZone) + case "DECIMAL" => DecimalLogicalType(getScale, getPrecision) + case "DATE" => DateLogicalType + case "JSON" => JsonLogicalType + case _ => + throw new ConfigException(s"Unsupported logical_type.name: $name.") + } + } + + override def primitiveType(column: Column): PrimitiveType = + logicalType.primitiveType(column) + override def glueDataType(column: Column): GlueDataType = + logicalType.glueDataType(column) + override def consumeBoolean(consumer: RecordConsumer, v: Boolean): Unit = + logicalType.consumeBoolean(consumer, v) + override def consumeString(consumer: RecordConsumer, v: String): Unit = + logicalType.consumeString(consumer, v) + override def consumeLong(consumer: RecordConsumer, v: Long): Unit = + logicalType.consumeLong(consumer, v) + override def consumeDouble(consumer: RecordConsumer, v: Double): Unit = + logicalType.consumeDouble(consumer, v) + override def consumeTimestamp( + consumer: RecordConsumer, + v: Timestamp, + formatter: TimestampFormatter + ): Unit = logicalType.consumeTimestamp(consumer, v, formatter) + override def consumeJson(consumer: RecordConsumer, v: Value): Unit = + logicalType.consumeJson(consumer, v) +} diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetColumnType.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetColumnType.scala new file mode 100644 index 0000000..e89c06b --- /dev/null +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetColumnType.scala @@ -0,0 +1,278 @@ +package org.embulk.output.s3_parquet.parquet + +import java.time.ZoneId +import java.util.{Locale, Optional} + +import org.apache.parquet.format.ConvertedType +import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit +import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit.{ + MICROS, + MILLIS, + NANOS +} +import org.apache.parquet.schema.PrimitiveType +import org.embulk.config.{ + Config, + ConfigDefault, + ConfigException, + ConfigSource, + Task => EmbulkTask +} +import org.embulk.output.s3_parquet.catalog.GlueDataType +import org.embulk.output.s3_parquet.implicits +import org.embulk.spi.{Column, DataException, Exec} +import org.embulk.spi.time.{Timestamp, TimestampFormatter} +import org.embulk.spi.time.TimestampFormatter.TimestampColumnOption +import org.msgpack.value.Value +import org.slf4j.{Logger, LoggerFactory} + +import scala.util.{Failure, Success, Try} +import scala.util.chaining._ + +object ParquetColumnType { + + import implicits._ + + private val logger: Logger = + LoggerFactory.getLogger(classOf[ParquetColumnType]) + + trait Task extends EmbulkTask with TimestampColumnOption { + @Config("logical_type") + @ConfigDefault("null") + def getLogicalType: Optional[LogicalTypeOption] + } + + trait LogicalTypeOption extends EmbulkTask { + @Config("name") + def getName: String + + @Config("scale") + @ConfigDefault("null") + def getScale: Optional[Int] + + @Config("precision") + @ConfigDefault("null") + def getPrecision: Optional[Int] + + @Config("bit_width") + @ConfigDefault("null") + def getBitWidth: Optional[Int] + + @Config("is_signed") + @ConfigDefault("null") + def getIsSigned: Optional[Boolean] + + @Config("is_adjusted_to_utc") + @ConfigDefault("null") + def getIsAdjustedToUtc: Optional[Boolean] + + @Config("time_unit") + @ConfigDefault("null") + def getTimeUnit: Optional[TimeUnit] + } + + object LogicalTypeOption { + case class ConfigBuilder private () { + case class Attributes private ( + name: Option[String] = None, + precision: Option[Int] = None, + scale: Option[Int] = None, + bitWidth: Option[Int] = None, + isSigned: Option[Boolean] = None, + isAdjustedToUtc: Option[Boolean] = None, + timeUnit: Option[TimeUnit] = None + ) { + def toOnelineYaml: String = { + val builder = Seq.newBuilder[String] + name.foreach(v => builder.addOne(s"name: ${v}")) + precision.foreach(v => builder.addOne(s"precision: ${v}")) + scale.foreach(v => builder.addOne(s"scale: ${v}")) + bitWidth.foreach(v => builder.addOne(s"bit_width: ${v}")) + isSigned.foreach(v => builder.addOne(s"is_signed: ${v}")) + isAdjustedToUtc.foreach(v => + builder.addOne(s"is_adjusted_to_utc: ${v}") + ) + timeUnit.foreach(tu => builder.addOne(s"time_unit: ${tu.name()}")) + "{" + builder.result().mkString(", ") + "}" + } + + def build(): ConfigSource = { + val c = Exec.newConfigSource() + name.foreach(c.set("name", _)) + precision.foreach(c.set("precision", _)) + scale.foreach(c.set("scale", _)) + bitWidth.foreach(c.set("bit_width", _)) + isSigned.foreach(c.set("is_signed", _)) + isAdjustedToUtc.foreach(c.set("is_adjusted_to_utc", _)) + timeUnit.foreach(tu => c.set("time_unit", tu.name())) + c + } + } + var attrs: Attributes = Attributes() + + def name(name: String): ConfigBuilder = + this.tap(_ => attrs = attrs.copy(name = Option(name))) + def scale(scale: Int): ConfigBuilder = + this.tap(_ => attrs = attrs.copy(scale = Option(scale))) + def precision(precision: Int): ConfigBuilder = + this.tap(_ => attrs = attrs.copy(precision = Option(precision))) + def bitWidth(bitWidth: Int): ConfigBuilder = + this.tap(_ => attrs = attrs.copy(bitWidth = Option(bitWidth))) + def isSigned(isSigned: Boolean): ConfigBuilder = + this.tap(_ => attrs = attrs.copy(isSigned = Option(isSigned))) + def isAdjustedToUtc(isAdjustedToUtc: Boolean): ConfigBuilder = + this.tap(_ => + attrs = attrs.copy(isAdjustedToUtc = Option(isAdjustedToUtc)) + ) + def timeUnit(timeUnit: TimeUnit): ConfigBuilder = + this.tap(_ => attrs = attrs.copy(timeUnit = Option(timeUnit))) + + def toOnelineYaml: String = attrs.toOnelineYaml + + def build(): ConfigSource = attrs.build() + } + + def builder(): ConfigBuilder = ConfigBuilder() + } + + def loadConfig(c: ConfigSource): Task = { + if (c.has("logical_type")) { + Try(c.get(classOf[String], "logical_type")).foreach { v => + logger.warn( + "[DEPRECATED] Now, it is deprecated to use the \"logical_type\" option in this usage." + + " Use \"converted_type\" instead." + ) + logger.warn( + s"[DEPRECATED] Translate {logical_type: $v} => {converted_type: $v}" + ) + c.remove("logical_type") + c.set("converted_type", v) + } + } + if (c.has("converted_type")) { + if (c.has("logical_type")) + throw new ConfigException( + "\"converted_type\" and \"logical_type\" options cannot be used at the same time." + ) + Try(c.get(classOf[String], "converted_type")) match { + case Success(convertedType) => + val logicalTypeConfig: ConfigSource = + translateConvertedType2LogicalType(convertedType) + c.setNested("logical_type", logicalTypeConfig) + case Failure(ex) => + throw new ConfigException( + "The value of \"converted_type\" option must be string.", + ex + ) + } + } + c.loadConfig(classOf[Task]) + } + + private def translateConvertedType2LogicalType( + convertedType: String + ): ConfigSource = { + val builder = LogicalTypeOption.builder() + val normalizedConvertedType: String = normalizeConvertedType(convertedType) + if (normalizedConvertedType == "TIMESTAMP_NANOS") { + builder.name("timestamp").isAdjustedToUtc(true).timeUnit(NANOS) + logger.warn( + s"[DEPRECATED] $convertedType is deprecated because this is not one of" + + s" ConvertedTypes actually. Please use 'logical_type: ${builder.toOnelineYaml}'" + ) + } + else { + + ConvertedType.valueOf(normalizedConvertedType) match { + case ConvertedType.UTF8 => builder.name("string") + case ConvertedType.DATE => builder.name("date") + case ConvertedType.TIME_MILLIS => + builder.name("time").isAdjustedToUtc(true).timeUnit(MILLIS) + case ConvertedType.TIME_MICROS => + builder.name("time").isAdjustedToUtc(true).timeUnit(MICROS) + case ConvertedType.TIMESTAMP_MILLIS => + builder.name("timestamp").isAdjustedToUtc(true).timeUnit(MILLIS) + case ConvertedType.TIMESTAMP_MICROS => + builder.name("timestamp").isAdjustedToUtc(true).timeUnit(MICROS) + case ConvertedType.UINT_8 => + builder.name("int").bitWidth(8).isSigned(false) + case ConvertedType.UINT_16 => + builder.name("int").bitWidth(16).isSigned(false) + case ConvertedType.UINT_32 => + builder.name("int").bitWidth(32).isSigned(false) + case ConvertedType.UINT_64 => + builder.name("int").bitWidth(64).isSigned(false) + case ConvertedType.INT_8 => + builder.name("int").bitWidth(8).isSigned(true) + case ConvertedType.INT_16 => + builder.name("int").bitWidth(16).isSigned(true) + case ConvertedType.INT_32 => + builder.name("int").bitWidth(32).isSigned(true) + case ConvertedType.INT_64 => + builder.name("int").bitWidth(64).isSigned(true) + case ConvertedType.JSON => builder.name("json") + case _ => + // MAP, MAP_KEY_VALUE, LIST, ENUM, DECIMAL, BSON, INTERVAL + throw new ConfigException( + s"converted_type: $convertedType is not supported." + ) + } + } + logger.info( + s"Translate {converted_type: $convertedType} => {logical_type: ${builder.toOnelineYaml}}" + ) + builder.build() + } + + private def normalizeConvertedType(convertedType: String): String = { + convertedType + .toUpperCase(Locale.ENGLISH) + .replaceAll("-", "_") + .replaceAll("INT(\\d)", "INT_$1") + } + + def fromTask(task: Task): Option[LogicalTypeProxy] = { + task.getLogicalType.map { o => + LogicalTypeProxy( + name = o.getName, + scale = o.getScale, + precision = o.getPrecision, + bitWidth = o.getBitWidth, + isSigned = o.getIsSigned, + isAdjustedToUtc = o.getIsAdjustedToUtc, + timeUnit = o.getTimeUnit, + timeZone = task.getTimeZoneId.map(ZoneId.of) + ) + } + } +} + +trait ParquetColumnType { + def primitiveType(column: Column): PrimitiveType + def glueDataType(column: Column): GlueDataType + def consumeBoolean(consumer: RecordConsumer, v: Boolean): Unit + def consumeString(consumer: RecordConsumer, v: String): Unit + def consumeLong(consumer: RecordConsumer, v: Long): Unit + def consumeDouble(consumer: RecordConsumer, v: Double): Unit + def consumeTimestamp( + consumer: RecordConsumer, + v: Timestamp, + formatter: TimestampFormatter + ): Unit + def consumeJson(consumer: RecordConsumer, v: Value): Unit + def newUnsupportedMethodException(methodName: String) = + new ConfigException(s"${getClass.getName}#$methodName is unsupported.") + + protected def consumeLongAsInteger( + consumer: RecordConsumer, + v: Long + ): Unit = { + if (v < Int.MinValue || v > Int.MaxValue) + throw new DataException( + s"Failed to cast Long: $v to Int, " + + s"because $v exceeds ${Int.MaxValue} (Int.MaxValue) or ${Int.MinValue} (Int.MinValue)" + ) + consumer.addInteger(v.toInt) + } +} diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriteSupport.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriteSupport.scala index 3a437a8..a666320 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriteSupport.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriteSupport.scala @@ -1,43 +1,201 @@ package org.embulk.output.s3_parquet.parquet +import java.lang.{StringBuilder => JStringBuilder} +import java.util.{Map => JMap} + import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext +import org.apache.parquet.hadoop.ParquetWriter import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.MessageType -import org.embulk.spi.{PageReader, Schema} +import org.embulk.config.{ + Config, + ConfigDefault, + ConfigException, + ConfigSource, + Task => EmbulkTask +} +import org.embulk.output.s3_parquet.implicits +import org.embulk.output.s3_parquet.parquet.ParquetFileWriteSupport.WriterBuilder +import org.embulk.spi.{Column, ColumnVisitor, PageReader, Schema} +import org.embulk.spi.`type`.{TimestampType, Type, Types} import org.embulk.spi.time.TimestampFormatter +import org.embulk.spi.util.Timestamps +import org.slf4j.Logger + +object ParquetFileWriteSupport { + + import implicits._ + + trait Task extends TimestampFormatter.Task with EmbulkTask { + @Config("column_options") + @ConfigDefault("{}") + def getRawColumnOptions: JMap[String, ConfigSource] + + def getColumnOptions: JMap[String, ParquetColumnType.Task] + def setColumnOptions( + columnOptions: JMap[String, ParquetColumnType.Task] + ): Unit + + @Config("type_options") + @ConfigDefault("{}") + def getRawTypeOptions: JMap[String, ConfigSource] + + def getTypeOptions: JMap[String, ParquetColumnType.Task] + def setTypeOptions(typeOptions: JMap[String, ParquetColumnType.Task]): Unit + } -private[parquet] case class ParquetFileWriteSupport( + case class WriterBuilder(path: Path, writeSupport: ParquetFileWriteSupport) + extends ParquetWriter.Builder[PageReader, WriterBuilder](path) { + override def self(): WriterBuilder = this + override def getWriteSupport( + conf: Configuration + ): WriteSupport[PageReader] = writeSupport + } + + def configure(task: Task): Unit = { + task.setColumnOptions(task.getRawColumnOptions.map { + case (columnName, config) => + columnName -> ParquetColumnType.loadConfig(config) + }) + task.setTypeOptions(task.getRawTypeOptions.map { + case (columnType, config) => + columnType -> ParquetColumnType.loadConfig(config) + }) + } + + private def validateTask(task: Task, schema: Schema): Unit = { + if (task.getColumnOptions == null || task.getTypeOptions == null) + assert(false) + + task.getTypeOptions.keys.foreach( + embulkType + ) // throw ConfigException if unknown type name is found. + + task.getColumnOptions.foreach { + case (c: String, t: ParquetColumnType.Task) => + val column: Column = schema.lookupColumn(c) // throw ConfigException if columnName does not exist. + + if (t.getFormat.isDefined || t.getTimeZoneId.isDefined) { + if (!column.getType.isInstanceOf[TimestampType]) { + // NOTE: Warning is better instead of throwing. + throw new ConfigException( + s"The type of column{name:${column.getName},type:${column.getType.getName}} is not 'timestamp'," + + " but timestamp options (\"format\" or \"timezone\") are set." + ) + } + } + } + } + + private def embulkType(typeName: String): Type = { + Seq( + Types.BOOLEAN, + Types.STRING, + Types.LONG, + Types.DOUBLE, + Types.TIMESTAMP, + Types.JSON + ).foreach { embulkType => + if (embulkType.getName.equals(typeName)) return embulkType + } + throw new ConfigException(s"Unknown embulk type: $typeName.") + } + + def apply(task: Task, schema: Schema): ParquetFileWriteSupport = { + validateTask(task, schema) + + val parquetSchema: Map[Column, ParquetColumnType] = schema.getColumns.map { + c: Column => + c -> task.getColumnOptions.toMap + .get(c.getName) + .orElse(task.getTypeOptions.toMap.get(c.getType.getName)) + .flatMap(ParquetColumnType.fromTask) + .getOrElse(DefaultColumnType) + }.toMap + val timestampFormatters: Seq[TimestampFormatter] = Timestamps + .newTimestampColumnFormatters(task, schema, task.getColumnOptions) + new ParquetFileWriteSupport(schema, parquetSchema, timestampFormatters) + } +} + +case class ParquetFileWriteSupport private ( schema: Schema, - timestampFormatters: Seq[TimestampFormatter], - logicalTypeHandlers: LogicalTypeHandlerStore = LogicalTypeHandlerStore.empty + parquetSchema: Map[Column, ParquetColumnType], + timestampFormatters: Seq[TimestampFormatter] ) extends WriteSupport[PageReader] { - import org.embulk.output.s3_parquet.implicits._ + import implicits._ + + private val messageType: MessageType = + new MessageType("embulk", schema.getColumns.map { c => + parquetSchema(c).primitiveType(c) + }) - private var currentParquetFileWriter: ParquetFileWriter = _ + private var current: RecordConsumer = _ + + def showOutputSchema(logger: Logger): Unit = { + val sb = new JStringBuilder() + sb.append("=== Output Parquet Schema ===\n") + messageType.writeToStringBuilder(sb, null) // NOTE: indent is not used. + sb.append("=============================\n") + sb.toString.split("\n").foreach(logger.info) + } override def init(configuration: Configuration): WriteContext = { - val messageType: MessageType = EmbulkMessageType - .builder() - .withSchema(schema) - .withLogicalTypeHandlers(logicalTypeHandlers) - .build() val metadata: Map[String, String] = Map.empty // NOTE: When is this used? new WriteContext(messageType, metadata) } - override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { - currentParquetFileWriter = ParquetFileWriter( - recordConsumer, - schema, - timestampFormatters, - logicalTypeHandlers - ) - } + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = + current = recordConsumer override def write(record: PageReader): Unit = { - currentParquetFileWriter.write(record) + writingRecord { + schema.visitColumns(new ColumnVisitor { + override def booleanColumn(column: Column): Unit = nullOr(column) { + parquetSchema(column) + .consumeBoolean(current, record.getBoolean(column)) + } + override def longColumn(column: Column): Unit = nullOr(column) { + parquetSchema(column).consumeLong(current, record.getLong(column)) + } + override def doubleColumn(column: Column): Unit = nullOr(column) { + parquetSchema(column).consumeDouble(current, record.getDouble(column)) + } + override def stringColumn(column: Column): Unit = nullOr(column) { + parquetSchema(column).consumeString(current, record.getString(column)) + } + override def timestampColumn(column: Column): Unit = nullOr(column) { + parquetSchema(column).consumeTimestamp( + current, + record.getTimestamp(column), + timestampFormatters(column.getIndex) + ) + } + override def jsonColumn(column: Column): Unit = nullOr(column) { + parquetSchema(column).consumeJson(current, record.getJson(column)) + } + private def nullOr(column: Column)(f: => Unit): Unit = + if (!record.isNull(column)) writingColumn(column)(f) + }) + } + } + + private def writingRecord(f: => Unit): Unit = { + current.startMessage() + f + current.endMessage() + } + + private def writingColumn(column: Column)(f: => Unit): Unit = { + current.startField(column.getName, column.getIndex) + f + current.endField(column.getName, column.getIndex) } + + def newWriterBuilder(pathString: String): WriterBuilder = + WriterBuilder(new Path(pathString), this) } diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriter.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriter.scala deleted file mode 100644 index 5eb5701..0000000 --- a/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriter.scala +++ /dev/null @@ -1,167 +0,0 @@ -package org.embulk.output.s3_parquet.parquet - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.parquet.hadoop.ParquetWriter -import org.apache.parquet.hadoop.api.WriteSupport -import org.apache.parquet.io.api.{Binary, RecordConsumer} -import org.embulk.spi.{Column, ColumnVisitor, PageReader, Schema} -import org.embulk.spi.time.TimestampFormatter - -object ParquetFileWriter { - - case class Builder( - path: Path = null, - schema: Schema = null, - timestampFormatters: Seq[TimestampFormatter] = null, - logicalTypeHandlers: LogicalTypeHandlerStore = - LogicalTypeHandlerStore.empty - ) extends ParquetWriter.Builder[PageReader, Builder](path) { - - def withPath(path: Path): Builder = { - copy(path = path) - } - - def withPath(pathString: String): Builder = { - copy(path = new Path(pathString)) - } - - def withSchema(schema: Schema): Builder = { - copy(schema = schema) - } - - def withTimestampFormatters( - timestampFormatters: Seq[TimestampFormatter] - ): Builder = { - copy(timestampFormatters = timestampFormatters) - } - - def withLogicalTypeHandlers( - logicalTypeHandlers: LogicalTypeHandlerStore - ): Builder = { - copy(logicalTypeHandlers = logicalTypeHandlers) - } - - override def self(): Builder = { - this - } - - override def getWriteSupport( - conf: Configuration - ): WriteSupport[PageReader] = { - ParquetFileWriteSupport(schema, timestampFormatters, logicalTypeHandlers) - } - } - - def builder(): Builder = { - Builder() - } - -} - -private[parquet] case class ParquetFileWriter( - recordConsumer: RecordConsumer, - schema: Schema, - timestampFormatters: Seq[TimestampFormatter], - logicalTypeHandlers: LogicalTypeHandlerStore = LogicalTypeHandlerStore.empty -) { - - def write(record: PageReader): Unit = { - recordConsumer.startMessage() - writeRecord(record) - recordConsumer.endMessage() - } - - private def writeRecord(record: PageReader): Unit = { - - schema.visitColumns(new ColumnVisitor() { - - override def booleanColumn(column: Column): Unit = { - nullOr(column, { - withWriteFieldContext(column, { - recordConsumer.addBoolean(record.getBoolean(column)) - }) - }) - } - - override def longColumn(column: Column): Unit = { - nullOr(column, { - withWriteFieldContext(column, { - recordConsumer.addLong(record.getLong(column)) - }) - }) - } - - override def doubleColumn(column: Column): Unit = { - nullOr(column, { - withWriteFieldContext(column, { - recordConsumer.addDouble(record.getDouble(column)) - }) - }) - } - - override def stringColumn(column: Column): Unit = { - nullOr(column, { - withWriteFieldContext(column, { - val bin = Binary.fromString(record.getString(column)) - recordConsumer.addBinary(bin) - }) - }) - } - - override def timestampColumn(column: Column): Unit = { - nullOr( - column, { - withWriteFieldContext( - column, { - val t = record.getTimestamp(column) - - logicalTypeHandlers.get(column.getName, column.getType) match { - case Some(h) => - h.consume(t, recordConsumer) - case _ => - val ft = timestampFormatters(column.getIndex).format(t) - val bin = Binary.fromString(ft) - recordConsumer.addBinary(bin) - } - } - ) - } - ) - } - - override def jsonColumn(column: Column): Unit = { - nullOr( - column, { - withWriteFieldContext( - column, { - val msgPack = record.getJson(column) - - logicalTypeHandlers.get(column.getName, column.getType) match { - case Some(h) => - h.consume(msgPack, recordConsumer) - case _ => - val bin = Binary.fromString(msgPack.toJson) - recordConsumer.addBinary(bin) - } - } - ) - } - ) - } - - private def nullOr(column: Column, f: => Unit): Unit = { - if (!record.isNull(column)) f - } - - private def withWriteFieldContext(column: Column, f: => Unit): Unit = { - recordConsumer.startField(column.getName, column.getIndex) - f - recordConsumer.endField(column.getName, column.getIndex) - } - - }) - - } - -} diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/TimeLogicalType.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/TimeLogicalType.scala new file mode 100644 index 0000000..55b8d0c --- /dev/null +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/TimeLogicalType.scala @@ -0,0 +1,118 @@ +package org.embulk.output.s3_parquet.parquet + +import java.time.{OffsetTime, ZoneId} +import java.time.temporal.ChronoField.{MICRO_OF_DAY, MILLI_OF_DAY, NANO_OF_DAY} + +import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType, Types} +import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit +import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit.{ + MICROS, + MILLIS, + NANOS +} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.embulk.config.ConfigException +import org.embulk.output.s3_parquet.catalog.GlueDataType +import org.embulk.spi.Column +import org.embulk.spi.`type`.{ + BooleanType, + DoubleType, + JsonType, + LongType, + StringType, + TimestampType +} +import org.embulk.spi.time.{Timestamp, TimestampFormatter} +import org.msgpack.value.Value +import org.slf4j.{Logger, LoggerFactory} + +case class TimeLogicalType( + isAdjustedToUtc: Boolean, + timeUnit: TimeUnit, + timeZone: ZoneId +) extends ParquetColumnType { + private val logger: Logger = LoggerFactory.getLogger(classOf[TimeLogicalType]) + private val UTC: ZoneId = ZoneId.of("UTC") + + override def primitiveType(column: Column): PrimitiveType = + column.getType match { + case _: LongType | _: TimestampType => + Types + .optional(timeUnit match { + case MILLIS => PrimitiveTypeName.INT32 + case MICROS | NANOS => PrimitiveTypeName.INT64 + }) + .as(LogicalTypeAnnotation.timeType(isAdjustedToUtc, timeUnit)) + .named(column.getName) + case _: BooleanType | _: DoubleType | _: StringType | _: JsonType | _ => + throw new ConfigException(s"Unsupported column type: ${column.getName}") + } + + override def glueDataType(column: Column): GlueDataType = + column.getType match { + case _: LongType | _: TimestampType => + timeUnit match { + case MILLIS => + warningWhenConvertingTimeToGlueType(GlueDataType.INT) + GlueDataType.INT + case MICROS | NANOS => + warningWhenConvertingTimeToGlueType(GlueDataType.BIGINT) + GlueDataType.BIGINT + } + case _: BooleanType | _: DoubleType | _: StringType | _: JsonType | _ => + throw new ConfigException(s"Unsupported column type: ${column.getName}") + } + + override def consumeBoolean(consumer: RecordConsumer, v: Boolean): Unit = + throw newUnsupportedMethodException("consumeBoolean") + + override def consumeString(consumer: RecordConsumer, v: String): Unit = + throw newUnsupportedMethodException("consumeString") + + override def consumeLong(consumer: RecordConsumer, v: Long): Unit = + timeUnit match { + case MILLIS => consumeLongAsInteger(consumer, v) + case MICROS | NANOS => consumer.addLong(v) + } + + override def consumeDouble(consumer: RecordConsumer, v: Double): Unit = + throw newUnsupportedMethodException("consumeDouble") + + override def consumeTimestamp( + consumer: RecordConsumer, + v: Timestamp, + formatter: TimestampFormatter + ): Unit = { + // * `TIME` with precision `MILLIS` is used for millisecond precision. + // It must annotate an `int32` that stores the number of milliseconds after midnight. + // * `TIME` with precision `MICROS` is used for microsecond precision. + // It must annotate an `int64` that stores the number of microseconds after midnight. + // * `TIME` with precision `NANOS` is used for nanosecond precision. + // It must annotate an `int64` that stores the number of nanoseconds after midnight. + // + // ref. https://github.com/apache/parquet-format/blob/apache-parquet-format-2.7.0/LogicalTypes.md#time + val zoneId = if (isAdjustedToUtc) UTC else timeZone + val offsetTime: OffsetTime = OffsetTime.ofInstant(v.getInstant, zoneId) + timeUnit match { + case MILLIS => + consumeLongAsInteger(consumer, offsetTime.get(MILLI_OF_DAY)) + case MICROS => + consumer.addLong(offsetTime.getLong(MICRO_OF_DAY)) + case NANOS => + consumer.addLong(offsetTime.getLong(NANO_OF_DAY)) + } + } + + override def consumeJson(consumer: RecordConsumer, v: Value): Unit = + throw newUnsupportedMethodException("consumeJson") + + private def warningWhenConvertingTimeToGlueType( + glueType: GlueDataType + ): Unit = + logger.warn( + s"time(isAdjustedToUtc = $isAdjustedToUtc, timeUnit = $timeUnit) is converted to Glue" + + s" ${glueType.name} but this is not represented correctly, because Glue does not" + + s" support time type. Please use `catalog.column_options` to define the type." + ) +} diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/TimestampLogicalType.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/TimestampLogicalType.scala new file mode 100644 index 0000000..7f6809d --- /dev/null +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/TimestampLogicalType.scala @@ -0,0 +1,95 @@ +package org.embulk.output.s3_parquet.parquet + +import java.time.ZoneId + +import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType, Types} +import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit +import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit.{ + MICROS, + MILLIS, + NANOS +} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.embulk.config.ConfigException +import org.embulk.output.s3_parquet.catalog.GlueDataType +import org.embulk.spi.`type`.{ + BooleanType, + DoubleType, + JsonType, + LongType, + StringType, + TimestampType +} +import org.embulk.spi.time.{Timestamp, TimestampFormatter} +import org.embulk.spi.Column +import org.msgpack.value.Value +import org.slf4j.{Logger, LoggerFactory} + +case class TimestampLogicalType( + isAdjustedToUtc: Boolean, + timeUnit: TimeUnit, + timeZone: ZoneId +) extends ParquetColumnType { + private val logger: Logger = + LoggerFactory.getLogger(classOf[TimestampLogicalType]) + + override def primitiveType(column: Column): PrimitiveType = + column.getType match { + case _: LongType | _: TimestampType => + Types + .optional(PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(isAdjustedToUtc, timeUnit)) + .named(column.getName) + case _: BooleanType | _: DoubleType | _: StringType | _: JsonType | _ => + throw new ConfigException(s"Unsupported column type: ${column.getName}") + } + + override def glueDataType(column: Column): GlueDataType = + column.getType match { + case _: LongType | _: TimestampType => + timeUnit match { + case MILLIS => GlueDataType.TIMESTAMP + case MICROS | NANOS => + warningWhenConvertingTimestampToGlueType(GlueDataType.BIGINT) + GlueDataType.BIGINT + } + case _: BooleanType | _: DoubleType | _: StringType | _: JsonType | _ => + throw new ConfigException(s"Unsupported column type: ${column.getName}") + } + + override def consumeBoolean(consumer: RecordConsumer, v: Boolean): Unit = + throw newUnsupportedMethodException("consumeBoolean") + override def consumeString(consumer: RecordConsumer, v: String): Unit = + throw newUnsupportedMethodException("consumeString") + + override def consumeLong(consumer: RecordConsumer, v: Long): Unit = + consumer.addLong(v) + + override def consumeDouble(consumer: RecordConsumer, v: Double): Unit = + throw newUnsupportedMethodException("consumeDouble") + + override def consumeTimestamp( + consumer: RecordConsumer, + v: Timestamp, + formatter: TimestampFormatter + ): Unit = timeUnit match { + case MILLIS => consumer.addLong(v.toEpochMilli) + case MICROS => + consumer.addLong(v.getEpochSecond * 1_000_000L + (v.getNano / 1_000L)) + case NANOS => + consumer.addLong(v.getEpochSecond * 1_000_000_000L + v.getNano) + } + + override def consumeJson(consumer: RecordConsumer, v: Value): Unit = + throw newUnsupportedMethodException("consumeJson") + + private def warningWhenConvertingTimestampToGlueType( + glueType: GlueDataType + ): Unit = + logger.warn( + s"timestamp(isAdjustedToUtc = $isAdjustedToUtc, timeUnit = $timeUnit) is converted" + + s" to Glue ${glueType.name} but this is not represented correctly, because Glue" + + s" does not support time type. Please use `catalog.column_options` to define the type." + ) +} diff --git a/src/test/scala/org/embulk/output/s3_parquet/EmbulkPluginTestHelper.scala b/src/test/scala/org/embulk/output/s3_parquet/EmbulkPluginTestHelper.scala index 75423aa..a247016 100644 --- a/src/test/scala/org/embulk/output/s3_parquet/EmbulkPluginTestHelper.scala +++ b/src/test/scala/org/embulk/output/s3_parquet/EmbulkPluginTestHelper.scala @@ -39,12 +39,12 @@ import org.embulk.plugin.{ PluginClassLoaderModule } import org.embulk.spi.{Exec, ExecSession, OutputPlugin, PageTestUtils, Schema} -import org.msgpack.value.{Value, ValueFactory} +import org.embulk.spi.json.JsonParser +import org.msgpack.value.Value import org.scalatest.funsuite.AnyFunSuite import org.scalatest.BeforeAndAfter import org.scalatest.diagrams.Diagrams -import scala.jdk.CollectionConverters._ import scala.util.Using object EmbulkPluginTestHelper { @@ -84,6 +84,8 @@ abstract class EmbulkPluginTestHelper extends AnyFunSuite with BeforeAndAfter with Diagrams { + import implicits._ + private var exec: ExecSession = _ val TEST_S3_ENDPOINT: String = "http://localhost:4572" @@ -105,7 +107,7 @@ abstract class EmbulkPluginTestHelper withLocalStackS3Client { cli => @scala.annotation.tailrec def rmRecursive(listing: ObjectListing): Unit = { - listing.getObjectSummaries.asScala.foreach(o => + listing.getObjectSummaries.foreach(o => cli.deleteObject(TEST_BUCKET_NAME, o.getKey) ) if (listing.isTruncated) @@ -116,50 +118,47 @@ abstract class EmbulkPluginTestHelper withLocalStackS3Client(_.deleteBucket(TEST_BUCKET_NAME)) } + def execDoWith[A](f: => A): A = + try Exec.doWith(exec, () => f) + catch { + case ex: ExecutionException => throw ex.getCause + } + def runOutput( outConfig: ConfigSource, schema: Schema, data: Seq[Seq[Any]], messageTypeTest: MessageType => Unit = { _ => } ): Seq[Seq[AnyRef]] = { - try { - Exec.doWith( - exec, - () => { - val plugin = - exec.getInjector.getInstance(classOf[S3ParquetOutputPlugin]) - plugin.transaction( - outConfig, - schema, - 1, - (taskSource: TaskSource) => { - Using.resource(plugin.open(taskSource, schema, 0)) { output => - try { - PageTestUtils - .buildPage( - exec.getBufferAllocator, - schema, - data.flatten: _* - ) - .asScala - .foreach(output.add) - output.commit() - } - catch { - case ex: Throwable => - output.abort() - throw ex - } - } - Seq.empty.asJava + execDoWith { + val plugin = + exec.getInjector.getInstance(classOf[S3ParquetOutputPlugin]) + plugin.transaction( + outConfig, + schema, + 1, + (taskSource: TaskSource) => { + Using.resource(plugin.open(taskSource, schema, 0)) { output => + try { + PageTestUtils + .buildPage( + exec.getBufferAllocator, + schema, + data.flatten: _* + ) + .foreach(output.add) + output.commit() } - ) + catch { + case ex: Throwable => + output.abort() + throw ex + } + } + Seq.empty } ) } - catch { - case ex: ExecutionException => throw ex.getCause - } readS3Parquet(TEST_BUCKET_NAME, TEST_PATH_PREFIX, messageTypeTest) } @@ -243,9 +242,8 @@ abstract class EmbulkPluginTestHelper ): Seq[Seq[AnyRef]] = { val simpleRecord: SimpleRecord = reader.read() if (simpleRecord != null) { - val r: Seq[AnyRef] = simpleRecord.getValues.asScala + val r: Seq[AnyRef] = simpleRecord.getValues .map(_.getValue) - .toSeq return read(reader, records :+ r) } records @@ -254,29 +252,8 @@ abstract class EmbulkPluginTestHelper finally reader.close() } - def loadConfigSourceFromYamlString(yaml: String): ConfigSource = { + def loadConfigSourceFromYamlString(yaml: String): ConfigSource = new ConfigLoader(exec.getModelManager).fromYamlString(yaml) - } - - def newJson(map: Map[String, Any]): Value = { - ValueFactory - .newMapBuilder() - .putAll(map.map { - case (k: String, v: Any) => - val value: Value = - v match { - case str: String => ValueFactory.newString(str) - case bool: Boolean => ValueFactory.newBoolean(bool) - case long: Long => ValueFactory.newInteger(long) - case int: Int => ValueFactory.newInteger(int) - case double: Double => ValueFactory.newFloat(double) - case float: Float => ValueFactory.newFloat(float) - case _ => ValueFactory.newNil() - } - ValueFactory.newString(k) -> value - }.asJava) - .build() - } def newDefaultConfig: ConfigSource = loadConfigSourceFromYamlString( @@ -291,4 +268,6 @@ abstract class EmbulkPluginTestHelper |default_timezone: Asia/Tokyo |""".stripMargin ) + + def json(str: String): Value = new JsonParser().parse(str) } diff --git a/src/test/scala/org/embulk/output/s3_parquet/TestS3ParquetOutputPlugin.scala b/src/test/scala/org/embulk/output/s3_parquet/TestS3ParquetOutputPlugin.scala index 71720d6..7e353c3 100644 --- a/src/test/scala/org/embulk/output/s3_parquet/TestS3ParquetOutputPlugin.scala +++ b/src/test/scala/org/embulk/output/s3_parquet/TestS3ParquetOutputPlugin.scala @@ -24,11 +24,11 @@ class TestS3ParquetOutputPlugin extends EmbulkPluginTestHelper { // scalafmt: { maxColumn = 200 } val parser = TimestampParser.of("%Y-%m-%d %H:%M:%S.%N %z", "UTC") val data: Seq[Seq[Any]] = Seq( - Seq(true, 0L, 0.0d, "c212c89f91", parser.parse("2017-10-22 19:53:31.000000 +0900"), newJson(Map("a" -> 0, "b" -> "00"))), - Seq(false, 1L, -0.5d, "aaaaa", parser.parse("2017-10-22 19:53:31.000000 +0900"), newJson(Map("a" -> 1, "b" -> "11"))), - Seq(false, 2L, 1.5d, "90823c6a1f", parser.parse("2017-10-23 23:42:43.000000 +0900"), newJson(Map("a" -> 2, "b" -> "22"))), - Seq(true, 3L, 0.44d, "", parser.parse("2017-10-22 06:12:13.000000 +0900"), newJson(Map("a" -> 3, "b" -> "33", "c" -> 3.3))), - Seq(false, 9999L, 10000.33333d, "e56a40571c", parser.parse("2017-10-23 04:59:16.000000 +0900"), newJson(Map("a" -> 4, "b" -> "44", "c" -> 4.4, "d" -> true))) + Seq(true, 0L, 0.0d, "c212c89f91", parser.parse("2017-10-22 19:53:31.000000 +0900"), json("""{"a":0,"b":"00"}""")), + Seq(false, 1L, -0.5d, "aaaaa", parser.parse("2017-10-22 19:53:31.000000 +0900"), json("""{"a":1,"b":"11"}""")), + Seq(false, 2L, 1.5d, "90823c6a1f", parser.parse("2017-10-23 23:42:43.000000 +0900"), json("""{"a":2,"b":"22"}""")), + Seq(true, 3L, 0.44d, "", parser.parse("2017-10-22 06:12:13.000000 +0900"), json("""{"a":3,"b":"33","c":3.3}""")), + Seq(false, 9999L, 10000.33333d, "e56a40571c", parser.parse("2017-10-23 04:59:16.000000 +0900"), json("""{"a":4,"b":"44","c":4.4,"d":true}""")) ) // scalafmt: { maxColumn = 80 } @@ -38,79 +38,22 @@ class TestS3ParquetOutputPlugin extends EmbulkPluginTestHelper { schema, data, messageTypeTest = { messageType => - assert( - PrimitiveTypeName.BOOLEAN == messageType.getColumns - .get(0) - .getPrimitiveType - .getPrimitiveTypeName - ) - assert( - PrimitiveTypeName.INT64 == messageType.getColumns - .get(1) - .getPrimitiveType - .getPrimitiveTypeName - ) - assert( - PrimitiveTypeName.DOUBLE == messageType.getColumns - .get(2) - .getPrimitiveType - .getPrimitiveTypeName - ) - assert( - PrimitiveTypeName.BINARY == messageType.getColumns - .get(3) - .getPrimitiveType - .getPrimitiveTypeName - ) - assert( - PrimitiveTypeName.BINARY == messageType.getColumns - .get(4) - .getPrimitiveType - .getPrimitiveTypeName - ) - assert( - PrimitiveTypeName.BINARY == messageType.getColumns - .get(5) - .getPrimitiveType - .getPrimitiveTypeName - ) - - assert( - null == messageType.getColumns - .get(0) - .getPrimitiveType - .getLogicalTypeAnnotation - ) - assert( - null == messageType.getColumns - .get(1) - .getPrimitiveType - .getLogicalTypeAnnotation - ) - assert( - null == messageType.getColumns - .get(2) - .getPrimitiveType - .getLogicalTypeAnnotation - ) - assert( - LogicalTypeAnnotation.stringType() == messageType.getColumns - .get(3) - .getPrimitiveType - .getLogicalTypeAnnotation - ) - assert( - LogicalTypeAnnotation.stringType() == messageType.getColumns - .get(4) - .getPrimitiveType - .getLogicalTypeAnnotation - ) - assert( - LogicalTypeAnnotation.stringType() == messageType.getColumns - .get(5) - .getPrimitiveType - .getLogicalTypeAnnotation - ) + // format: off + assert(PrimitiveTypeName.BOOLEAN == messageType.getColumns.get(0).getPrimitiveType.getPrimitiveTypeName) + assert(PrimitiveTypeName.INT64 == messageType.getColumns.get(1).getPrimitiveType.getPrimitiveTypeName) + assert(PrimitiveTypeName.DOUBLE == messageType.getColumns.get(2).getPrimitiveType.getPrimitiveTypeName) + assert(PrimitiveTypeName.BINARY == messageType.getColumns.get(3).getPrimitiveType.getPrimitiveTypeName) + assert(PrimitiveTypeName.BINARY == messageType.getColumns.get(4).getPrimitiveType.getPrimitiveTypeName) + assert(PrimitiveTypeName.BINARY == messageType.getColumns.get(5).getPrimitiveType.getPrimitiveTypeName) + + assert(null == messageType.getColumns.get(0).getPrimitiveType.getLogicalTypeAnnotation) + assert(null == messageType.getColumns.get(1).getPrimitiveType.getLogicalTypeAnnotation) + assert(null == messageType.getColumns.get(2).getPrimitiveType.getLogicalTypeAnnotation) + + assert(LogicalTypeAnnotation.stringType() == messageType.getColumns.get(3).getPrimitiveType.getLogicalTypeAnnotation) + assert(LogicalTypeAnnotation.stringType() == messageType.getColumns.get(4).getPrimitiveType.getLogicalTypeAnnotation) + assert(LogicalTypeAnnotation.stringType() == messageType.getColumns.get(5).getPrimitiveType.getLogicalTypeAnnotation) + // format: on } ) @@ -160,21 +103,10 @@ class TestS3ParquetOutputPlugin extends EmbulkPluginTestHelper { schema, data, messageTypeTest = { messageType => - assert( - PrimitiveTypeName.INT64 == messageType.getColumns - .get(0) - .getPrimitiveType - .getPrimitiveTypeName - ) - assert( - LogicalTypeAnnotation.timestampType( - true, - LogicalTypeAnnotation.TimeUnit.MILLIS - ) == messageType.getColumns - .get(0) - .getPrimitiveType - .getLogicalTypeAnnotation - ) + // format: off + assert(PrimitiveTypeName.INT64 == messageType.getColumns.get(0).getPrimitiveType.getPrimitiveTypeName) + assert(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.MILLIS) == messageType.getColumns.get(0).getPrimitiveType.getLogicalTypeAnnotation) + // format: on } ) @@ -206,31 +138,20 @@ class TestS3ParquetOutputPlugin extends EmbulkPluginTestHelper { schema, data, messageTypeTest = { messageType => - assert( - PrimitiveTypeName.INT64 == messageType.getColumns - .get(0) - .getPrimitiveType - .getPrimitiveTypeName - ) - assert( - LogicalTypeAnnotation.timestampType( - true, - LogicalTypeAnnotation.TimeUnit.MICROS - ) == messageType.getColumns - .get(0) - .getPrimitiveType - .getLogicalTypeAnnotation - ) + // format: off + assert(PrimitiveTypeName.INT64 == messageType.getColumns.get(0).getPrimitiveType.getPrimitiveTypeName) + assert(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.MICROS) == messageType.getColumns.get(0).getPrimitiveType.getLogicalTypeAnnotation) + // format: on } ) assert(data.size == result.size) data.indices.foreach { i => - assert { - data(i).head.pipe(ts => - (ts.getEpochSecond * 1_000_000L) + (ts.getNano / 1_000L) - ) == result(i).head.asInstanceOf[Long] - } + // format: off + assert( + data(i).head.pipe(ts => (ts.getEpochSecond * 1_000_000L) + (ts.getNano / 1_000L)) == result(i).head.asInstanceOf[Long] + ) + // format: on } } @@ -254,31 +175,18 @@ class TestS3ParquetOutputPlugin extends EmbulkPluginTestHelper { schema, data, messageTypeTest = { messageType => - assert( - PrimitiveTypeName.INT64 == messageType.getColumns - .get(0) - .getPrimitiveType - .getPrimitiveTypeName - ) - assert( - LogicalTypeAnnotation.timestampType( - true, - LogicalTypeAnnotation.TimeUnit.NANOS - ) == messageType.getColumns - .get(0) - .getPrimitiveType - .getLogicalTypeAnnotation - ) + // format: off + assert(PrimitiveTypeName.INT64 == messageType.getColumns.get(0).getPrimitiveType.getPrimitiveTypeName) + assert(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.NANOS) == messageType.getColumns.get(0).getPrimitiveType.getLogicalTypeAnnotation) + // format: on } ) assert(data.size == result.size) data.indices.foreach { i => - assert { - data(i).head.pipe(ts => - (ts.getEpochSecond * 1_000_000_000L) + ts.getNano - ) == result(i).head.asInstanceOf[Long] - } + // format: off + assert(data(i).head.pipe(ts => (ts.getEpochSecond * 1_000_000_000L) + ts.getNano) == result(i).head.asInstanceOf[Long]) + // format: on } } } diff --git a/src/test/scala/org/embulk/output/s3_parquet/TestS3ParquetOutputPluginConfigException.scala b/src/test/scala/org/embulk/output/s3_parquet/TestS3ParquetOutputPluginConfigException.scala index cdf4fd7..93e752c 100644 --- a/src/test/scala/org/embulk/output/s3_parquet/TestS3ParquetOutputPluginConfigException.scala +++ b/src/test/scala/org/embulk/output/s3_parquet/TestS3ParquetOutputPluginConfigException.scala @@ -22,11 +22,7 @@ class TestS3ParquetOutputPluginConfigException extends EmbulkPluginTestHelper { ) val caught = intercept[ConfigException](runOutput(cfg, schema, data)) assert(caught.isInstanceOf[ConfigException]) - assert( - caught.getMessage.startsWith( - "string is not convertible by org.embulk.output.s3_parquet.parquet.TimestampMillisLogicalTypeHandler" - ) - ) + assert(caught.getMessage.startsWith("Unsupported column type: ")) } test( @@ -45,11 +41,7 @@ class TestS3ParquetOutputPluginConfigException extends EmbulkPluginTestHelper { ) val caught = intercept[ConfigException](runOutput(cfg, schema, data)) assert(caught.isInstanceOf[ConfigException]) - assert( - caught.getMessage.startsWith( - "string is not convertible by org.embulk.output.s3_parquet.parquet.TimestampMillisLogicalTypeHandler" - ) - ) + assert(caught.getMessage.startsWith("Unsupported column type: ")) } diff --git a/src/test/scala/org/embulk/output/s3_parquet/parquet/MockParquetRecordConsumer.scala b/src/test/scala/org/embulk/output/s3_parquet/parquet/MockParquetRecordConsumer.scala new file mode 100644 index 0000000..9ea44e0 --- /dev/null +++ b/src/test/scala/org/embulk/output/s3_parquet/parquet/MockParquetRecordConsumer.scala @@ -0,0 +1,59 @@ +package org.embulk.output.s3_parquet.parquet + +import org.apache.parquet.io.api.{Binary, RecordConsumer} + +case class MockParquetRecordConsumer() extends RecordConsumer { + case class Data private (messages: Seq[Message] = Seq()) { + def toData: Seq[Seq[Any]] = messages.map(_.toData) + } + case class Message private (fields: Seq[Field] = Seq()) { + def toData: Seq[Any] = { + val maxIndex: Int = fields.maxBy(_.index).index + val raw: Map[Int, Any] = fields.map(f => f.index -> f.value).toMap + 0.to(maxIndex).map(idx => raw.get(idx).orNull) + } + } + case class Field private (index: Int = 0, value: Any = null) + + private var _data: Data = Data() + private var _message: Message = Message() + private var _field: Field = Field() + + override def startMessage(): Unit = _message = Message() + override def endMessage(): Unit = + _data = _data.copy(messages = _data.messages :+ _message) + override def startField(field: String, index: Int): Unit = + _field = Field(index = index) + override def endField(field: String, index: Int): Unit = + _message = _message.copy(fields = _message.fields :+ _field) + override def startGroup(): Unit = throw new UnsupportedOperationException + override def endGroup(): Unit = throw new UnsupportedOperationException + override def addInteger(value: Int): Unit = + _field = _field.copy(value = value) + override def addLong(value: Long): Unit = _field = _field.copy(value = value) + override def addBoolean(value: Boolean): Unit = + _field = _field.copy(value = value) + override def addBinary(value: Binary): Unit = + _field = _field.copy(value = value) + override def addFloat(value: Float): Unit = + _field = _field.copy(value = value) + override def addDouble(value: Double): Unit = + _field = _field.copy(value = value) + + def writingMessage(f: => Unit): Unit = { + startMessage() + f + endMessage() + } + def writingField(field: String, index: Int)(f: => Unit): Unit = { + startField(field, index) + f + endField(field, index) + } + def writingSampleField(f: => Unit): Unit = { + writingMessage { + writingField("a", 0)(f) + } + } + def data: Seq[Seq[Any]] = _data.toData +} diff --git a/src/test/scala/org/embulk/output/s3_parquet/parquet/ParquetColumnTypeTestHelper.scala b/src/test/scala/org/embulk/output/s3_parquet/parquet/ParquetColumnTypeTestHelper.scala new file mode 100644 index 0000000..313b5ba --- /dev/null +++ b/src/test/scala/org/embulk/output/s3_parquet/parquet/ParquetColumnTypeTestHelper.scala @@ -0,0 +1,17 @@ +package org.embulk.output.s3_parquet.parquet + +import org.embulk.spi.Column +import org.embulk.spi.`type`.Types + +trait ParquetColumnTypeTestHelper { + + val SAMPLE_BOOLEAN_COLUMN: Column = new Column(0, "a", Types.BOOLEAN) + val SAMPLE_LONG_COLUMN: Column = new Column(0, "a", Types.LONG) + val SAMPLE_DOUBLE_COLUMN: Column = new Column(0, "a", Types.DOUBLE) + val SAMPLE_STRING_COLUMN: Column = new Column(0, "a", Types.STRING) + val SAMPLE_TIMESTAMP_COLUMN: Column = new Column(0, "a", Types.TIMESTAMP) + val SAMPLE_JSON_COLUMN: Column = new Column(0, "a", Types.JSON) + + def newMockRecordConsumer(): MockParquetRecordConsumer = + MockParquetRecordConsumer() +} diff --git a/src/test/scala/org/embulk/output/s3_parquet/parquet/TestDateLogicalType.scala b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestDateLogicalType.scala new file mode 100644 index 0000000..8b3d645 --- /dev/null +++ b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestDateLogicalType.scala @@ -0,0 +1,178 @@ +package org.embulk.output.s3_parquet.parquet + +import org.apache.parquet.schema.LogicalTypeAnnotation +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.embulk.config.ConfigException +import org.embulk.output.s3_parquet.catalog.GlueDataType +import org.embulk.spi.DataException +import org.embulk.spi.time.Timestamp +import org.scalatest.diagrams.Diagrams +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.prop.TableDrivenPropertyChecks + +import scala.util.chaining._ + +class TestDateLogicalType + extends AnyFunSuite + with ParquetColumnTypeTestHelper + with TableDrivenPropertyChecks + with Diagrams { + + private val conditions = Table( + "column", + Seq( + SAMPLE_BOOLEAN_COLUMN, + SAMPLE_LONG_COLUMN, + SAMPLE_DOUBLE_COLUMN, + SAMPLE_STRING_COLUMN, + SAMPLE_TIMESTAMP_COLUMN, + SAMPLE_JSON_COLUMN + ): _* + ) + + private val unsupportedEmbulkColumns = Seq( + SAMPLE_BOOLEAN_COLUMN, + SAMPLE_DOUBLE_COLUMN, + SAMPLE_STRING_COLUMN, + SAMPLE_JSON_COLUMN + ) + + test( + "#primitiveType(column) returns PrimitiveTypeName.INT32 with LogicalType" + ) { + forAll(conditions) { column => + whenever(!unsupportedEmbulkColumns.contains(column)) { + // format: off + assert(PrimitiveTypeName.INT32 == DateLogicalType.primitiveType(column).getPrimitiveTypeName) + assert(LogicalTypeAnnotation.dateType() == DateLogicalType.primitiveType(column).getLogicalTypeAnnotation) + // format: on + } + } + } + + test( + s"#primitiveType(column) cannot return any PrimitiveType when embulk column type is one of (${unsupportedEmbulkColumns + .map(_.getType.getName) + .mkString(",")})" + ) { + forAll(conditions) { column => + whenever(unsupportedEmbulkColumns.contains(column)) { + // format: off + assert(intercept[ConfigException](DateLogicalType.primitiveType(column)).getMessage.startsWith("Unsupported column type: ")) + // format: on + } + } + } + + test("#glueDataType(column) returns GlueDataType") { + forAll(conditions) { column => + whenever(!unsupportedEmbulkColumns.contains(column)) { + assert(GlueDataType.DATE == DateLogicalType.glueDataType(column)) + } + } + } + + test( + s"#glueDataType(column) cannot return any GlueDataType when embulk column type is one of (${unsupportedEmbulkColumns + .map(_.getType.getName) + .mkString(",")})" + ) { + forAll(conditions) { column => + whenever(unsupportedEmbulkColumns.contains(column)) { + // format: off + assert(intercept[ConfigException](DateLogicalType.glueDataType(column)).getMessage.startsWith("Unsupported column type: ")) + // format: on + } + } + } + + test("#consumeBoolean") { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[ConfigException](DateLogicalType.consumeBoolean(consumer, true)).getMessage.endsWith("is unsupported.")) + // format: on + } + } + } + + test("#consumeString") { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[ConfigException](DateLogicalType.consumeString(consumer, "")).getMessage.endsWith("is unsupported.")) + // format: on + } + } + } + + test("#consumeLong") { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + DateLogicalType.consumeLong(consumer, 1L) + } + assert(consumer.data.head.head.isInstanceOf[Int]) + assert(consumer.data.head.head == 1) + } + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[DataException](DateLogicalType.consumeLong(consumer, Long.MaxValue)).getMessage.startsWith("Failed to cast Long: ")) + // format: on + } + } + } + + test("#consumeDouble") { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[ConfigException](DateLogicalType.consumeDouble(consumer, 0.0d)).getMessage.endsWith("is unsupported.")) + // format: on + } + + } + } + + test("#consumeTimestamp") { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + DateLogicalType.consumeTimestamp( + consumer, + Timestamp.ofEpochSecond(24 * 60 * 60), // 1day + null + ) + } + assert(consumer.data.head.head.isInstanceOf[Int]) + assert(consumer.data.head.head == 1) + } + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // NOTE: See. java.time.Instant#MAX_SECOND + val instantMaxEpochSeconds = 31556889864403199L + // format: off + assert(intercept[DataException](DateLogicalType.consumeTimestamp(consumer, Timestamp.ofEpochSecond(instantMaxEpochSeconds), null)).getMessage.startsWith("Failed to cast Long: ")) + // format: on + } + } + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // NOTE: See. java.time.Instant#MIN_SECOND + val instantMinEpochSeconds = -31557014167219200L + // format: off + assert(intercept[DataException](DateLogicalType.consumeTimestamp(consumer, Timestamp.ofEpochSecond(instantMinEpochSeconds), null)).getMessage.startsWith("Failed to cast Long: ")) + // format: on + } + } + } + + test("#consumeJson") { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[ConfigException](DateLogicalType.consumeJson(consumer, null)).getMessage.endsWith("is unsupported.")) + // format: on + } + } + } +} diff --git a/src/test/scala/org/embulk/output/s3_parquet/parquet/TestDecimalLogicalType.scala b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestDecimalLogicalType.scala new file mode 100644 index 0000000..84ff355 --- /dev/null +++ b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestDecimalLogicalType.scala @@ -0,0 +1,179 @@ +package org.embulk.output.s3_parquet.parquet + +import org.apache.parquet.io.api.{Binary, RecordConsumer} +import org.apache.parquet.schema.LogicalTypeAnnotation +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.embulk.config.ConfigException +import org.embulk.output.s3_parquet.catalog.GlueDataType +import org.embulk.spi.`type`.{DoubleType, LongType, StringType} +import org.embulk.spi.DataException +import org.scalatest.diagrams.Diagrams +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.prop.TableDrivenPropertyChecks + +import scala.util.chaining._ + +class TestDecimalLogicalType + extends AnyFunSuite + with ParquetColumnTypeTestHelper + with TableDrivenPropertyChecks + with Diagrams { + + private val conditions = Table( + ("precision", "scale", "column"), { + for { + precision <- Seq(1, 9, 10, 18, 19) + scale <- Seq(0, 1, 20) + column <- Seq( + SAMPLE_BOOLEAN_COLUMN, + SAMPLE_LONG_COLUMN, + SAMPLE_DOUBLE_COLUMN, + SAMPLE_STRING_COLUMN, + SAMPLE_TIMESTAMP_COLUMN, + SAMPLE_JSON_COLUMN + ) + } yield (precision, scale, column) + }: _* + ) + + private val unsupportedEmbulkColumns = Seq( + SAMPLE_BOOLEAN_COLUMN, + SAMPLE_TIMESTAMP_COLUMN, + SAMPLE_JSON_COLUMN + ) + + def isValidScaleAndPrecision(scale: Int, precision: Int): Boolean = + scale >= 0 && scale < precision && precision > 0 + + test("throws IllegalArgumentException") { + // format: off + assert(intercept[IllegalArgumentException](DecimalLogicalType(-1, 5)).getMessage.startsWith("requirement failed: Scale must be zero or a positive integer.")) + assert(intercept[IllegalArgumentException](DecimalLogicalType(10, 5)).getMessage.startsWith("requirement failed: Scale must be a positive integer less than the precision.")) + // format: on + } + + test( + "#primitiveType(column) returns PrimitiveTypeName.{INT32, INT64, BINARY} with LogicalType" + ) { + forAll(conditions) { (precision, scale, column) => + whenever(isValidScaleAndPrecision(scale, precision)) { + // format: off + column.getType match { + case _: LongType if 1 <= precision && precision <= 9 => + assert(PrimitiveTypeName.INT32 == DecimalLogicalType(scale, precision).primitiveType(column).getPrimitiveTypeName) + assert(LogicalTypeAnnotation.decimalType(scale, precision) == DecimalLogicalType(scale, precision).primitiveType(column).getLogicalTypeAnnotation) + case _: LongType if 10 <= precision && precision <= 18 => + assert(PrimitiveTypeName.INT64 == DecimalLogicalType(scale, precision).primitiveType(column).getPrimitiveTypeName) + assert(LogicalTypeAnnotation.decimalType(scale, precision) == DecimalLogicalType(scale, precision).primitiveType(column).getLogicalTypeAnnotation) + case _: StringType | _: DoubleType => + assert(PrimitiveTypeName.BINARY == DecimalLogicalType(scale, precision).primitiveType(column).getPrimitiveTypeName) + assert(LogicalTypeAnnotation.decimalType(scale, precision) == DecimalLogicalType(scale, precision).primitiveType(column).getLogicalTypeAnnotation) + case _ => + assert(intercept[ConfigException](DecimalLogicalType(scale, precision).primitiveType(column)).getMessage.startsWith("Unsupported column type: ")) + } + // format: on + } + } + } + + test("#glueDataType(column) returns GlueDataType") { + forAll(conditions) { (precision, scale, column) => + whenever(isValidScaleAndPrecision(scale, precision)) { + // format: off + column.getType match { + case _: LongType | _: StringType | _: DoubleType => + assert(GlueDataType.DECIMAL(precision, scale) == DecimalLogicalType(scale, precision).glueDataType(column)) + case _ => + assert(intercept[ConfigException](DecimalLogicalType(scale, precision).glueDataType(column)).getMessage.startsWith("Unsupported column type: ")) + } + // format: on + } + } + } + + test("#consumeString") { + forAll(conditions) { (precision, scale, _) => + whenever(isValidScaleAndPrecision(scale, precision)) { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[DataException](DecimalLogicalType(scale, precision).consumeString(consumer, "string")).getMessage.startsWith("Failed to cast String: ")) + // format: on + } + } + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + DecimalLogicalType(scale, precision).consumeString(consumer, "5.5") + } + assert(consumer.data.head.head.isInstanceOf[Binary]) + if (scale == 0) + assert(consumer.data.head.head == Binary.fromString("6")) + else assert(consumer.data.head.head == Binary.fromString("5.5")) + } + } + } + } + + test("#consumeLong") { + forAll(conditions) { (precision, scale, _) => + whenever(isValidScaleAndPrecision(scale, precision) && precision <= 18) { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + DecimalLogicalType(scale, precision) + .consumeLong(consumer, 1L) + } + if (1 <= precision && precision <= 9) { + assert(consumer.data.head.head.isInstanceOf[Int]) + assert(consumer.data.head.head == 1) + } + else { + assert(consumer.data.head.head.isInstanceOf[Long]) + assert(consumer.data.head.head == 1) + } + } + } + whenever(isValidScaleAndPrecision(scale, precision) && precision > 18) { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[ConfigException](DecimalLogicalType(scale, precision).consumeLong(consumer, 1L)).getMessage.startsWith("precision must be 1 <= precision <= 18 when consuming long values but precision is ")) + // format: on + } + } + } + } + } + + test("#consumeDouble") { + forAll(conditions) { (precision, scale, _) => + whenever(isValidScaleAndPrecision(scale, precision)) { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + DecimalLogicalType(scale, precision) + .consumeDouble(consumer, 1.1d) + } + assert(consumer.data.head.head.isInstanceOf[Binary]) + if (scale == 0) + assert(consumer.data.head.head == Binary.fromString("1")) + else assert(consumer.data.head.head == Binary.fromString("1.1")) + } + } + } + } + + test("#consume{Boolean,Timestamp,Json} are unsupported.") { + def assertUnsupportedConsume(f: RecordConsumer => Unit) = + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[ConfigException](f(consumer)).getMessage.endsWith("is unsupported.")) + // format: on + } + } + assertUnsupportedConsume(DecimalLogicalType(5, 10).consumeBoolean(_, true)) + assertUnsupportedConsume( + DecimalLogicalType(5, 10).consumeTimestamp(_, null, null) + ) + assertUnsupportedConsume(DecimalLogicalType(5, 10).consumeJson(_, null)) + } +} diff --git a/src/test/scala/org/embulk/output/s3_parquet/parquet/TestDefaultColumnType.scala b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestDefaultColumnType.scala new file mode 100644 index 0000000..3f07d67 --- /dev/null +++ b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestDefaultColumnType.scala @@ -0,0 +1,157 @@ +package org.embulk.output.s3_parquet.parquet + +import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.LogicalTypeAnnotation +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.embulk.output.s3_parquet.catalog.GlueDataType +import org.embulk.spi.`type`.{ + BooleanType, + DoubleType, + JsonType, + LongType, + StringType, + TimestampType +} +import org.embulk.spi.json.JsonParser +import org.embulk.spi.time.{Timestamp, TimestampFormatter} +import org.scalatest.diagrams.Diagrams +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.prop.TableDrivenPropertyChecks + +import scala.util.chaining._ + +class TestDefaultColumnType + extends AnyFunSuite + with ParquetColumnTypeTestHelper + with TableDrivenPropertyChecks + with Diagrams { + + private val conditions = Table( + "column", + Seq( + SAMPLE_BOOLEAN_COLUMN, + SAMPLE_LONG_COLUMN, + SAMPLE_DOUBLE_COLUMN, + SAMPLE_STRING_COLUMN, + SAMPLE_TIMESTAMP_COLUMN, + SAMPLE_JSON_COLUMN + ): _* + ) + + test( + "#primitiveType(column) returns PrimitiveTypeName.{BOOLEAN,INT64,DOUBLE,BINARY}" + ) { + forAll(conditions) { column => + // format: off + column.getType match { + case _: BooleanType => + assert(PrimitiveTypeName.BOOLEAN == DefaultColumnType.primitiveType(column).getPrimitiveTypeName) + assert(null == DefaultColumnType.primitiveType(column).getLogicalTypeAnnotation) + case _: LongType => + assert(PrimitiveTypeName.INT64 == DefaultColumnType.primitiveType(column).getPrimitiveTypeName) + assert(null == DefaultColumnType.primitiveType(column).getLogicalTypeAnnotation) + case _: DoubleType => + assert(PrimitiveTypeName.DOUBLE == DefaultColumnType.primitiveType(column).getPrimitiveTypeName) + assert(null == DefaultColumnType.primitiveType(column).getLogicalTypeAnnotation) + case _: StringType | _: TimestampType | _: JsonType => + assert(PrimitiveTypeName.BINARY == DefaultColumnType.primitiveType(column).getPrimitiveTypeName) + assert(LogicalTypeAnnotation.stringType() == DefaultColumnType.primitiveType(column).getLogicalTypeAnnotation) + case _ => + fail() + } + // format: on + } + } + + test("#glueDataType(column) returns GlueDataType") { + forAll(conditions) { column => + // format: off + column.getType match { + case _: BooleanType => + assert(GlueDataType.BOOLEAN == DefaultColumnType.glueDataType(column)) + case _: LongType => + assert(GlueDataType.BIGINT == DefaultColumnType.glueDataType(column)) + case _: DoubleType => + assert(GlueDataType.DOUBLE == DefaultColumnType.glueDataType(column)) + case _: StringType | _: TimestampType | _: JsonType => + assert(GlueDataType.STRING == DefaultColumnType.glueDataType(column)) + case _ => + fail() + } + // format: on + } + } + + test("#consumeBoolean") { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + DefaultColumnType.consumeBoolean(consumer, true) + } + assert(consumer.data.head.head.isInstanceOf[Boolean]) + assert(consumer.data.head.head == true) + } + } + + test("#consumeString") { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + DefaultColumnType.consumeString(consumer, "string") + } + assert(consumer.data.head.head.isInstanceOf[Binary]) + assert(consumer.data.head.head == Binary.fromString("string")) + } + } + + test("#consumeLong") { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + DefaultColumnType.consumeLong(consumer, Long.MaxValue) + } + assert(consumer.data.head.head.isInstanceOf[Long]) + assert(consumer.data.head.head == Long.MaxValue) + } + } + + test("#consumeDouble") { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + DefaultColumnType.consumeDouble(consumer, Double.MaxValue) + } + assert(consumer.data.head.head.isInstanceOf[Double]) + assert(consumer.data.head.head == Double.MaxValue) + } + } + + test("#consumeTimestamp") { + val formatter = TimestampFormatter + .of("%Y-%m-%d %H:%M:%S.%6N %z", "UTC") + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + DefaultColumnType.consumeTimestamp( + consumer, + Timestamp.ofEpochMilli(Int.MaxValue), + formatter + ) + } + // format: off + assert(consumer.data.head.head.isInstanceOf[Binary]) + assert(consumer.data.head.head == Binary.fromString("1970-01-25 20:31:23.647000 +0000")) + // format: on + } + } + + test("#consumeJson") { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + DefaultColumnType.consumeJson( + consumer, + new JsonParser().parse("""{"a":1,"b":"c","d":5.5,"e":true}""") + ) + } + // format: off + assert(consumer.data.head.head.isInstanceOf[Binary]) + assert(consumer.data.head.head == Binary.fromString("""{"a":1,"b":"c","d":5.5,"e":true}""")) + // format: on + } + } +} diff --git a/src/test/scala/org/embulk/output/s3_parquet/parquet/TestIntLogicalType.scala b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestIntLogicalType.scala new file mode 100644 index 0000000..12973ac --- /dev/null +++ b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestIntLogicalType.scala @@ -0,0 +1,349 @@ +package org.embulk.output.s3_parquet.parquet + +import org.apache.parquet.schema.LogicalTypeAnnotation +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.embulk.config.ConfigException +import org.embulk.output.s3_parquet.catalog.GlueDataType +import org.embulk.spi.DataException +import org.scalatest.diagrams.Diagrams +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.prop.TableDrivenPropertyChecks + +import scala.util.chaining._ +class TestIntLogicalType + extends AnyFunSuite + with ParquetColumnTypeTestHelper + with TableDrivenPropertyChecks + with Diagrams { + + private val conditions = Table( + ("bitWidth", "isSigned", "column"), { + for { + bitWidth <- Seq(8, 16, 32, 64) + isSigned <- Seq(true, false) + column <- Seq( + SAMPLE_BOOLEAN_COLUMN, + SAMPLE_LONG_COLUMN, + SAMPLE_DOUBLE_COLUMN, + SAMPLE_STRING_COLUMN, + SAMPLE_TIMESTAMP_COLUMN, + SAMPLE_JSON_COLUMN + ) + } yield (bitWidth, isSigned, column) + }: _* + ) + + private val unsupportedEmbulkColumns = Seq( + SAMPLE_TIMESTAMP_COLUMN, + SAMPLE_JSON_COLUMN + ) + + private def isINT32(bitWidth: Int): Boolean = bitWidth < 64 + + test( + "#primitiveType(column) returns PrimitiveTypeName.INT32 with LogicalType" + ) { + forAll(conditions) { (bitWidth, isSigned, column) => + whenever(isINT32(bitWidth) && !unsupportedEmbulkColumns.contains(column)) { + val logicalType = + IntLogicalType(bitWidth = bitWidth, isSigned = isSigned) + // format: off + assert(PrimitiveTypeName.INT32 == logicalType.primitiveType(column).getPrimitiveTypeName) + assert(LogicalTypeAnnotation.intType(bitWidth, isSigned) == logicalType.primitiveType(column).getLogicalTypeAnnotation) + // format: on + } + } + } + + test( + "#primitiveType(column) returns PrimitiveTypeName.INT64 with LogicalType" + ) { + forAll(conditions) { (bitWidth, isSigned, column) => + whenever(!isINT32(bitWidth) && !unsupportedEmbulkColumns.contains(column)) { + val logicalType = + IntLogicalType(bitWidth = bitWidth, isSigned = isSigned) + // format: off + assert(PrimitiveTypeName.INT64 == logicalType.primitiveType(column).getPrimitiveTypeName) + assert(LogicalTypeAnnotation.intType(bitWidth, isSigned) == logicalType.primitiveType(column).getLogicalTypeAnnotation) + // format: on + } + } + } + + test( + s"#primitiveType(column) cannot return any PrimitiveType when embulk column type is one of (${unsupportedEmbulkColumns + .map(_.getType.getName) + .mkString(",")})" + ) { + forAll(conditions) { (bitWidth, isSigned, column) => + whenever(unsupportedEmbulkColumns.contains(column)) { + // format: off + assert(intercept[ConfigException](IntLogicalType(bitWidth = bitWidth, isSigned = isSigned).primitiveType(column)).getMessage.startsWith("Unsupported column type: ")) + // format: on + } + } + } + + test("#glueDataType(column) returns GlueDataType") { + forAll(conditions) { (bitWidth, isSigned, column) => + whenever(!unsupportedEmbulkColumns.contains(column)) { + def assertGlueDataType(expected: GlueDataType) = { + // format: off + assert(expected == IntLogicalType(bitWidth = bitWidth, isSigned = isSigned).glueDataType(column)) + // format: on + } + if (isSigned) { + bitWidth match { + case 8 => assertGlueDataType(GlueDataType.TINYINT) + case 16 => assertGlueDataType(GlueDataType.SMALLINT) + case 32 => assertGlueDataType(GlueDataType.INT) + case 64 => assertGlueDataType(GlueDataType.BIGINT) + case _ => fail() + } + } + else { + bitWidth match { + case 8 => assertGlueDataType(GlueDataType.SMALLINT) + case 16 => assertGlueDataType(GlueDataType.INT) + case 32 => assertGlueDataType(GlueDataType.BIGINT) + case 64 => assertGlueDataType(GlueDataType.BIGINT) + case _ => fail() + } + } + } + } + } + + test( + s"#glueDataType(column) cannot return any GlueDataType when embulk column type is one of (${unsupportedEmbulkColumns + .map(_.getType.getName) + .mkString(",")})" + ) { + forAll(conditions) { (bitWidth, isSigned, column) => + whenever(unsupportedEmbulkColumns.contains(column)) { + // format: off + assert(intercept[ConfigException](IntLogicalType(bitWidth = bitWidth, isSigned = isSigned).glueDataType(column)).getMessage.startsWith("Unsupported column type: ")) + // format: on + } + } + } + + test("#consumeBoolean (INT32)") { + forAll(conditions) { (bitWidth, isSigned, _) => + whenever(isINT32(bitWidth)) { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + IntLogicalType(bitWidth = bitWidth, isSigned = isSigned) + .consumeBoolean(consumer, true) + } + assert(consumer.data.head.head.isInstanceOf[Int]) + assert(consumer.data.head.head == 1) + } + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + IntLogicalType(bitWidth = bitWidth, isSigned = isSigned) + .consumeBoolean(consumer, false) + } + assert(consumer.data.head.head.isInstanceOf[Int]) + assert(consumer.data.head.head == 0) + } + } + } + } + + test("#consumeBoolean (INT64)") { + forAll(conditions) { (bitWidth, isSigned, _) => + whenever(!isINT32(bitWidth)) { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + IntLogicalType(bitWidth = bitWidth, isSigned = isSigned) + .consumeBoolean(consumer, true) + } + assert(consumer.data.head.head.isInstanceOf[Long]) + assert(consumer.data.head.head == 1L) + } + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + IntLogicalType(bitWidth = bitWidth, isSigned = isSigned) + .consumeBoolean(consumer, false) + } + assert(consumer.data.head.head.isInstanceOf[Long]) + assert(consumer.data.head.head == 0L) + } + } + } + } + + test("#consumeString (INT32)") { + forAll(conditions) { (bitWidth, isSigned, _) => + whenever(isINT32(bitWidth)) { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + IntLogicalType(bitWidth = bitWidth, isSigned = isSigned) + .consumeString(consumer, "1") + } + assert(consumer.data.head.head.isInstanceOf[Int]) + assert(consumer.data.head.head == 1) + } + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[DataException](IntLogicalType(bitWidth = bitWidth, isSigned = isSigned).consumeString(consumer, "string")).getMessage.startsWith("Failed to cast String: ")) + // format: on + } + } + } + } + } + + test("#consumeString (INT64)") { + forAll(conditions) { (bitWidth, isSigned, _) => + whenever(!isINT32(bitWidth)) { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + IntLogicalType(bitWidth = bitWidth, isSigned = isSigned) + .consumeString(consumer, "1") + } + assert(consumer.data.head.head.isInstanceOf[Long]) + assert(consumer.data.head.head == 1L) + } + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[DataException](IntLogicalType(bitWidth = bitWidth, isSigned = isSigned).consumeString(consumer, "string")).getMessage.startsWith("Failed to cast String: ")) + // format: on + } + } + } + } + } + + test("#consumeLong (INT32)") { + forAll(conditions) { (bitWidth, isSigned, _) => + whenever(isINT32(bitWidth)) { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + IntLogicalType(bitWidth = bitWidth, isSigned = isSigned) + .consumeLong(consumer, 1L) + } + assert(consumer.data.head.head.isInstanceOf[Int]) + assert(consumer.data.head.head == 1) + } + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[DataException](IntLogicalType(bitWidth = bitWidth, isSigned = isSigned).consumeLong(consumer, Long.MaxValue)).getMessage.startsWith("The value is out of the range: that is ")) + // format: on + } + } + } + } + } + + test("#consumeLong (INT64)") { + forAll(conditions) { (bitWidth, isSigned, _) => + whenever(!isINT32(bitWidth)) { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + IntLogicalType(bitWidth = bitWidth, isSigned = isSigned) + .consumeLong(consumer, 1L) + } + assert(consumer.data.head.head.isInstanceOf[Long]) + assert(consumer.data.head.head == 1L) + } + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + IntLogicalType(bitWidth = bitWidth, isSigned = isSigned) + .consumeLong(consumer, Long.MaxValue) + } + assert(consumer.data.head.head.isInstanceOf[Long]) + assert(consumer.data.head.head == Long.MaxValue) + } + } + } + } + + test("#consumeDouble (INT32)") { + forAll(conditions) { (bitWidth, isSigned, _) => + whenever(isINT32(bitWidth)) { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + IntLogicalType(bitWidth = bitWidth, isSigned = isSigned) + .consumeDouble(consumer, 1.4d) + } + assert(consumer.data.head.head.isInstanceOf[Int]) + assert(consumer.data.head.head == 1) + } + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + IntLogicalType(bitWidth = bitWidth, isSigned = isSigned) + .consumeDouble(consumer, 1.5d) + } + assert(consumer.data.head.head.isInstanceOf[Int]) + assert(consumer.data.head.head == 2) + } + + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[DataException](IntLogicalType(bitWidth = bitWidth, isSigned = isSigned).consumeDouble(consumer, Double.MaxValue)).getMessage.startsWith("The value is out of the range: that is ")) + // format: on + } + } + } + } + } + + test("#consumeDouble (INT64)") { + forAll(conditions) { (bitWidth, isSigned, _) => + whenever(!isINT32(bitWidth)) { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + IntLogicalType(bitWidth = bitWidth, isSigned = isSigned) + .consumeDouble(consumer, 1.4d) + } + assert(consumer.data.head.head.isInstanceOf[Long]) + assert(consumer.data.head.head == 1L) + } + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + IntLogicalType(bitWidth = bitWidth, isSigned = isSigned) + .consumeDouble(consumer, 1.5d) + } + assert(consumer.data.head.head.isInstanceOf[Long]) + assert(consumer.data.head.head == 2L) + } + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[DataException](IntLogicalType(bitWidth = bitWidth, isSigned = isSigned).consumeDouble(consumer, Double.MaxValue)).getMessage.startsWith("The value is out of the range: ")) + // format: on + } + } + } + } + } + + test("#consumeTimestamp is unsupported") { + forAll(conditions) { (bitWidth, isSigned, _) => + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[ConfigException](IntLogicalType(bitWidth = bitWidth, isSigned = isSigned).consumeTimestamp(consumer, null, null)).getMessage.endsWith("is unsupported.")) + // format: on + } + } + } + } + test("#consumeJson is unsupported") { + forAll(conditions) { (bitWidth, isSigned, _) => + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[ConfigException](IntLogicalType(bitWidth = bitWidth, isSigned = isSigned).consumeJson(consumer, null)).getMessage.endsWith("is unsupported.")) + // format: on + } + } + } + } +} diff --git a/src/test/scala/org/embulk/output/s3_parquet/parquet/TestJsonLogicalType.scala b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestJsonLogicalType.scala new file mode 100644 index 0000000..571146c --- /dev/null +++ b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestJsonLogicalType.scala @@ -0,0 +1,148 @@ +package org.embulk.output.s3_parquet.parquet + +import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.LogicalTypeAnnotation +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.embulk.config.ConfigException +import org.embulk.output.s3_parquet.catalog.GlueDataType +import org.embulk.spi.`type`.{ + BooleanType, + DoubleType, + JsonType, + LongType, + StringType +} +import org.embulk.spi.json.JsonParser +import org.embulk.spi.time.TimestampFormatter +import org.msgpack.value.ValueFactory +import org.scalatest.diagrams.Diagrams +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.prop.TableDrivenPropertyChecks + +import scala.util.chaining._ + +class TestJsonLogicalType + extends AnyFunSuite + with ParquetColumnTypeTestHelper + with TableDrivenPropertyChecks + with Diagrams { + + private val conditions = Table( + "column", + Seq( + SAMPLE_BOOLEAN_COLUMN, + SAMPLE_LONG_COLUMN, + SAMPLE_DOUBLE_COLUMN, + SAMPLE_STRING_COLUMN, + SAMPLE_TIMESTAMP_COLUMN, + SAMPLE_JSON_COLUMN + ): _* + ) + + test( + "#primitiveType(column) returns PrimitiveTypeName.{BOOLEAN,INT64,DOUBLE,BINARY} with LogicalType" + ) { + forAll(conditions) { column => + // format: off + column.getType match { + case _: BooleanType | _: LongType | _: DoubleType | _: StringType | + _: JsonType => + assert(PrimitiveTypeName.BINARY == JsonLogicalType.primitiveType(column).getPrimitiveTypeName) + assert(LogicalTypeAnnotation.jsonType() == JsonLogicalType.primitiveType(column).getLogicalTypeAnnotation) + case _ => + assert(intercept[ConfigException](JsonLogicalType.primitiveType(column)).getMessage.startsWith("Unsupported column type: ")) + } + // format: on + } + } + + test("#glueDataType(column) returns GlueDataType") { + forAll(conditions) { column => + // format: off + column.getType match { + case _: BooleanType | _: LongType | _: DoubleType | _: StringType | + _: JsonType => + assert(GlueDataType.STRING == JsonLogicalType.glueDataType(column)) + case _ => + assert(intercept[ConfigException](JsonLogicalType.glueDataType(column)).getMessage.startsWith("Unsupported column type: ")) + } + // format: on + } + } + + test("#consumeBoolean") { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + JsonLogicalType.consumeBoolean(consumer, true) + } + // format: off + assert(consumer.data.head.head.isInstanceOf[Binary]) + assert(consumer.data.head.head == Binary.fromString(ValueFactory.newBoolean(true).toJson)) + // format: on + } + } + + test("#consumeString") { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + JsonLogicalType.consumeString(consumer, "string") + } + // format: off + assert(consumer.data.head.head.isInstanceOf[Binary]) + assert(consumer.data.head.head == Binary.fromString(ValueFactory.newString("string").toJson)) + // format: on + } + } + + test("#consumeLong") { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + JsonLogicalType.consumeLong(consumer, Long.MaxValue) + } + // format: off + assert(consumer.data.head.head.isInstanceOf[Binary]) + assert(consumer.data.head.head == Binary.fromString(ValueFactory.newInteger(Long.MaxValue).toJson)) + // format: on + } + } + + test("#consumeDouble") { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + JsonLogicalType.consumeDouble(consumer, Double.MaxValue) + } + // format: off + assert(consumer.data.head.head.isInstanceOf[Binary]) + assert(consumer.data.head.head == Binary.fromString(ValueFactory.newFloat(Double.MaxValue).toJson)) + // format: on + } + } + + test("#consumeTimestamp") { + val formatter = TimestampFormatter + .of("%Y-%m-%d %H:%M:%S.%6N %z", "UTC") + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[ConfigException](JsonLogicalType.consumeTimestamp(consumer, null, null)).getMessage.endsWith("is unsupported.")) + // format: on + } + } + } + + test("#consumeJson") { + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + JsonLogicalType.consumeJson( + consumer, + new JsonParser().parse("""{"a":1,"b":"c","d":5.5,"e":true}""") + ) + } + // format: off + assert(consumer.data.head.head.isInstanceOf[Binary]) + assert(consumer.data.head.head == Binary.fromString("""{"a":1,"b":"c","d":5.5,"e":true}""")) + // format: on + } + } + +} diff --git a/src/test/scala/org/embulk/output/s3_parquet/parquet/TestLogicalTypeHandler.scala b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestLogicalTypeHandler.scala deleted file mode 100644 index af0668e..0000000 --- a/src/test/scala/org/embulk/output/s3_parquet/parquet/TestLogicalTypeHandler.scala +++ /dev/null @@ -1,101 +0,0 @@ -package org.embulk.output.s3_parquet.parquet - -import org.embulk.spi.DataException -import org.embulk.spi.`type`.Types -import org.scalatest.diagrams.Diagrams -import org.scalatest.funsuite.AnyFunSuite - -import scala.util.Try - -class TestLogicalTypeHandler extends AnyFunSuite with Diagrams { - - test("IntLogicalTypeHandler.isConvertible() returns true for long") { - val h = Int8LogicalTypeHandler - - assert(h.isConvertible(Types.LONG)) - assert(!h.isConvertible(Types.BOOLEAN)) - } - - test( - "IntLogicalTypeHandler.consume() raises DataException if given type is not long" - ) { - val h = Int8LogicalTypeHandler - val actual = Try(h.consume("invalid", null)) - - assert(actual.isFailure) - assert(actual.failed.get.isInstanceOf[DataException]) - } - - test( - "TimestampMillisLogicalTypeHandler.isConvertible() returns true for timestamp" - ) { - val h = TimestampMillisLogicalTypeHandler - - assert(h.isConvertible(Types.TIMESTAMP)) - assert(!h.isConvertible(Types.BOOLEAN)) - } - - test( - "TimestampMillisLogicalTypeHandler.consume() raises DataException if given type is not timestamp" - ) { - val h = TimestampMillisLogicalTypeHandler - val actual = Try(h.consume("invalid", null)) - - assert(actual.isFailure) - assert(actual.failed.get.isInstanceOf[DataException]) - } - - test( - "TimestampMicrosLogicalTypeHandler.isConvertible() returns true for timestamp" - ) { - val h = TimestampMicrosLogicalTypeHandler - - assert(h.isConvertible(Types.TIMESTAMP)) - assert(!h.isConvertible(Types.BOOLEAN)) - } - - test( - "TimestampMicrosLogicalTypeHandler.consume() raises DataException if given type is not timestamp" - ) { - val h = TimestampMicrosLogicalTypeHandler - val actual = Try(h.consume("invalid", null)) - - assert(actual.isFailure) - assert(actual.failed.get.isInstanceOf[DataException]) - } - - test( - "TimestampNanosLogicalTypeHandler.isConvertible() returns true for timestamp" - ) { - val h = TimestampNanosLogicalTypeHandler - - assert(h.isConvertible(Types.TIMESTAMP)) - assert(!h.isConvertible(Types.BOOLEAN)) - } - - test( - "TimestampNanosLogicalTypeHandler.consume() raises DataException if given type is not timestamp" - ) { - val h = TimestampNanosLogicalTypeHandler - val actual = Try(h.consume("invalid", null)) - - assert(actual.isFailure) - assert(actual.failed.get.isInstanceOf[DataException]) - } - - test("JsonLogicalTypeHandler.isConvertible() returns true for json") { - val h = JsonLogicalTypeHandler - - assert(h.isConvertible(Types.JSON)) - assert(!h.isConvertible(Types.BOOLEAN)) - } - - test( - "JsonLogicalTypeHandler.consume() raises DataException if given type is not json" - ) { - val h = JsonLogicalTypeHandler - val actual = Try(h.consume("invalid", null)) - assert(actual.isFailure) - assert(actual.failed.get.isInstanceOf[DataException]) - } -} diff --git a/src/test/scala/org/embulk/output/s3_parquet/parquet/TestLogicalTypeHandlerStore.scala b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestLogicalTypeHandlerStore.scala deleted file mode 100644 index 7d426a5..0000000 --- a/src/test/scala/org/embulk/output/s3_parquet/parquet/TestLogicalTypeHandlerStore.scala +++ /dev/null @@ -1,177 +0,0 @@ -package org.embulk.output.s3_parquet.parquet - -import java.util.Optional - -import com.google.common.base.{Optional => GOptional} -import org.embulk.config.{ConfigException, TaskSource} -import org.embulk.output.s3_parquet.S3ParquetOutputPlugin.{ - ColumnOptionTask, - TypeOptionTask -} -import org.embulk.spi.`type`.{Types, Type => EType} -import org.scalatest.diagrams.Diagrams -import org.scalatest.funsuite.AnyFunSuite - -import scala.jdk.CollectionConverters._ -import scala.util.Try - -class TestLogicalTypeHandlerStore extends AnyFunSuite with Diagrams { - test("empty() returns empty maps") { - val rv = LogicalTypeHandlerStore.empty - - assert(rv.fromColumnName.isEmpty) - assert(rv.fromEmbulkType.isEmpty) - } - - test("fromEmbulkOptions() returns handlers for valid option tasks") { - val typeOpts = Map[String, TypeOptionTask]( - "timestamp" -> DummyTypeOptionTask( - Optional.of[String]("timestamp-millis") - ) - ).asJava - val columnOpts = Map[String, ColumnOptionTask]( - "col1" -> DummyColumnOptionTask(Optional.of[String]("timestamp-micros")) - ).asJava - - val expected1 = Map[EType, LogicalTypeHandler]( - Types.TIMESTAMP -> TimestampMillisLogicalTypeHandler - ) - val expected2 = Map[String, LogicalTypeHandler]( - "col1" -> TimestampMicrosLogicalTypeHandler - ) - - val rv = LogicalTypeHandlerStore.fromEmbulkOptions(typeOpts, columnOpts) - - assert(rv.fromEmbulkType == expected1) - assert(rv.fromColumnName == expected2) - } - - test( - "fromEmbulkOptions() raises ConfigException if invalid option tasks given" - ) { - val emptyTypeOpts = Map.empty[String, TypeOptionTask].asJava - val emptyColumnOpts = Map.empty[String, ColumnOptionTask].asJava - - val invalidTypeOpts = Map[String, TypeOptionTask]( - "unknown-embulk-type-name" -> DummyTypeOptionTask( - Optional.of[String]("timestamp-millis") - ), - "timestamp" -> DummyTypeOptionTask( - Optional.of[String]("unknown-parquet-logical-type-name") - ) - ).asJava - val invalidColumnOpts = Map[String, ColumnOptionTask]( - "col1" -> DummyColumnOptionTask( - Optional.of[String]("unknown-parquet-logical-type-name") - ) - ).asJava - - val try1 = Try( - LogicalTypeHandlerStore - .fromEmbulkOptions(invalidTypeOpts, emptyColumnOpts) - ) - assert(try1.isFailure) - assert(try1.failed.get.isInstanceOf[ConfigException]) - - val try2 = Try( - LogicalTypeHandlerStore - .fromEmbulkOptions(emptyTypeOpts, invalidColumnOpts) - ) - assert(try2.isFailure) - assert(try2.failed.get.isInstanceOf[ConfigException]) - - val try3 = Try( - LogicalTypeHandlerStore - .fromEmbulkOptions(invalidTypeOpts, invalidColumnOpts) - ) - assert(try3.isFailure) - assert(try3.failed.get.isInstanceOf[ConfigException]) - } - - test("get() returns a handler matched with primary column name condition") { - val typeOpts = Map[String, TypeOptionTask]( - "timestamp" -> DummyTypeOptionTask( - Optional.of[String]("timestamp-millis") - ) - ).asJava - val columnOpts = Map[String, ColumnOptionTask]( - "col1" -> DummyColumnOptionTask(Optional.of[String]("timestamp-micros")) - ).asJava - - val handlers = - LogicalTypeHandlerStore.fromEmbulkOptions(typeOpts, columnOpts) - - // It matches both of column name and embulk type, and column name should be primary - val expected = Some(TimestampMicrosLogicalTypeHandler) - val actual = handlers.get("col1", Types.TIMESTAMP) - - assert(actual == expected) - } - - test("get() returns a handler matched with type name condition") { - val typeOpts = Map[String, TypeOptionTask]( - "timestamp" -> DummyTypeOptionTask( - Optional.of[String]("timestamp-millis") - ) - ).asJava - val columnOpts = Map.empty[String, ColumnOptionTask].asJava - - val handlers = - LogicalTypeHandlerStore.fromEmbulkOptions(typeOpts, columnOpts) - - // It matches column name - val expected = Some(TimestampMillisLogicalTypeHandler) - val actual = handlers.get("col1", Types.TIMESTAMP) - - assert(actual == expected) - } - - test("get() returns None if not matched") { - val typeOpts = Map.empty[String, TypeOptionTask].asJava - val columnOpts = Map.empty[String, ColumnOptionTask].asJava - - val handlers = - LogicalTypeHandlerStore.fromEmbulkOptions(typeOpts, columnOpts) - - // It matches embulk type - val actual = handlers.get("col1", Types.TIMESTAMP) - - assert(actual.isEmpty) - } - - private case class DummyTypeOptionTask(lt: Optional[String]) - extends TypeOptionTask { - - override def getLogicalType: Optional[String] = { - lt - } - - override def validate(): Unit = {} - - override def dump(): TaskSource = { - null - } - } - - private case class DummyColumnOptionTask(lt: Optional[String]) - extends ColumnOptionTask { - - override def getTimeZoneId: GOptional[String] = { - GOptional.absent[String] - } - - override def getFormat: GOptional[String] = { - GOptional.absent[String] - } - - override def getLogicalType: Optional[String] = { - lt - } - - override def validate(): Unit = {} - - override def dump(): TaskSource = { - null - } - } -} diff --git a/src/test/scala/org/embulk/output/s3_parquet/parquet/TestTimeLogicalType.scala b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestTimeLogicalType.scala new file mode 100644 index 0000000..14070ab --- /dev/null +++ b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestTimeLogicalType.scala @@ -0,0 +1,223 @@ +package org.embulk.output.s3_parquet.parquet + +import java.time.ZoneId + +import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.schema.LogicalTypeAnnotation +import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit.{ + MICROS, + MILLIS, + NANOS +} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.embulk.config.ConfigException +import org.embulk.output.s3_parquet.catalog.GlueDataType +import org.embulk.spi.DataException +import org.embulk.spi.time.Timestamp +import org.scalatest.diagrams.Diagrams +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.prop.TableDrivenPropertyChecks + +import scala.util.chaining._ + +class TestTimeLogicalType + extends AnyFunSuite + with ParquetColumnTypeTestHelper + with TableDrivenPropertyChecks + with Diagrams { + + private val conditions = Table( + ("isAdjustedToUtc", "timeUnit", "timeZone", "column"), { + for { + isAdjustedToUtc <- Seq(true, false) + timeUnit <- Seq(MILLIS, MICROS, NANOS) + timeZone <- Seq(ZoneId.of("UTC"), ZoneId.of("Asia/Tokyo")) + column <- Seq( + SAMPLE_BOOLEAN_COLUMN, + SAMPLE_LONG_COLUMN, + SAMPLE_DOUBLE_COLUMN, + SAMPLE_STRING_COLUMN, + SAMPLE_TIMESTAMP_COLUMN, + SAMPLE_JSON_COLUMN + ) + } yield (isAdjustedToUtc, timeUnit, timeZone, column) + }: _* + ) + + private val unsupportedEmbulkColumns = Seq( + SAMPLE_BOOLEAN_COLUMN, + SAMPLE_DOUBLE_COLUMN, + SAMPLE_STRING_COLUMN, + SAMPLE_JSON_COLUMN + ) + + test( + "#primitiveType(column) returns PrimitiveTypeName.{INT32,INT64} with LogicalType" + ) { + forAll(conditions) { (isAdjustedToUtc, timeUnit, timeZone, column) => + whenever(unsupportedEmbulkColumns.contains(column)) { + // format: off + assert(intercept[ConfigException](TimeLogicalType(isAdjustedToUtc = isAdjustedToUtc, timeUnit = timeUnit, timeZone = timeZone).primitiveType(column)).getMessage.startsWith("Unsupported column type: ")) + // format: on + } + + whenever(!unsupportedEmbulkColumns.contains(column)) { + val expectedPrimitiveTypeName = + if (timeUnit === MILLIS) PrimitiveTypeName.INT32 + else PrimitiveTypeName.INT64 + // format: off + assert(expectedPrimitiveTypeName == TimeLogicalType(isAdjustedToUtc = isAdjustedToUtc, timeUnit = timeUnit, timeZone = timeZone).primitiveType(column).getPrimitiveTypeName) + assert(LogicalTypeAnnotation.timeType(isAdjustedToUtc, timeUnit) == TimeLogicalType(isAdjustedToUtc = isAdjustedToUtc, timeUnit = timeUnit, timeZone = timeZone).primitiveType(column).getLogicalTypeAnnotation) + // format: on + } + } + } + + test("#glueDataType(column) returns GlueDataType") { + forAll(conditions) { (isAdjustedToUtc, timeUnit, timeZone, column) => + whenever(unsupportedEmbulkColumns.contains(column)) { + // format: off + assert(intercept[ConfigException](TimeLogicalType(isAdjustedToUtc = isAdjustedToUtc, timeUnit = timeUnit, timeZone = timeZone).glueDataType(column)).getMessage.startsWith("Unsupported column type: ")) + // format: on + } + whenever(!unsupportedEmbulkColumns.contains(column)) { + val expectedGlueDataType = + if (timeUnit === MILLIS) GlueDataType.INT + else GlueDataType.BIGINT + // format: off + assert(expectedGlueDataType == TimeLogicalType(isAdjustedToUtc = isAdjustedToUtc, timeUnit = timeUnit, timeZone = timeZone).glueDataType(column)) + // format: on + } + } + } + + test("#consumeLong") { + forAll(conditions) { (isAdjustedToUtc, timeUnit, timeZone, _) => + timeUnit match { + case MILLIS => + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + TimeLogicalType( + isAdjustedToUtc = isAdjustedToUtc, + timeUnit = timeUnit, + timeZone = timeZone + ).consumeLong(consumer, 5) + } + assert(consumer.data.head.head.isInstanceOf[Int]) + assert(consumer.data.head.head == 5) + } + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[DataException](TimeLogicalType(isAdjustedToUtc = isAdjustedToUtc, timeUnit = timeUnit, timeZone = timeZone).consumeLong(consumer, Long.MaxValue)).getMessage.startsWith("Failed to cast Long: ")) + // format: on + } + } + case MICROS | NANOS => + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + TimeLogicalType( + isAdjustedToUtc = isAdjustedToUtc, + timeUnit = timeUnit, + timeZone = timeZone + ).consumeLong(consumer, 5) + } + assert(consumer.data.head.head.isInstanceOf[Long]) + assert(consumer.data.head.head == 5L) + } + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + TimeLogicalType( + isAdjustedToUtc = isAdjustedToUtc, + timeUnit = timeUnit, + timeZone = timeZone + ).consumeLong(consumer, Long.MaxValue) + } + assert(consumer.data.head.head.isInstanceOf[Long]) + assert(consumer.data.head.head == Long.MaxValue) + } + } + } + } + + test("#consumeTimestamp") { + forAll(conditions) { (isAdjustedToUtc, timeUnit, timeZone, _) => + timeUnit match { + case MILLIS => + val v = Timestamp.ofEpochMilli(Int.MaxValue) + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + TimeLogicalType( + isAdjustedToUtc = isAdjustedToUtc, + timeUnit = timeUnit, + timeZone = timeZone + ).consumeTimestamp(consumer, v, null) + } + assert(consumer.data.head.head.isInstanceOf[Int]) + if (timeZone.getId == "Asia/Tokyo" && !isAdjustedToUtc) + assert(consumer.data.head.head == 19883647) + else // UTC + assert(consumer.data.head.head == 73883647) + } + case MICROS => + val v = Timestamp.ofEpochMilli(Int.MaxValue) + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + TimeLogicalType( + isAdjustedToUtc = isAdjustedToUtc, + timeUnit = timeUnit, + timeZone = timeZone + ).consumeTimestamp(consumer, v, null) + } + assert(consumer.data.head.head.isInstanceOf[Long]) + if (timeZone.getId == "Asia/Tokyo" && !isAdjustedToUtc) + assert(consumer.data.head.head == 19883647000L) + else // UTC + assert(consumer.data.head.head == 73883647000L) + } + case NANOS => + val v = Timestamp.ofEpochMilli(Int.MaxValue) + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + TimeLogicalType( + isAdjustedToUtc = isAdjustedToUtc, + timeUnit = timeUnit, + timeZone = timeZone + ).consumeTimestamp(consumer, v, null) + } + assert(consumer.data.head.head.isInstanceOf[Long]) + if (timeZone.getId == "Asia/Tokyo" && !isAdjustedToUtc) + assert(consumer.data.head.head == 19883647000000L) + else // UTC + assert(consumer.data.head.head == 73883647000000L) + } + } + + } + } + + test("#consume{Boolean,Double,String,Json} are unsupported.") { + def assertUnsupportedConsume(f: RecordConsumer => Unit) = + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[ConfigException](f(consumer)).getMessage.endsWith("is unsupported.")) + // format: on + } + } + + forAll(conditions) { (isAdjustedToUtc, timeUnit, timeZone, _) => + val t = + TimeLogicalType( + isAdjustedToUtc = isAdjustedToUtc, + timeUnit = timeUnit, + timeZone = timeZone + ) + assertUnsupportedConsume(t.consumeBoolean(_, true)) + assertUnsupportedConsume(t.consumeDouble(_, 0.0d)) + assertUnsupportedConsume(t.consumeString(_, null)) + assertUnsupportedConsume(t.consumeJson(_, null)) + } + } + +} diff --git a/src/test/scala/org/embulk/output/s3_parquet/parquet/TestTimestampLogicalType.scala b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestTimestampLogicalType.scala new file mode 100644 index 0000000..464b59f --- /dev/null +++ b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestTimestampLogicalType.scala @@ -0,0 +1,189 @@ +package org.embulk.output.s3_parquet.parquet + +import java.time.ZoneId + +import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.schema.LogicalTypeAnnotation +import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit.{ + MICROS, + MILLIS, + NANOS +} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.embulk.config.ConfigException +import org.embulk.output.s3_parquet.catalog.GlueDataType +import org.embulk.spi.time.Timestamp +import org.scalatest.diagrams.Diagrams +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.prop.TableDrivenPropertyChecks + +import scala.util.chaining._ + +class TestTimestampLogicalType + extends AnyFunSuite + with ParquetColumnTypeTestHelper + with TableDrivenPropertyChecks + with Diagrams { + + private val conditions = Table( + ("isAdjustedToUtc", "timeUnit", "timeZone", "column"), { + for { + isAdjustedToUtc <- Seq(true, false) + timeUnit <- Seq(MILLIS, MICROS, NANOS) + timeZone <- Seq(ZoneId.of("UTC"), ZoneId.of("Asia/Tokyo")) + column <- Seq( + SAMPLE_BOOLEAN_COLUMN, + SAMPLE_LONG_COLUMN, + SAMPLE_DOUBLE_COLUMN, + SAMPLE_STRING_COLUMN, + SAMPLE_TIMESTAMP_COLUMN, + SAMPLE_JSON_COLUMN + ) + } yield (isAdjustedToUtc, timeUnit, timeZone, column) + }: _* + ) + + private val unsupportedEmbulkColumns = Seq( + SAMPLE_BOOLEAN_COLUMN, + SAMPLE_DOUBLE_COLUMN, + SAMPLE_STRING_COLUMN, + SAMPLE_JSON_COLUMN + ) + + test( + "#primitiveType(column) returns PrimitiveTypeName.{INT32,INT64} with LogicalType" + ) { + forAll(conditions) { (isAdjustedToUtc, timeUnit, timeZone, column) => + whenever(unsupportedEmbulkColumns.contains(column)) { + // format: off + assert(intercept[ConfigException](TimestampLogicalType(isAdjustedToUtc = isAdjustedToUtc, timeUnit = timeUnit, timeZone = timeZone).primitiveType(column)).getMessage.startsWith("Unsupported column type: ")) + // format: on + } + + whenever(!unsupportedEmbulkColumns.contains(column)) { + // format: off + assert(PrimitiveTypeName.INT64 == TimestampLogicalType(isAdjustedToUtc = isAdjustedToUtc, timeUnit = timeUnit, timeZone = timeZone).primitiveType(column).getPrimitiveTypeName) + assert(LogicalTypeAnnotation.timeType(isAdjustedToUtc, timeUnit) == TimestampLogicalType(isAdjustedToUtc = isAdjustedToUtc, timeUnit = timeUnit, timeZone = timeZone).primitiveType(column).getLogicalTypeAnnotation) + // format: on + } + } + } + + test("#glueDataType(column) returns GlueDataType") { + forAll(conditions) { (isAdjustedToUtc, timeUnit, timeZone, column) => + whenever(unsupportedEmbulkColumns.contains(column)) { + // format: off + assert(intercept[ConfigException](TimestampLogicalType(isAdjustedToUtc = isAdjustedToUtc, timeUnit = timeUnit, timeZone = timeZone).glueDataType(column)).getMessage.startsWith("Unsupported column type: ")) + // format: on + } + whenever(!unsupportedEmbulkColumns.contains(column)) { + val expectedGlueDataType = + if (timeUnit === MILLIS) GlueDataType.TIMESTAMP + else GlueDataType.BIGINT + // format: off + assert(expectedGlueDataType == TimestampLogicalType(isAdjustedToUtc = isAdjustedToUtc, timeUnit = timeUnit, timeZone = timeZone).glueDataType(column)) + // format: on + } + } + } + + test("#consumeLong") { + forAll(conditions) { (isAdjustedToUtc, timeUnit, timeZone, _) => + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + TimestampLogicalType( + isAdjustedToUtc = isAdjustedToUtc, + timeUnit = timeUnit, + timeZone = timeZone + ).consumeLong(consumer, 5) + } + assert(consumer.data.head.head.isInstanceOf[Long]) + assert(consumer.data.head.head == 5L) + } + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + TimestampLogicalType( + isAdjustedToUtc = isAdjustedToUtc, + timeUnit = timeUnit, + timeZone = timeZone + ).consumeLong(consumer, Long.MaxValue) + } + assert(consumer.data.head.head.isInstanceOf[Long]) + assert(consumer.data.head.head == Long.MaxValue) + } + } + } + + test("#consumeTimestamp") { + forAll(conditions) { (isAdjustedToUtc, timeUnit, timeZone, _) => + timeUnit match { + case MILLIS => + val v = Timestamp.ofEpochMilli(Int.MaxValue) + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + TimestampLogicalType( + isAdjustedToUtc = isAdjustedToUtc, + timeUnit = timeUnit, + timeZone = timeZone + ).consumeTimestamp(consumer, v, null) + } + assert(consumer.data.head.head.isInstanceOf[Long]) + assert(consumer.data.head.head == Int.MaxValue) + } + case MICROS => + val v = Timestamp.ofEpochMilli(Int.MaxValue) + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + TimestampLogicalType( + isAdjustedToUtc = isAdjustedToUtc, + timeUnit = timeUnit, + timeZone = timeZone + ).consumeTimestamp(consumer, v, null) + } + assert(consumer.data.head.head.isInstanceOf[Long]) + + assert(consumer.data.head.head == Int.MaxValue * 1_000L) + } + case NANOS => + val v = Timestamp.ofEpochMilli(Int.MaxValue) + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + TimestampLogicalType( + isAdjustedToUtc = isAdjustedToUtc, + timeUnit = timeUnit, + timeZone = timeZone + ).consumeTimestamp(consumer, v, null) + } + assert(consumer.data.head.head.isInstanceOf[Long]) + assert(consumer.data.head.head == Int.MaxValue * 1_000_000L) + } + } + + } + } + + test("#consume{Boolean,Double,String,Json} are unsupported.") { + def assertUnsupportedConsume(f: RecordConsumer => Unit) = + newMockRecordConsumer().tap { consumer => + consumer.writingSampleField { + // format: off + assert(intercept[ConfigException](f(consumer)).getMessage.endsWith("is unsupported.")) + // format: on + } + } + + forAll(conditions) { (isAdjustedToUtc, timeUnit, timeZone, _) => + val t = + TimestampLogicalType( + isAdjustedToUtc = isAdjustedToUtc, + timeUnit = timeUnit, + timeZone = timeZone + ) + assertUnsupportedConsume(t.consumeBoolean(_, true)) + assertUnsupportedConsume(t.consumeDouble(_, 0.0d)) + assertUnsupportedConsume(t.consumeString(_, null)) + assertUnsupportedConsume(t.consumeJson(_, null)) + } + } + +} From 442d4c5a29e0ecd6f0286eb07826e6b7dbb23ab7 Mon Sep 17 00:00:00 2001 From: Civitaspo Date: Mon, 25 May 2020 12:27:37 +0900 Subject: [PATCH 5/5] Update README: new logical type configurations --- README.md | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 47e8cef..fddb3da 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,16 @@ - **column_options**: a map whose keys are name of columns, and values are configuration with following parameters (optional) - **timezone**: timezone if type of this column is timestamp. If not set, **default_timezone** is used. (string, optional) - **format**: timestamp format if type of this column is timestamp. If not set, **default_timestamp_format**: is used. (string, optional) - - **logical_type**: a Parquet logical type name (`timestamp-millis`, `timestamp-micros`, `timestamp-nanos`, `json`, `int8`, `int16`, `int32`, `int64`, `uint8`, `uint16`, `uint32`, `uint64`) (string, optional) + - **converted_type**: a Parquet converted type name (`timestamp-millis`, `timestamp-micros`, `timestamp-nanos`, `json`, `int8`, `int16`, `int32`, `int64`, `uint8`, `uint16`, `uint32`, `uint64`) (string, optional) + - **logical_type**: **[DEPRECATED: Use **converted_type** instead]** a Parquet converted type name (`timestamp-millis`, `timestamp-micros`, `timestamp-nanos`, `json`, `int8`, `int16`, `int32`, `int64`, `uint8`, `uint16`, `uint32`, `uint64`) (string, optional) + - **logical_type**: configuration for the detailed logical type. See [Logical Type Specification](https://github.com/apache/parquet-format/blob/apache-parquet-format-2.7.0/LogicalTypes.md) (optional) + - **name**: The name of logical type (`"date"`, `"decimal"`, `"int"`, `"json"`, `"time"`, `"timestamp"`) (string, required) + - **bit_width**: The bit width for `"int"` logical type (Allowed bit width values are `8`, `16`, `32`, `64`). (int, default: `64`) + - **is_signed**: Signed or not for `"int"` logical type (boolean, default: `true`) + - **scale**: The scale for `"decimal"` logical type (int, default: `0`) + - **precision**: The precision for `"decimal"` logical type (int, default: `0`) + - **is_adjusted_to_utc**: (boolean, default: `true`) + - **time_unit**: The precision for `"time"` or `"timestamp"` logical type (Allowed values are `"MILLIS`, `MICROS`, `NANOS`) - **canned_acl**: grants one of [canned ACLs](https://docs.aws.amazon.com/AmazonS3/latest/dev/acl-overview.html#CannedACL) for created objects (string, default: `private`) - **block_size**: The block size is the size of a row group being buffered in memory. This limits the memory usage when writing. Larger values will improve the I/O when reading but consume more memory when writing. (int, default: `134217728` (128MB)) - **page_size**: The page size is for compression. When reading, each page can be decompressed independently. A block is composed of pages. The page is the smallest unit that must be read fully to access a single record. If this value is too small, the compression will deteriorate. (int, default: `1048576` (1MB)) @@ -80,7 +89,7 @@ |timestamp|string| |json|string| - |parquet logical type|glue data type|note| + |parquet converted type|glue data type|note| |:---|:---|:---| |timestamp-millis|timestamp|| |timestamp-micros|long|Glue cannot recognize timestamp-micros.| @@ -106,7 +115,16 @@ - **password** proxy password (string, optional) - **buffer_dir**: buffer directory for parquet files to be uploaded on S3 (string, default: Create a Temporary Directory) - **type_options**: a map whose keys are name of embulk type(`boolean`, `long`, `double`, `string`, `timestamp`, `json`), and values are configuration with following parameters (optional) - - **logical_type**: a Parquet logical type name (`timestamp-millis`, `timestamp-micros`, `timestamp-nanos`, `json`, `int8`, `int16`, `int32`, `int64`, `uint8`, `uint16`, `uint32`, `uint64`) (string, optional) + - **converted_type**: a Parquet converted type name (`timestamp-millis`, `timestamp-micros`, `timestamp-nanos`, `json`, `int8`, `int16`, `int32`, `int64`, `uint8`, `uint16`, `uint32`, `uint64`) (string, optional) + - **logical_type**: **[DEPRECATED: Use **converted_type** instead]** a Parquet converted type name (`timestamp-millis`, `timestamp-micros`, `timestamp-nanos`, `json`, `int8`, `int16`, `int32`, `int64`, `uint8`, `uint16`, `uint32`, `uint64`) (string, optional) + - **logical_type**: configuration for the detailed logical type. See [Logical Type Specification](https://github.com/apache/parquet-format/blob/apache-parquet-format-2.7.0/LogicalTypes.md) (optional) + - **name**: The name of logical type (`"date"`, `"decimal"`, `"int"`, `"json"`, `"time"`, `"timestamp"`) (string, required) + - **bit_width**: The bit width for `"int"` logical type (Allowed bit width values are `8`, `16`, `32`, `64`). (int, default: `64`) + - **is_signed**: Signed or not for `"int"` logical type (boolean, default: `true`) + - **scale**: The scale for `"decimal"` logical type (int, default: `0`) + - **precision**: The precision for `"decimal"` logical type (int, default: `0`) + - **is_adjusted_to_utc**: (boolean, default: `true`) + - **time_unit**: The precision for `"time"` or `"timestamp"` logical type (Allowed values are `"MILLIS`, `MICROS`, `NANOS`) ## Example