Skip to content

Commit

Permalink
fix: fix issues with form recognizer parsing and form ontology learner
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 committed May 7, 2022
1 parent f232058 commit 5205114
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ class FormOntologyLearner(override val uid: String) extends Estimator[FormOntolo
def this() = this(Identifiable.randomUID("FormOntologyLearner"))

private[ml] def extractOntology(fromRow: Row => AnalyzeResponse)(r: Row): StructType = {
val fieldResults = fromRow(r.getStruct(0)).analyzeResult.documentResults.get.head.fields
val fieldResults = fromRow(r.getStruct(0)).analyzeResult.documentResults
.getOrElse(throw new IllegalArgumentException("A row does not have a `analyzeResult.documentResults` field," +
" please filter these out before using the FormOntologyLearner"))
.head.fields
new StructType(fieldResults
.mapValues(_.toFieldResultRecursive.toSimplifiedDataType)
.map({ case (name, dt) => StructField(name, dt) }).toArray)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import spray.json._

abstract class FormRecognizerBase(override val uid: String) extends CognitiveServicesBaseNoHandler(uid)
with HasCognitiveServiceInput with HasInternalJsonOutputParser with BasicAsyncReply
with HasImageInput with HasSetLocation with HasSetLinkedService {
with HasImageInput with HasSetLocation with HasSetLinkedService with HasModelVersion {

override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
r =>
Expand Down Expand Up @@ -306,21 +306,8 @@ class GetCustomModel(override val uid: String) extends CognitiveServicesBase(uid

def urlPath: String = "formrecognizer/v2.1/custom/models"

override protected def prepareUrl: Row => String = {
val urlParams: Array[ServiceParam[Any]] =
getUrlParams.asInstanceOf[Array[ServiceParam[Any]]];
// This semicolon is needed to avoid argument confusion
{ row: Row =>
val base = getUrl + s"/${getValue(row, modelId)}"
val appended = if (!urlParams.isEmpty) {
"?" + URLEncodingUtils.format(urlParams.flatMap(p =>
getValueOpt(row, p).map(v => p.name -> p.toValueString(v))
).toMap)
} else {
""
}
base + appended
}
override protected def prepareUrlRoot: Row => String = { row =>
getUrl + s"/${getValue(row, modelId)}"
}

override protected def prepareMethod(): HttpRequestBase = new HttpGet()
Expand All @@ -347,21 +334,8 @@ class AnalyzeCustomModel(override val uid: String) extends FormRecognizerBase(ui

def urlPath: String = "formrecognizer/v2.1/custom/models"

override protected def prepareUrl: Row => String = {
val urlParams: Array[ServiceParam[Any]] =
getUrlParams.asInstanceOf[Array[ServiceParam[Any]]];
// This semicolon is needed to avoid argument confusion
{ row: Row =>
val base = getUrl + s"/${getValue(row, modelId)}/analyze"
val appended = if (!urlParams.isEmpty) {
"?" + URLEncodingUtils.format(urlParams.flatMap(p =>
getValueOpt(row, p).map(v => p.name -> p.toValueString(v))
).toMap)
} else {
""
}
base + appended
}
override protected def prepareUrlRoot: Row => String = {row =>
getUrl + s"/${getValue(row, modelId)}/analyze"
}

override protected def responseDataType: DataType = AnalyzeResponse.schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,19 @@ trait HasPrebuiltModelID extends HasServiceParams {
def getPrebuiltModelIdCol: String = getVectorParam(prebuiltModelId)
}

object AnalyzeDocument extends ComplexParamsReadable[AnalyzeDocument] {
// Different versions might have different results so make sure tests pass before updating
val DefaultAPIVersion = "2022-01-30-preview"
}
object AnalyzeDocument extends ComplexParamsReadable[AnalyzeDocument]

class AnalyzeDocument(override val uid: String) extends CognitiveServicesBaseNoHandler(uid)
with HasCognitiveServiceInput with HasInternalJsonOutputParser with BasicAsyncReply
with HasPrebuiltModelID with HasPages with HasLocale
with HasPrebuiltModelID with HasPages with HasLocale with HasAPIVersion
with HasImageInput with HasSetLocation with BasicLogging {
logClass()

setDefault(apiVersion -> Left("2022-01-30-preview"))

def this() = this(Identifiable.randomUID("AnalyzeDocument"))

def urlPath: String = "formrecognizer/documentModels/"
def urlPath: String = "formrecognizer/documentModels"

val stringIndexType = new ServiceParam[String](this, "stringIndexType", "Method used to " +
"compute string offset and length.", {
Expand Down Expand Up @@ -75,21 +74,8 @@ class AnalyzeDocument(override val uid: String) extends CognitiveServicesBaseNoH
"Payload needs to contain image bytes or url. This code should not run"))
}

override protected def prepareUrl: Row => String = {
val urlParams: Array[ServiceParam[Any]] =
getUrlParams.asInstanceOf[Array[ServiceParam[Any]]];
// This semicolon is needed to avoid argument confusion
{ row: Row =>
val base = getUrl + s"${getValue(row, prebuiltModelId)}:analyze?api-version=${AnalyzeDocument.DefaultAPIVersion}"
val appended = if (!urlParams.isEmpty) {
"&" + URLEncodingUtils.format(urlParams.flatMap(p =>
getValueOpt(row, p).map(v => p.name -> p.toValueString(v))
).toMap)
} else {
""
}
base + appended
}
override protected def prepareUrlRoot: Row => String = { row =>
getUrl + s"/${getValue(row, prebuiltModelId)}:analyze"
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ case class AnalyzeResultV3(apiVersion: String,

case class PageResultV3(pageNumber: Int,
angle: Double,
width: Int,
height: Int,
width: Double,
height: Double,
unit: String,
spans: Option[Seq[FormSpan]],
words: Option[Seq[FormWord]],
Expand All @@ -37,19 +37,19 @@ case class PageResultV3(pageNumber: Int,

case class FormSpan(offset: Int, length: Int)

case class FormWord(content: String, boundingBox: Option[Seq[Int]], confidence: Double, span: FormSpan)
case class FormWord(content: String, boundingBox: Option[Seq[Double]], confidence: Double, span: FormSpan)

case class FormSelectionMark(state: String, boundingBox: Option[Seq[Int]], confidence: Double, span: FormSpan)
case class FormSelectionMark(state: String, boundingBox: Option[Seq[Double]], confidence: Double, span: FormSpan)

case class FormLine(content: String, boundingBox: Option[Seq[Int]], spans: Option[Seq[FormSpan]])
case class FormLine(content: String, boundingBox: Option[Seq[Double]], spans: Option[Seq[FormSpan]])

case class TableResultV3(rowCount: Int,
columnCount: Int,
boundingRegions: Option[Seq[BoundingRegion]],
spans: Option[Seq[FormSpan]],
cells: Option[Seq[FormCell]])

case class BoundingRegion(pageNumber: Int, boundingBox: Option[Seq[Int]])
case class BoundingRegion(pageNumber: Int, boundingBox: Option[Seq[Double]])

case class FormCell(kind: String,
rowIndex: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package com.microsoft.azure.synapse.ml.cognitive.split1

import com.microsoft.azure.synapse.ml.cognitive._
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{EstimatorFuzzing, TestObject}
import org.apache.spark.SparkException
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructType}
Expand All @@ -31,8 +32,30 @@ class FormOntologyLearnerSuite extends EstimatorFuzzing[FormOntologyLearner] wit
"https://mmlsparkdemo.blob.core.windows.net/ignite2021/forms/2009/Invoice12241.pdf"
).toDF("url")

lazy val tableUrlDF: DataFrame = Seq(
"https://mmlspark.blob.core.windows.net/datasets/FormRecognizer/tables1.pdf"
).toDF("url")

lazy val df: DataFrame = analyzeInvoices.transform(urlDF).cache()

test("Yields a reasonable error message when input rows dont contain documentResults") {
val analyzedDf = new AnalyzeLayout()
.setSubscriptionKey(cognitiveKey)
.setLocation("eastus")
.setImageUrlCol("url")
.setOutputCol("layout")
.setConcurrency(5)
.transform(tableUrlDF)

assertThrows[SparkException] {
new FormOntologyLearner()
.setInputCol("layout")
.setOutputCol("unified_ontology")
.fit(analyzedDf)
.transform(analyzedDf)
}
}

test("Basic Usage") {

val targetSchema = new StructType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ trait FormRecognizerUtils extends TestBase with CognitiveKey with Flaky {

lazy val bytesDF5: DataFrame = createTestDataframe(baseUrl, Seq("id1.jpg"), returnBytes = true)

lazy val imageDf6: DataFrame = createTestDataframe(baseUrl, Seq("tables1.pdf"), returnBytes = false)

lazy val pdfDf1: DataFrame = createTestDataframe(baseUrl, Seq("layout2.pdf"), returnBytes = false)

lazy val pdfDf2: DataFrame = createTestDataframe(baseUrl, Seq("invoice1.pdf", "invoice3.pdf"), returnBytes = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ class AnalyzeDocumentSuite extends TransformerFuzzing[AnalyzeDocument] with Form
super.assertDFEq(prep(df1), prep(df2))(eq)
}

test("basic usage with tables") {
val fromRow = AnalyzeDocumentResponse.makeFromRowConverter
analyzeDocument
.setPrebuiltModelId("prebuilt-layout")
.setImageUrlCol("source")
.transform(imageDf6)
.collect()
.map(r => fromRow(r.getAs[Row]("result")))
.foreach(r => assert(r.analyzeResult.pages.get.head.pageNumber >= 0))
}

def analyzeDocument: AnalyzeDocument = new AnalyzeDocument()
.setSubscriptionKey(cognitiveKey)
.setLocation("eastus")
Expand Down Expand Up @@ -128,7 +139,7 @@ class AnalyzeDocumentSuite extends TransformerFuzzing[AnalyzeDocument] with Form
resultAssert(result, "WA WASHINGTON\n20 1234567XX1101\nDRIVER LICENSE\nFEDERAL LIMITS APPLY\n" +
"4d LIC#WDLABCD456DG 9CLASS\nDONORS\n1 TALBOT\n2 LIAM R.\n3 DOB 01/06/1958\n",
"Address,CountryRegion,DateOfBirth,DateOfExpiration,DocumentNumber," +
"Endorsements,FirstName,LastName,Region,Restrictions,Sex")
"Endorsements,FirstName,LastName,Region,Restrictions,Sex")
}
}

Expand Down

0 comments on commit 5205114

Please sign in to comment.