Skip to content

Commit

Permalink
Merge pull request #23 from civitaspo/limit-the-usage-of-ContextClass…
Browse files Browse the repository at this point in the history
…Loader

Limit the usage of swapping ContextClassLoader
  • Loading branch information
civitaspo authored Mar 10, 2020
2 parents 88ead75 + 7223227 commit 9aefba9
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 44 deletions.
1 change: 1 addition & 0 deletions example/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ out:
type: s3_parquet
bucket: example
region: us-east-1
endpoint: http://127.0.0.1:4572
path_prefix: path/to/my-obj.
file_ext: snappy.parquet
compression_codec: snappy
Expand Down
1 change: 1 addition & 0 deletions example/with_catalog.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ out:
type: s3_parquet
bucket: example
region: us-east-1
endpoint: http://127.0.0.1:4572
path_prefix: path/to/my-obj-2.
file_ext: snappy.parquet
compression_codec: snappy
Expand Down
1 change: 1 addition & 0 deletions example/with_logicaltypes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ out:
type: s3_parquet
bucket: example
region: us-east-1
endpoint: http://127.0.0.1:4572
path_prefix: path/to/my-obj-2.
file_ext: snappy.parquet
compression_codec: snappy
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package org.embulk.output.s3_parquet

// WARNING: This object should be used for limited purposes only.
object ContextClassLoaderSwapper {

def using[A](klass: Class[_])(f: => A): A = {
val currentTread = Thread.currentThread()
val original = currentTread.getContextClassLoader
val target = klass.getClassLoader
currentTread.setContextClassLoader(target)
try f
finally currentTread.setContextClassLoader(original)
}

def usingPluginClass[A](f: => A): A = {
using(classOf[S3ParquetOutputPlugin])(f)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,6 @@ class S3ParquetOutputPlugin extends OutputPlugin {

val logger: Logger = LoggerFactory.getLogger(classOf[S3ParquetOutputPlugin])

private def withPluginContextClassLoader[A](f: => A): A = {
val original: ClassLoader = Thread.currentThread.getContextClassLoader
Thread.currentThread.setContextClassLoader(
classOf[S3ParquetOutputPlugin].getClassLoader
)
try f
finally Thread.currentThread.setContextClassLoader(original)
}

override def transaction(
config: ConfigSource,
schema: Schema,
Expand All @@ -150,10 +141,9 @@ class S3ParquetOutputPlugin extends OutputPlugin {
): ConfigDiff = {
val task: PluginTask = config.loadConfig(classOf[PluginTask])

withPluginContextClassLoader {
configure(task, schema)
control.run(task.dump)
}
configure(task, schema)
control.run(task.dump)

task.getCatalog.ifPresent { catalog =>
val location =
s"s3://${task.getBucket}/${task.getPathPrefix.replaceFirst("(.*/)[^/]+$", "$1")}"
Expand Down Expand Up @@ -303,34 +293,43 @@ class S3ParquetOutputPlugin extends OutputPlugin {
task.getTypeOptions,
task.getColumnOptions
)
val parquetWriter: ParquetWriter[PageReader] = ParquetFileWriter
.builder()
.withPath(bufferFile)
.withSchema(schema)
.withLogicalTypeHandlers(logicalTypeHandlers)
.withTimestampFormatters(timestampFormatters)
.withCompressionCodec(task.getCompressionCodec)
.withDictionaryEncoding(
task.getEnableDictionaryEncoding.orElse(
ParquetProperties.DEFAULT_IS_DICTIONARY_ENABLED
)
)
.withDictionaryPageSize(
task.getPageSize.orElse(ParquetProperties.DEFAULT_DICTIONARY_PAGE_SIZE)
)
.withMaxPaddingSize(
task.getMaxPaddingSize.orElse(ParquetWriter.MAX_PADDING_SIZE_DEFAULT)
)
.withPageSize(
task.getPageSize.orElse(ParquetProperties.DEFAULT_PAGE_SIZE)
)
.withRowGroupSize(
task.getBlockSize.orElse(ParquetWriter.DEFAULT_BLOCK_SIZE)
)
.withValidation(ParquetWriter.DEFAULT_IS_VALIDATING_ENABLED)
.withWriteMode(org.apache.parquet.hadoop.ParquetFileWriter.Mode.CREATE)
.withWriterVersion(ParquetProperties.DEFAULT_WRITER_VERSION)
.build()
val parquetWriter: ParquetWriter[PageReader] =
ContextClassLoaderSwapper.usingPluginClass {
ParquetFileWriter
.builder()
.withPath(bufferFile)
.withSchema(schema)
.withLogicalTypeHandlers(logicalTypeHandlers)
.withTimestampFormatters(timestampFormatters)
.withCompressionCodec(task.getCompressionCodec)
.withDictionaryEncoding(
task.getEnableDictionaryEncoding.orElse(
ParquetProperties.DEFAULT_IS_DICTIONARY_ENABLED
)
)
.withDictionaryPageSize(
task.getPageSize.orElse(
ParquetProperties.DEFAULT_DICTIONARY_PAGE_SIZE
)
)
.withMaxPaddingSize(
task.getMaxPaddingSize.orElse(
ParquetWriter.MAX_PADDING_SIZE_DEFAULT
)
)
.withPageSize(
task.getPageSize.orElse(ParquetProperties.DEFAULT_PAGE_SIZE)
)
.withRowGroupSize(
task.getBlockSize.orElse(ParquetWriter.DEFAULT_BLOCK_SIZE)
)
.withValidation(ParquetWriter.DEFAULT_IS_VALIDATING_ENABLED)
.withWriteMode(
org.apache.parquet.hadoop.ParquetFileWriter.Mode.CREATE
)
.withWriterVersion(ParquetProperties.DEFAULT_WRITER_VERSION)
.build()
}

logger.info(
s"Local Buffer File: $bufferFile, Destination: s3://$destS3bucket/$destS3Key"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ case class S3ParquetPageOutput(
override def close(): Unit = {
synchronized {
if (!isClosed) {
writer.close()
ContextClassLoaderSwapper.usingPluginClass {
writer.close()
}
isClosed = true
}
}
Expand All @@ -46,11 +48,12 @@ case class S3ParquetPageOutput(

override def commit(): TaskReport = {
close()
val result: UploadResult = aws.withTransferManager {
xfer: TransferManager =>
val result: UploadResult = ContextClassLoaderSwapper.usingPluginClass {
aws.withTransferManager { xfer: TransferManager =>
val upload: Upload =
xfer.upload(destBucket, destKey, new File(outputLocalFile))
upload.waitForUploadResult()
}
}
cleanup()
Exec
Expand Down

0 comments on commit 9aefba9

Please sign in to comment.