Skip to content

Commit

Permalink
feat: add translator (#1108)
Browse files Browse the repository at this point in the history
* feat: add text translation

* refactor textTranslator to be cleaner

* add document translation

* fix textAndTranslation param name

* fix ServiceParam name consistency issue & update document translation tests

* format

* fix: `from` is keyword in python so rename the ServiceParam and set exemption in fuzzingTest

* update fuzzingTest

* fix `from` as python keyword issue

* address comments

* refactor translator tests
  • Loading branch information
serena-ruan authored Jul 13, 2021
1 parent d287be6 commit 84d8d24
Show file tree
Hide file tree
Showing 7 changed files with 1,138 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,42 +208,11 @@ object RecognizeText extends ComplexParamsReadable[RecognizeText] {
}
}

trait BasicAsyncReply extends HasAsyncReply {

trait HasAsyncReply extends Params {
val backoffs: IntArrayParam = new IntArrayParam(
this, "backoffs", "array of backoffs to use in the handler")

/** @group getParam */
def getBackoffs: Array[Int] = $(backoffs)

/** @group setParam */
def setBackoffs(value: Array[Int]): this.type = set(backoffs, value)

val maxPollingRetries: IntParam = new IntParam(
this, "maxPollingRetries", "number of times to poll")

/** @group getParam */
def getMaxPollingRetries: Int = $(maxPollingRetries)

/** @group setParam */
def setMaxPollingRetries(value: Int): this.type = set(maxPollingRetries, value)

val pollingDelay: IntParam = new IntParam(
this, "pollingDelay", "number of milliseconds to wait between polling")

/** @group getParam */
def getPollingDelay: Int = $(pollingDelay)

/** @group setParam */
def setPollingDelay(value: Int): this.type = set(pollingDelay, value)

//scalastyle:off magic.number
setDefault(backoffs -> Array(100, 500, 1000), maxPollingRetries -> 1000, pollingDelay -> 300)
//scalastyle:on magic.number

private def queryForResult(key: Option[String],
client: CloseableHttpClient,
location: URI): Option[HTTPResponseData] = {
protected def queryForResult(key: Option[String],
client: CloseableHttpClient,
location: URI): Option[HTTPResponseData] = {
val get = new HttpGet()
get.setURI(location)
key.foreach(get.setHeader("Ocp-Apim-Subscription-Key", _))
Expand Down Expand Up @@ -284,13 +253,54 @@ trait HasAsyncReply extends Params {
response
}
}
}


trait HasAsyncReply extends Params {
val backoffs: IntArrayParam = new IntArrayParam(
this, "backoffs", "array of backoffs to use in the handler")

/** @group getParam */
def getBackoffs: Array[Int] = $(backoffs)

/** @group setParam */
def setBackoffs(value: Array[Int]): this.type = set(backoffs, value)

val maxPollingRetries: IntParam = new IntParam(
this, "maxPollingRetries", "number of times to poll")

/** @group getParam */
def getMaxPollingRetries: Int = $(maxPollingRetries)

/** @group setParam */
def setMaxPollingRetries(value: Int): this.type = set(maxPollingRetries, value)

val pollingDelay: IntParam = new IntParam(
this, "pollingDelay", "number of milliseconds to wait between polling")

/** @group getParam */
def getPollingDelay: Int = $(pollingDelay)

/** @group setParam */
def setPollingDelay(value: Int): this.type = set(pollingDelay, value)

//scalastyle:off magic.number
setDefault(backoffs -> Array(100, 500, 1000), maxPollingRetries -> 1000, pollingDelay -> 300)
//scalastyle:on magic.number

protected def queryForResult(key: Option[String],
client: CloseableHttpClient,
location: URI): Option[HTTPResponseData]

protected def handlingFunc(client: CloseableHttpClient,
request: HTTPRequestData): HTTPResponseData

}


class RecognizeText(override val uid: String)
extends CognitiveServicesBaseNoHandler(uid)
with HasAsyncReply
with BasicAsyncReply
with HasImageInput with HasCognitiveServiceInput
with HasInternalJsonOutputParser with HasSetLocation with BasicLogging with HasSetLinkedService {
logClass()
Expand Down Expand Up @@ -336,7 +346,7 @@ object Read extends ComplexParamsReadable[Read] {

class Read(override val uid: String)
extends CognitiveServicesBaseNoHandler(uid)
with HasAsyncReply
with BasicAsyncReply
with HasImageInput with HasCognitiveServiceInput
with HasInternalJsonOutputParser with HasSetLocation with BasicLogging with HasSetLinkedService {
logClass()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.cognitive

import com.microsoft.ml.spark.codegen.Wrappable
import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.param.ServiceParam
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.types.DataType
import com.microsoft.ml.spark.build.BuildInfo
import com.microsoft.ml.spark.io.http.{HTTPRequestData, HTTPResponseData, HandlingUtils, HeaderValues}
import com.microsoft.ml.spark.io.http.HandlingUtils.{convertAndClose, sendWithRetries}
import org.apache.commons.io.IOUtils
import org.apache.http.client.methods.HttpGet
import org.apache.http.impl.client.CloseableHttpClient
import spray.json._

import java.net.URI
import java.util.concurrent.TimeoutException
import scala.concurrent.blocking

trait DocumentTranslatorAsyncReply extends BasicAsyncReply {

import TranslatorJsonProtocol._

override protected def queryForResult(key: Option[String],
client: CloseableHttpClient,
location: URI): Option[HTTPResponseData] = {
val get = new HttpGet()
get.setURI(location)
key.foreach(get.setHeader("Ocp-Apim-Subscription-Key", _))
get.setHeader("User-Agent", s"mmlspark/${BuildInfo.version}${HeaderValues.PlatformInfo}")
val resp = convertAndClose(sendWithRetries(client, get, getBackoffs))
get.releaseConnection()
val status = IOUtils.toString(resp.entity.get.content, "UTF-8")
.parseJson.asJsObject.fields.get("status").map(_.convertTo[String])
status.map(_.toLowerCase()).flatMap {
case "succeeded" | "failed" | "canceled" | "ValidationFailed" => Some(resp)
case "notstarted" | "running" | "cancelling" => None
case s => throw new RuntimeException(s"Received unknown status code: $s")
}
}
}

object DocumentTranslator extends ComplexParamsReadable[DocumentTranslator]

class DocumentTranslator(override val uid: String) extends CognitiveServicesBaseNoHandler(uid)
with HasInternalJsonOutputParser with HasCognitiveServiceInput with HasServiceName
with Wrappable with DocumentTranslatorAsyncReply with BasicLogging {

import TranslatorJsonProtocol._

logClass()

def this() = this(Identifiable.randomUID("DocumentTranslator"))

val filterPrefix = new ServiceParam[String](
this, "filterPrefix", "A case-sensitive prefix string to filter documents in the source" +
" path for translation. For example, when using an Azure storage blob Uri, use the prefix to" +
" restrict sub folders for translation.")

def setFilterPrefix(v: String): this.type = setScalarParam(filterPrefix, v)

def setFilterPrefixCol(v: String): this.type = setVectorParam(filterPrefix, v)

val filterSuffix = new ServiceParam[String](
this, "filterSuffix", "A case-sensitive suffix string to filter documents in the source" +
" path for translation. This is most often use for file extensions.")

def setFilterSuffix(v: String): this.type = setScalarParam(filterSuffix, v)

def setFilterSuffixCol(v: String): this.type = setVectorParam(filterSuffix, v)

val sourceLanguage = new ServiceParam[String](this, "sourceLanguage", "Language code." +
" If none is specified, we will perform auto detect on the document.")

def setSourceLanguage(v: String): this.type = setScalarParam(sourceLanguage, v)

def setSourceLanguageCol(v: String): this.type = setVectorParam(sourceLanguage, v)

val sourceUrl = new ServiceParam[String](this, "sourceUrl", "Location of the folder /" +
" container or single file with your documents.", isRequired = true)

def setSourceUrl(v: String): this.type = setScalarParam(sourceUrl, v)

def setSourceUrlCol(v: String): this.type = setVectorParam(sourceUrl, v)

val sourceStorageSource = new ServiceParam[String](this, "sourceStorageSource",
"Storage source of source input.")

def setSourceStorageSource(v: String): this.type = setScalarParam(sourceStorageSource, v)

def setSourceStorageSourceCol(v: String): this.type = setVectorParam(sourceStorageSource, v)

val storageType = new ServiceParam[String](this, "storageType", "Storage type of the input" +
" documents source string. Required for single document translation only.")

def setStorageType(v: String): this.type = setScalarParam(storageType, v)

def setStorageTypeCol(v: String): this.type = setVectorParam(storageType, v)

val targets = new ServiceParam[Seq[TargetInput]](this, "targets", "Destination for the" +
" finished translated documents.")

def setTargets(v: Seq[TargetInput]): this.type = setScalarParam(targets, v)

def setTargetsCol(v: String): this.type = setVectorParam(targets, v)

override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
def fetchGlossaries(row: Row): Option[Seq[Glossary]] = {
try {
Option(row.getSeq(1).asInstanceOf[Seq[Row]].map(
x => Glossary(x.getString(0), x.getString(1), Option(x.getString(2)), Option(x.getString(3)))
))
} catch {
case _: NullPointerException => Option(row.getAs[Seq[Glossary]]("glossaries"))
}
}

r =>
Some(new StringEntity(
Map("inputs" -> Seq(
BatchRequest(source = SourceInput(
filter = Option(DocumentFilter(
prefix = getValueOpt(r, filterPrefix),
suffix = getValueOpt(r, filterSuffix))),
language = getValueOpt(r, sourceLanguage),
storageSource = getValueOpt(r, sourceStorageSource),
sourceUrl = getValue(r, sourceUrl)),
storageType = getValueOpt(r, storageType),
targets = getValue(r, targets).asInstanceOf[Seq[Row]].map(
row => TargetInput(Option(row.getString(0)),
fetchGlossaries(row),
row.getString(2), row.getString(3), Option(row.getString(4))))
))).toJson.compactPrint, ContentType.APPLICATION_JSON))
}

override def transform(dataset: Dataset[_]): DataFrame = {
logTransform[DataFrame]({
setUrl(s"https://$getServiceName.cognitiveservices.azure.com/translator/text/batch/v1.0/batches")
getInternalTransformer(dataset.schema).transform(dataset)
})
}

override def responseDataType: DataType = TranslationStatusResponse.schema
}

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import spray.json.DefaultJsonProtocol._
import spray.json._

abstract class FormRecognizerBase(override val uid: String) extends CognitiveServicesBaseNoHandler(uid)
with HasCognitiveServiceInput with HasInternalJsonOutputParser with HasAsyncReply
with HasCognitiveServiceInput with HasInternalJsonOutputParser with BasicAsyncReply
with HasImageInput with HasSetLocation {

override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
Expand Down Expand Up @@ -82,6 +82,7 @@ trait HasLocale extends HasServiceParams {
}

object FormsFlatteners {

import FormsJsonProtocol._

def flattenReadResults(inputCol: String, outputCol: String): UDFTransformer = {
Expand Down
Loading

0 comments on commit 84d8d24

Please sign in to comment.