From 5205114b81ff102af7dfd0e19d83f6aaacc0958a Mon Sep 17 00:00:00 2001 From: Mark Date: Fri, 6 May 2022 17:43:29 -0400 Subject: [PATCH] fix: fix issues with form recognizer parsing and form ontology learner --- .../ml/cognitive/FormOntologyLearner.scala | 5 ++- .../synapse/ml/cognitive/FormRecognizer.scala | 36 +++---------------- .../ml/cognitive/FormRecognizerV3.scala | 28 ++++----------- .../cognitive/FormRecognizerV3Schemas.scala | 12 +++---- .../split1/FormOntologyLearnerSuite.scala | 23 ++++++++++++ .../split1/FormRecognizerSuite.scala | 2 ++ .../split1/FormRecognizerV3Suite.scala | 13 ++++++- 7 files changed, 59 insertions(+), 60 deletions(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/FormOntologyLearner.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/FormOntologyLearner.scala index a53d787b64..e6ed790bae 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/FormOntologyLearner.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/FormOntologyLearner.scala @@ -46,7 +46,10 @@ class FormOntologyLearner(override val uid: String) extends Estimator[FormOntolo def this() = this(Identifiable.randomUID("FormOntologyLearner")) private[ml] def extractOntology(fromRow: Row => AnalyzeResponse)(r: Row): StructType = { - val fieldResults = fromRow(r.getStruct(0)).analyzeResult.documentResults.get.head.fields + val fieldResults = fromRow(r.getStruct(0)).analyzeResult.documentResults + .getOrElse(throw new IllegalArgumentException("A row does not have a `analyzeResult.documentResults` field," + + " please filter these out before using the FormOntologyLearner")) + .head.fields new StructType(fieldResults .mapValues(_.toFieldResultRecursive.toSimplifiedDataType) .map({ case (name, dt) => StructField(name, dt) }).toArray) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/FormRecognizer.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/FormRecognizer.scala index 571f255f2f..9f7fbd4ea4 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/FormRecognizer.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/FormRecognizer.scala @@ -18,7 +18,7 @@ import spray.json._ abstract class FormRecognizerBase(override val uid: String) extends CognitiveServicesBaseNoHandler(uid) with HasCognitiveServiceInput with HasInternalJsonOutputParser with BasicAsyncReply - with HasImageInput with HasSetLocation with HasSetLinkedService { + with HasImageInput with HasSetLocation with HasSetLinkedService with HasModelVersion { override protected def prepareEntity: Row => Option[AbstractHttpEntity] = { r => @@ -306,21 +306,8 @@ class GetCustomModel(override val uid: String) extends CognitiveServicesBase(uid def urlPath: String = "formrecognizer/v2.1/custom/models" - override protected def prepareUrl: Row => String = { - val urlParams: Array[ServiceParam[Any]] = - getUrlParams.asInstanceOf[Array[ServiceParam[Any]]]; - // This semicolon is needed to avoid argument confusion - { row: Row => - val base = getUrl + s"/${getValue(row, modelId)}" - val appended = if (!urlParams.isEmpty) { - "?" + URLEncodingUtils.format(urlParams.flatMap(p => - getValueOpt(row, p).map(v => p.name -> p.toValueString(v)) - ).toMap) - } else { - "" - } - base + appended - } + override protected def prepareUrlRoot: Row => String = { row => + getUrl + s"/${getValue(row, modelId)}" } override protected def prepareMethod(): HttpRequestBase = new HttpGet() @@ -347,21 +334,8 @@ class AnalyzeCustomModel(override val uid: String) extends FormRecognizerBase(ui def urlPath: String = "formrecognizer/v2.1/custom/models" - override protected def prepareUrl: Row => String = { - val urlParams: Array[ServiceParam[Any]] = - getUrlParams.asInstanceOf[Array[ServiceParam[Any]]]; - // This semicolon is needed to avoid argument confusion - { row: Row => - val base = getUrl + s"/${getValue(row, modelId)}/analyze" - val appended = if (!urlParams.isEmpty) { - "?" + URLEncodingUtils.format(urlParams.flatMap(p => - getValueOpt(row, p).map(v => p.name -> p.toValueString(v)) - ).toMap) - } else { - "" - } - base + appended - } + override protected def prepareUrlRoot: Row => String = {row => + getUrl + s"/${getValue(row, modelId)}/analyze" } override protected def responseDataType: DataType = AnalyzeResponse.schema diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/FormRecognizerV3.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/FormRecognizerV3.scala index c42c50e9be..6737aba995 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/FormRecognizerV3.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/FormRecognizerV3.scala @@ -34,20 +34,19 @@ trait HasPrebuiltModelID extends HasServiceParams { def getPrebuiltModelIdCol: String = getVectorParam(prebuiltModelId) } -object AnalyzeDocument extends ComplexParamsReadable[AnalyzeDocument] { - // Different versions might have different results so make sure tests pass before updating - val DefaultAPIVersion = "2022-01-30-preview" -} +object AnalyzeDocument extends ComplexParamsReadable[AnalyzeDocument] class AnalyzeDocument(override val uid: String) extends CognitiveServicesBaseNoHandler(uid) with HasCognitiveServiceInput with HasInternalJsonOutputParser with BasicAsyncReply - with HasPrebuiltModelID with HasPages with HasLocale + with HasPrebuiltModelID with HasPages with HasLocale with HasAPIVersion with HasImageInput with HasSetLocation with BasicLogging { logClass() + setDefault(apiVersion -> Left("2022-01-30-preview")) + def this() = this(Identifiable.randomUID("AnalyzeDocument")) - def urlPath: String = "formrecognizer/documentModels/" + def urlPath: String = "formrecognizer/documentModels" val stringIndexType = new ServiceParam[String](this, "stringIndexType", "Method used to " + "compute string offset and length.", { @@ -75,21 +74,8 @@ class AnalyzeDocument(override val uid: String) extends CognitiveServicesBaseNoH "Payload needs to contain image bytes or url. This code should not run")) } - override protected def prepareUrl: Row => String = { - val urlParams: Array[ServiceParam[Any]] = - getUrlParams.asInstanceOf[Array[ServiceParam[Any]]]; - // This semicolon is needed to avoid argument confusion - { row: Row => - val base = getUrl + s"${getValue(row, prebuiltModelId)}:analyze?api-version=${AnalyzeDocument.DefaultAPIVersion}" - val appended = if (!urlParams.isEmpty) { - "&" + URLEncodingUtils.format(urlParams.flatMap(p => - getValueOpt(row, p).map(v => p.name -> p.toValueString(v)) - ).toMap) - } else { - "" - } - base + appended - } + override protected def prepareUrlRoot: Row => String = { row => + getUrl + s"/${getValue(row, prebuiltModelId)}:analyze" } } diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/FormRecognizerV3Schemas.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/FormRecognizerV3Schemas.scala index 89c87765b7..3e12e144ec 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/FormRecognizerV3Schemas.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/FormRecognizerV3Schemas.scala @@ -27,8 +27,8 @@ case class AnalyzeResultV3(apiVersion: String, case class PageResultV3(pageNumber: Int, angle: Double, - width: Int, - height: Int, + width: Double, + height: Double, unit: String, spans: Option[Seq[FormSpan]], words: Option[Seq[FormWord]], @@ -37,11 +37,11 @@ case class PageResultV3(pageNumber: Int, case class FormSpan(offset: Int, length: Int) -case class FormWord(content: String, boundingBox: Option[Seq[Int]], confidence: Double, span: FormSpan) +case class FormWord(content: String, boundingBox: Option[Seq[Double]], confidence: Double, span: FormSpan) -case class FormSelectionMark(state: String, boundingBox: Option[Seq[Int]], confidence: Double, span: FormSpan) +case class FormSelectionMark(state: String, boundingBox: Option[Seq[Double]], confidence: Double, span: FormSpan) -case class FormLine(content: String, boundingBox: Option[Seq[Int]], spans: Option[Seq[FormSpan]]) +case class FormLine(content: String, boundingBox: Option[Seq[Double]], spans: Option[Seq[FormSpan]]) case class TableResultV3(rowCount: Int, columnCount: Int, @@ -49,7 +49,7 @@ case class TableResultV3(rowCount: Int, spans: Option[Seq[FormSpan]], cells: Option[Seq[FormCell]]) -case class BoundingRegion(pageNumber: Int, boundingBox: Option[Seq[Int]]) +case class BoundingRegion(pageNumber: Int, boundingBox: Option[Seq[Double]]) case class FormCell(kind: String, rowIndex: Int, diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/FormOntologyLearnerSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/FormOntologyLearnerSuite.scala index 7d307b3066..2e396aff75 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/FormOntologyLearnerSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/FormOntologyLearnerSuite.scala @@ -5,6 +5,7 @@ package com.microsoft.azure.synapse.ml.cognitive.split1 import com.microsoft.azure.synapse.ml.cognitive._ import com.microsoft.azure.synapse.ml.core.test.fuzzing.{EstimatorFuzzing, TestObject} +import org.apache.spark.SparkException import org.apache.spark.ml.util.MLReadable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructType} @@ -31,8 +32,30 @@ class FormOntologyLearnerSuite extends EstimatorFuzzing[FormOntologyLearner] wit "https://mmlsparkdemo.blob.core.windows.net/ignite2021/forms/2009/Invoice12241.pdf" ).toDF("url") + lazy val tableUrlDF: DataFrame = Seq( + "https://mmlspark.blob.core.windows.net/datasets/FormRecognizer/tables1.pdf" + ).toDF("url") + lazy val df: DataFrame = analyzeInvoices.transform(urlDF).cache() + test("Yields a reasonable error message when input rows dont contain documentResults") { + val analyzedDf = new AnalyzeLayout() + .setSubscriptionKey(cognitiveKey) + .setLocation("eastus") + .setImageUrlCol("url") + .setOutputCol("layout") + .setConcurrency(5) + .transform(tableUrlDF) + + assertThrows[SparkException] { + new FormOntologyLearner() + .setInputCol("layout") + .setOutputCol("unified_ontology") + .fit(analyzedDf) + .transform(analyzedDf) + } + } + test("Basic Usage") { val targetSchema = new StructType() diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/FormRecognizerSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/FormRecognizerSuite.scala index 397a4b0f39..3b6f17fb7d 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/FormRecognizerSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/FormRecognizerSuite.scala @@ -130,6 +130,8 @@ trait FormRecognizerUtils extends TestBase with CognitiveKey with Flaky { lazy val bytesDF5: DataFrame = createTestDataframe(baseUrl, Seq("id1.jpg"), returnBytes = true) + lazy val imageDf6: DataFrame = createTestDataframe(baseUrl, Seq("tables1.pdf"), returnBytes = false) + lazy val pdfDf1: DataFrame = createTestDataframe(baseUrl, Seq("layout2.pdf"), returnBytes = false) lazy val pdfDf2: DataFrame = createTestDataframe(baseUrl, Seq("invoice1.pdf", "invoice3.pdf"), returnBytes = false) diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/FormRecognizerV3Suite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/FormRecognizerV3Suite.scala index 436c558f28..029f7f39f7 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/FormRecognizerV3Suite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/FormRecognizerV3Suite.scala @@ -60,6 +60,17 @@ class AnalyzeDocumentSuite extends TransformerFuzzing[AnalyzeDocument] with Form super.assertDFEq(prep(df1), prep(df2))(eq) } + test("basic usage with tables") { + val fromRow = AnalyzeDocumentResponse.makeFromRowConverter + analyzeDocument + .setPrebuiltModelId("prebuilt-layout") + .setImageUrlCol("source") + .transform(imageDf6) + .collect() + .map(r => fromRow(r.getAs[Row]("result"))) + .foreach(r => assert(r.analyzeResult.pages.get.head.pageNumber >= 0)) + } + def analyzeDocument: AnalyzeDocument = new AnalyzeDocument() .setSubscriptionKey(cognitiveKey) .setLocation("eastus") @@ -128,7 +139,7 @@ class AnalyzeDocumentSuite extends TransformerFuzzing[AnalyzeDocument] with Form resultAssert(result, "WA WASHINGTON\n20 1234567XX1101\nDRIVER LICENSE\nFEDERAL LIMITS APPLY\n" + "4d LIC#WDLABCD456DG 9CLASS\nDONORS\n1 TALBOT\n2 LIAM R.\n3 DOB 01/06/1958\n", "Address,CountryRegion,DateOfBirth,DateOfExpiration,DocumentNumber," + - "Endorsements,FirstName,LastName,Region,Restrictions,Sex") + "Endorsements,FirstName,LastName,Region,Restrictions,Sex") } }