Skip to content

Commit

Permalink
[spark] Integrate HuggingFace tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Feb 14, 2023
1 parent 3f2a495 commit f400040
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 6 deletions.
25 changes: 20 additions & 5 deletions docker/spark/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,38 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
FROM 895885662937.dkr.ecr.us-west-2.amazonaws.com/spark/emr-6.9.0:latest
FROM 711395599931.dkr.ecr.us-east-2.amazonaws.com/spark/emr-6.9.0:latest
LABEL maintainer="djl-dev@amazon.com"

# Add DJL jars
# Install DJL dependencies
USER root
ARG DJL_VERSION=0.20.0
ARG JNA_VERSION=5.12.1
ARG JAVACPP_VERSION=1.5.8
ARG TENSORFLOW_CORE_VERSION=0.4.2
ARG PROTOBUF_VERSION=3.21.9

COPY extensions/spark/setup/dist/ dist/
RUN pip3 install --no-cache-dir dist/djl_spark-${DJL_VERSION}-py3-none-any.whl && \
rm -rf dist

ADD https://repo1.maven.org/maven2/ai/djl/api/${DJL_VERSION}/api-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-api-${DJL_VERSION}.jar
ADD https://repo1.maven.org/maven2/ai/djl/spark/spark/${DJL_VERSION}/spark-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-spark-${DJL_VERSION}.jar
ADD https://repo1.maven.org/maven2/ai/djl/huggingface/tokenizers/${DJL_VERSION}/tokenizers-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-tokenizers-${DJL_VERSION}.jar

ADD https://repo1.maven.org/maven2/ai/djl/pytorch/pytorch-engine/${DJL_VERSION}/pytorch-engine-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-pytorch-engine-${DJL_VERSION}.jar
ADD https://repo1.maven.org/maven2/ai/djl/mxnet/mxnet-engine/${DJL_VERSION}/mxnet-engine-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-mxnet-engine-${DJL_VERSION}.jar
ADD https://repo1.maven.org/maven2/ai/djl/onnxruntime/onnxruntime-engine/${DJL_VERSION}/onnxruntime-engine-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-onnxruntime-engine-${DJL_VERSION}.jar
ADD https://repo1.maven.org/maven2/ai/djl/tensorflow/tensorflow-engine/${DJL_VERSION}/tensorflow-engine-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-tensorflow-engine-${DJL_VERSION}.jar
ADD https://repo1.maven.org/maven2/net/java/dev/jna/jna/${JNA_VERSION}/jna-${JNA_VERSION}.jar /usr/lib/spark/jars/

ADD https://repo1.maven.org/maven2/ai/djl/tensorflow/tensorflow-engine/${DJL_VERSION}/tensorflow-engine-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-tensorflow-engine-${DJL_VERSION}.jar
ADD https://repo1.maven.org/maven2/org/bytedeco/javacpp/${JAVACPP_VERSION}/javacpp-${JAVACPP_VERSION}.jar /usr/lib/spark/jars/
ADD https://repo1.maven.org/maven2/org/tensorflow/tensorflow-core-api/${TENSORFLOW_CORE_VERSION}/tensorflow-core-api-${TENSORFLOW_CORE_VERSION}.jar /usr/lib/spark/jars/
RUN rm /usr/lib/spark/jars/protobuf-java-*.jar
ADD https://repo1.maven.org/maven2/com/google/protobuf/protobuf-java/${PROTOBUF_VERSION}/protobuf-java-${PROTOBUF_VERSION}.jar /usr/lib/spark/jars/

RUN chmod -R +r /usr/lib/spark/jars/

# Set environment
ENV PYTORCH_PRECXX11 true
ENV OMP_NUM_THREADS 1

USER hadoop:hadoop
1 change: 1 addition & 0 deletions extensions/spark/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ group "ai.djl.spark"

dependencies {
api project(":api")
api project(":extensions:tokenizers")
api "org.apache.spark:spark-core_2.12:${spark_version}"
api "org.apache.spark:spark-sql_2.12:${spark_version}"
api "org.apache.spark:spark-mllib_2.12:${spark_version}"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.spark.task.text

import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}

/**
* TextDecoder performs text decoding using HuggingFace tokenizers in Spark.
*
* @param uid An immutable unique ID for the object and its derivatives.
*/
class HuggingFaceTextDecoder(override val uid: String) extends TextPredictor[Array[Long], String]
with HasInputCol with HasOutputCol {

def this() = this(Identifiable.randomUID("HuggingFaceTextDecoder"))

final val name = new Param[String](this, "name", "The name of the tokenizer")

/**
* Sets the inputCol parameter.
*
* @param value the value of the parameter
*/
def setInputCol(value: String): this.type = set(inputCol, value)

/**
* Sets the outputCol parameter.
*
* @param value the value of the parameter
*/
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* Sets the name parameter.
*
* @param value the value of the parameter
*/
def setName(value: String): this.type = set(name, value)

setDefault(inputClass, classOf[Array[Long]])
setDefault(outputClass, classOf[String])

/**
* Decodes String from the input ids on the provided dataset.
*
* @param dataset input dataset
* @return output dataset
*/
def decode(dataset: Dataset[_]): DataFrame = {
transform(dataset)
}

/** @inheritdoc */
override def transformRows(iter: Iterator[Row]): Iterator[Row] = {
val tokenizer = HuggingFaceTokenizer.newInstance($(name))
iter.map(row => {
new GenericRowWithSchema(row.toSeq.toArray ++
Array[Any](tokenizer.decode(row.getAs[Seq[Long]]($(inputCol)).toArray)),
outputSchema)
})
}

/** @inheritdoc */
override def transformSchema(schema: StructType): StructType = {
validateInputType(schema($(inputCol)))
val outputSchema = StructType(schema.fields ++
Array(StructField($(outputCol), StringType)))
outputSchema
}

override def validateInputType(input: StructField): Unit = {
require(input.dataType == ArrayType(LongType),
s"Input column ${input.name} type must be ArrayType but got ${input.dataType}.")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.spark.task.text

import ai.djl.huggingface.tokenizers.{Encoding, HuggingFaceTokenizer}
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}

/**
* TextEncoder performs text encoding using HuggingFace tokenizers in Spark.
*
* @param uid An immutable unique ID for the object and its derivatives.
*/
class HuggingFaceTextEncoder(override val uid: String) extends TextPredictor[String, Encoding]
with HasInputCol with HasOutputCol {

def this() = this(Identifiable.randomUID("HuggingFaceTextEncoder"))

final val name = new Param[String](this, "name", "The name of the tokenizer")

/**
* Sets the inputCol parameter.
*
* @param value the value of the parameter
*/
def setInputCol(value: String): this.type = set(inputCol, value)

/**
* Sets the outputCol parameter.
*
* @param value the value of the parameter
*/
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* Sets the name parameter.
*
* @param value the value of the parameter
*/
def setName(value: String): this.type = set(name, value)

setDefault(inputClass, classOf[String])
setDefault(outputClass, classOf[Encoding])

/**
* Performs sentence encoding on the provided dataset.
*
* @param dataset input dataset
* @return output dataset
*/
def encode(dataset: Dataset[_]): DataFrame = {
transform(dataset)
}

/** @inheritdoc */
override def transformRows(iter: Iterator[Row]): Iterator[Row] = {
val tokenizer = HuggingFaceTokenizer.newInstance($(name))
iter.map(row => {
val encoding = tokenizer.encode(row.getAs[String]($(inputCol)))
new GenericRowWithSchema(row.toSeq.toArray
++ Array[Any](Row(encoding.getIds, encoding.getTypeIds, encoding.getAttentionMask)),
outputSchema)
})
}

/** @inheritdoc */
override def transformSchema(schema: StructType): StructType = {
validateInputType(schema($(inputCol)))
val outputSchema = StructType(schema.fields ++
Array(StructField($(outputCol), StructType(Seq(StructField("ids", ArrayType(LongType)),
StructField("type_ids", ArrayType(LongType)),
StructField("attention_mask", ArrayType(LongType)))))))
outputSchema
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.spark.task.text

import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}

/**
* TextTokenizer performs text tokenization using HuggingFace tokenizers in Spark.
*
* @param uid An immutable unique ID for the object and its derivatives.
*/
class HuggingFaceTextTokenizer(override val uid: String) extends TextPredictor[String, Array[String]]
with HasInputCol with HasOutputCol {

def this() = this(Identifiable.randomUID("HuggingFaceTextTokenizer"))

final val name = new Param[String](this, "name", "The name of the tokenizer")

/**
* Sets the inputCol parameter.
*
* @param value the value of the parameter
*/
def setInputCol(value: String): this.type = set(inputCol, value)

/**
* Sets the outputCol parameter.
*
* @param value the value of the parameter
*/
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* Sets the name parameter.
*
* @param value the value of the parameter
*/
def setName(value: String): this.type = set(name, value)

setDefault(inputClass, classOf[String])
setDefault(outputClass, classOf[Array[String]])

/**
* Performs sentence tokenization on the provided dataset.
*
* @param dataset input dataset
* @return output dataset
*/
def tokenize(dataset: Dataset[_]): DataFrame = {
transform(dataset)
}

/** @inheritdoc */
override def transformRows(iter: Iterator[Row]): Iterator[Row] = {
val tokenizer = HuggingFaceTokenizer.newInstance($(name))
iter.map(row => {
new GenericRowWithSchema(row.toSeq.toArray ++
Array[Any](tokenizer.tokenize(row.getAs[String]($(inputCol))).toArray),
outputSchema)
})
}

/** @inheritdoc */
override def transformSchema(schema: StructType): StructType = {
validateInputType(schema($(inputCol)))
val outputSchema = StructType(schema.fields ++
Array(StructField($(outputCol), ArrayType(StringType))))
outputSchema
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row}
*
* @param uid An immutable unique ID for the object and its derivatives.
*/
abstract class TextEmbedder(override val uid: String) extends TextPredictor[String, Array[Float]]
class TextEmbedder(override val uid: String) extends TextPredictor[String, Array[Float]]
with HasInputCol with HasOutputCol {

def this() = this(Identifiable.randomUID("TextEmbedder"))
Expand Down

0 comments on commit f400040

Please sign in to comment.