diff --git a/encoders/src/main/scala/au/csiro/pathling/encoders/SchemaProcessor.scala b/encoders/src/main/scala/au/csiro/pathling/encoders/SchemaProcessor.scala index 731b05f58e..d5a1b30dc0 100644 --- a/encoders/src/main/scala/au/csiro/pathling/encoders/SchemaProcessor.scala +++ b/encoders/src/main/scala/au/csiro/pathling/encoders/SchemaProcessor.scala @@ -24,6 +24,7 @@ package au.csiro.pathling.encoders import au.csiro.pathling.schema._ import ca.uhn.fhir.context._ +import org.hl7.fhir.instance.model.api.IBaseReference /** @@ -95,8 +96,20 @@ trait SchemaProcessor[DT, SF] extends SchemaVisitor[DT, SF] with EncoderSettings } } + private def includeElement(elementDefinition: BaseRuntimeElementDefinition[_]): Boolean = { + val nestingLevel = EncodingContext.currentNestingLevel(elementDefinition) + if (classOf[IBaseReference].isAssignableFrom(elementDefinition.getImplementingClass)) { + // This is a special provision for References which disallows any nesting. + // It removes the `assigner` field from the Identifier type instances + // nested inside a Reference (in its `identifier` element). + nestingLevel <= 0 + } else { + nestingLevel <= maxNestingLevel + } + } + override def visitElement(elementCtx: ElementCtx[DT, SF]): Seq[SF] = { - if (EncodingContext.currentNestingLevel(elementCtx.elementDefinition) <= maxNestingLevel) { + if (includeElement(elementCtx.elementDefinition)) { buildValue(elementCtx.childDefinition, elementCtx.elementDefinition, elementCtx.elementName) } else { Nil diff --git a/encoders/src/main/scala/au/csiro/pathling/encoders/SerializerBuilder.scala b/encoders/src/main/scala/au/csiro/pathling/encoders/SerializerBuilder.scala index 25e4489ab2..87ff0a1d9c 100644 --- a/encoders/src/main/scala/au/csiro/pathling/encoders/SerializerBuilder.scala +++ b/encoders/src/main/scala/au/csiro/pathling/encoders/SerializerBuilder.scala @@ -24,10 +24,8 @@ package au.csiro.pathling.encoders import au.csiro.pathling.encoders.ExtensionSupport.{EXTENSIONS_FIELD_NAME, FID_FIELD_NAME} -import au.csiro.pathling.encoders.QuantitySupport.{CODE_CANONICALIZED_FIELD_NAME, VALUE_CANONICALIZED_FIELD_NAME} import au.csiro.pathling.encoders.SerializerBuilderProcessor.{dataTypeToUtf8Expr, getChildExpression, objectTypeFor} -import au.csiro.pathling.encoders.datatypes.{DataTypeMappings, DecimalCustomCoder} -import au.csiro.pathling.encoders.terminology.ucum.Ucum +import au.csiro.pathling.encoders.datatypes.DataTypeMappings import au.csiro.pathling.schema.SchemaVisitor.isCollection import au.csiro.pathling.schema._ import ca.uhn.fhir.context.BaseRuntimeElementDefinition.ChildTypeEnum @@ -36,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.{ExternalMapToCatalyst, import org.apache.spark.sql.catalyst.expressions.{BoundReference, CreateNamedStruct, Expression, If, IsNull, Literal} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.hl7.fhir.instance.model.api.{IBaseDatatype, IBaseHasExtensions, IBaseResource} +import org.hl7.fhir.instance.model.api.{IBaseDatatype, IBaseHasExtensions, IBaseReference, IBaseResource} import org.hl7.fhir.r4.model.{Base, Extension, Quantity} import org.hl7.fhir.utilities.xhtml.XhtmlNode @@ -226,16 +224,24 @@ private[encoders] object SerializerBuilderProcessor { // Primitive single-value types typically use the Element suffix in their // accessors, with the exception of the "div" field for reasons that are not clear. //noinspection DuplicatedCode - if (field.isInstanceOf[RuntimeChildPrimitiveDatatypeDefinition] && - field.getMax == 1 && - field.getElementName != "div") - "get" + field.getElementName.capitalize + "Element" - else { - if (field.getElementName.equals("class")) { - "get" + field.getElementName.capitalize + "_" - } else { + field match { + case p: RuntimeChildPrimitiveDatatypeDefinition if p.getMax == 1 && p + .getElementName != "div" => + if ("reference" == p.getElementName && classOf[IBaseReference] + .isAssignableFrom(p.getField.getDeclaringClass)) { + // Special case for subclasses of IBaseReference. + // The accessor getReferenceElement returns IdType rather than + // StringType and getReferenceElement_ needs to be used instead. + // All subclasses of IBaseReference have a getReferenceElement_ + // method. + "getReferenceElement_" + } else { + "get" + p.getElementName.capitalize + "Element" + } + case f if f.getElementName.equals("class") => + "get" + f.getElementName.capitalize + "_" + case _ => "get" + field.getElementName.capitalize - } } } diff --git a/encoders/src/main/scala/au/csiro/pathling/encoders/datatypes/R4DataTypeMappings.scala b/encoders/src/main/scala/au/csiro/pathling/encoders/datatypes/R4DataTypeMappings.scala index 920e1afad2..43cbe00e22 100644 --- a/encoders/src/main/scala/au/csiro/pathling/encoders/datatypes/R4DataTypeMappings.scala +++ b/encoders/src/main/scala/au/csiro/pathling/encoders/datatypes/R4DataTypeMappings.scala @@ -30,7 +30,6 @@ import ca.uhn.fhir.model.api.TemporalPrecisionEnum import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions.objects.{InitializeJavaBean, Invoke, NewInstance, StaticInvoke} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, DataTypes, ObjectType} import org.hl7.fhir.instance.model.api.{IBase, IBaseDatatype, IPrimitiveType} import org.hl7.fhir.r4.model._ @@ -55,36 +54,13 @@ class R4DataTypeMappings extends DataTypeMappings { override def baseType(): Class[_ <: IBaseDatatype] = classOf[org.hl7.fhir.r4.model.Type] - override def overrideCompositeExpression(inputObject: Expression, + override def overrideCompositeExpression(inputObject: Expression, definition: BaseRuntimeElementCompositeDefinition[_]): Option[Seq[ExpressionWithName]] = { - - if (definition.getImplementingClass == classOf[Reference]) { - // Reference type, so return only supported fields. We also explicitly use the IIDType for the - // reference element, since that differs from the conventions used to infer other types. - val reference = dataTypeToUtf8Expr( - Invoke(inputObject, - "getReferenceElement", - ObjectType(classOf[IdType]))) - - val display = dataTypeToUtf8Expr( - Invoke(inputObject, - "getDisplayElement", - ObjectType(classOf[org.hl7.fhir.r4.model.StringType]))) - - Some(List(("reference", reference), ("display", display))) - } else { - None - } + None } override def skipField(definition: BaseRuntimeElementCompositeDefinition[_], child: BaseRuntimeChildDefinition): Boolean = { - - // References may be recursive, so include only the reference adn display name. - val skipRecursiveReference = definition.getImplementingClass == classOf[Reference] && - !(child.getElementName == "reference" || - child.getElementName == "display") - // Contains elements are currently not encoded in our Spark dataset. val skipContains = definition .getImplementingClass == classOf[ValueSet.ValueSetExpansionContainsComponent] && @@ -95,8 +71,7 @@ class R4DataTypeMappings extends DataTypeMappings { // "modifierExtensionExtension", not "extensionExtension". // See: https://github.com/hapifhir/hapi-fhir/issues/3414 val skipModifierExtension = child.getElementName.equals("modifierExtension") - - skipRecursiveReference || skipContains || skipModifierExtension + skipContains || skipModifierExtension } override def primitiveEncoderExpression(inputObject: Expression, diff --git a/encoders/src/test/java/au/csiro/pathling/encoders/FhirEncodersTest.java b/encoders/src/test/java/au/csiro/pathling/encoders/FhirEncodersTest.java index 426f55801f..281ba09ad3 100644 --- a/encoders/src/test/java/au/csiro/pathling/encoders/FhirEncodersTest.java +++ b/encoders/src/test/java/au/csiro/pathling/encoders/FhirEncodersTest.java @@ -23,7 +23,9 @@ package au.csiro.pathling.encoders; +import static org.apache.spark.sql.functions.col; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -273,13 +275,80 @@ public void coding() { @Test public void reference() { + final Condition conditionWithReferences = TestData.conditionWithReferencesWithIdentifiers(); + + final Dataset conditionL3Dataset = spark + .createDataset(ImmutableList.of(conditionWithReferences), ENCODERS_L3.of(Condition.class)); + + final Condition decodedL3Condition = conditionL3Dataset.head(); + + assertEquals( + RowFactory.create( + "withReferencesWithIdentifiers", + "Patient/example", + "http://terminology.hl7.org/CodeSystem/v2-0203", + "MR", + "https://fhir.example.com/identifiers/mrn", + "urn:id" + ), + conditionL3Dataset.select( + col("id"), + col("subject.reference"), + col("subject.identifier.type.coding.system").getItem(0), + col("subject.identifier.type.coding.code").getItem(0), + col("subject.identifier.system"), + col("subject.identifier.value") + ).head()); + + assertEquals("Patient/example", + decodedL3Condition.getSubject().getReference()); + + assertEquals("urn:id", + decodedL3Condition.getSubject().getIdentifier().getValue()); + + // the assigner should be pruned from the reference identifier. + assertTrue(conditionWithReferences.getSubject().getIdentifier().hasAssigner()); + assertFalse(decodedL3Condition.getSubject().getIdentifier().hasAssigner()); + } + - assertEquals(condition.getSubject().getReference(), - conditionsDataset.select("subject.reference").head().get(0)); - assertEquals(condition.getSubject().getReference(), - decodedCondition.getSubject().getReference()); + @Test + public void identifier() { + final Condition conditionWithIdentifiers = TestData.conditionWithIdentifiersWithReferences(); + + final Dataset conditionL3Dataset = spark + .createDataset(ImmutableList.of(conditionWithIdentifiers), ENCODERS_L3.of(Condition.class)); + + final Condition decodedL3Condition = conditionL3Dataset.head(); + + assertEquals( + RowFactory.create( + "withIdentifiersWithReferences", + "http://terminology.hl7.org/CodeSystem/v2-0203", + "MR", + "https://fhir.example.com/identifiers/mrn", + "urn:id01", + "Organization/001", + "urn:id02" + ), + conditionL3Dataset.select( + col("id"), + col("identifier.type.coding").getItem(0).getField("system").getItem(0), + col("identifier.type.coding").getItem(0).getField("code").getItem(0), + col("identifier.system").getItem(0), + col("identifier.value").getItem(0), + col("identifier.assigner.reference").getItem(0), + col("identifier.assigner.identifier.value").getItem(0) + ).head()); + + // the assigner should be pruned from the reference identifier. + assertTrue(conditionWithIdentifiers.getIdentifier().get(0).getAssigner().getIdentifier() + .hasAssigner()); + assertFalse( + decodedL3Condition.getIdentifier().get(0).getAssigner().getIdentifier().hasAssigner()); } + @Test public void integer() { @@ -325,13 +394,13 @@ public void choiceBigDecimalInQuestionnaire() { .getAnswerDecimalType().getValue(); final BigDecimal queriedDecimal = (BigDecimal) questionnaireDataset - .select(functions.col("item").getItem(0).getField("enableWhen").getItem(0) + .select(col("item").getItem(0).getField("enableWhen").getItem(0) .getField("answerDecimal")) .head() .get(0); final int queriedDecimal_scale = questionnaireDataset - .select(functions.col("item").getItem(0).getField("enableWhen").getItem(0) + .select(col("item").getItem(0).getField("enableWhen").getItem(0) .getField("answerDecimal_scale")) .head() .getInt(0); @@ -360,13 +429,13 @@ public void choiceBigDecimalInQuestionnaireResponse() { .getValueDecimalType().getValue(); final BigDecimal queriedDecimal = (BigDecimal) questionnaireResponseDataset - .select(functions.col("item").getItem(0).getField("answer").getItem(0) + .select(col("item").getItem(0).getField("answer").getItem(0) .getField("valueDecimal")) .head() .get(0); final int queriedDecimal_scale = questionnaireResponseDataset - .select(functions.col("item").getItem(0).getField("answer").getItem(0) + .select(col("item").getItem(0).getField("answer").getItem(0) .getField("valueDecimal_scale")) .head() .getInt(0); @@ -516,13 +585,13 @@ public void testNestedQuestionnaire() { assertEquals(Stream.of("Item/0", "Item/0", "Item/0", "Item/0").map(RowFactory::create) .collect(Collectors.toUnmodifiableList()), - questionnaireDataset_L3.select(functions.col("item").getItem(0).getField("linkId")) + questionnaireDataset_L3.select(col("item").getItem(0).getField("linkId")) .collectAsList()); assertEquals(Stream.of(null, "Item/1.0", "Item/1.0", "Item/1.0").map(RowFactory::create) .collect(Collectors.toUnmodifiableList()), questionnaireDataset_L3 - .select(functions.col("item") + .select(col("item") .getItem(1).getField("item") .getItem(0).getField("linkId")) .collectAsList()); @@ -530,7 +599,7 @@ public void testNestedQuestionnaire() { assertEquals(Stream.of(null, null, "Item/2.1.0", "Item/2.1.0").map(RowFactory::create) .collect(Collectors.toUnmodifiableList()), questionnaireDataset_L3 - .select(functions.col("item") + .select(col("item") .getItem(2).getField("item") .getItem(1).getField("item") .getItem(0).getField("linkId")) @@ -539,7 +608,7 @@ public void testNestedQuestionnaire() { assertEquals(Stream.of(null, null, null, "Item/3.2.1.0").map(RowFactory::create) .collect(Collectors.toUnmodifiableList()), questionnaireDataset_L3 - .select(functions.col("item") + .select(col("item") .getItem(3).getField("item") .getItem(2).getField("item") .getItem(1).getField("item") diff --git a/encoders/src/test/java/au/csiro/pathling/encoders/LightweightFhirEncodersTest.java b/encoders/src/test/java/au/csiro/pathling/encoders/LightweightFhirEncodersTest.java index 849ded3fcd..9cbcdedf10 100644 --- a/encoders/src/test/java/au/csiro/pathling/encoders/LightweightFhirEncodersTest.java +++ b/encoders/src/test/java/au/csiro/pathling/encoders/LightweightFhirEncodersTest.java @@ -43,14 +43,20 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; import org.apache.spark.sql.catalyst.encoders.RowEncoder; import org.hl7.fhir.r4.model.BaseResource; +import org.hl7.fhir.r4.model.CodeableConcept; +import org.hl7.fhir.r4.model.Coding; import org.hl7.fhir.r4.model.Condition; import org.hl7.fhir.r4.model.Device; +import org.hl7.fhir.r4.model.Expression; import org.hl7.fhir.r4.model.IdType; +import org.hl7.fhir.r4.model.Identifier; +import org.hl7.fhir.r4.model.Identifier.IdentifierUse; import org.hl7.fhir.r4.model.MolecularSequence; import org.hl7.fhir.r4.model.MolecularSequence.MolecularSequenceQualityRocComponent; import org.hl7.fhir.r4.model.Observation; import org.hl7.fhir.r4.model.PlanDefinition; import org.hl7.fhir.r4.model.PlanDefinition.PlanDefinitionActionComponent; +import org.hl7.fhir.r4.model.Reference; import org.json4s.jackson.JsonMethods; import org.junit.jupiter.api.Test; import scala.collection.mutable.WrappedArray; @@ -162,6 +168,65 @@ public void testHtmlNarrative() { assertSerDeIsIdentity(encoder, conditionWithNarrative); } + @Test + public void testReference() { + final ExpressionEncoder encoder = fhirEncoders + .of(Condition.class); + final Condition conditionWithFullReference = new Condition(); + final Identifier identifier = new Identifier() + .setSystem("urn:id-system") + .setValue("id-valule") + .setUse(IdentifierUse.OFFICIAL) + .setType(new CodeableConcept().addCoding(new Coding().setCode("code").setSystem("system")) + .setText("text")); + final Reference referenceWithAllFields = new Reference("Patient/1234") + .setDisplay("Some Display Name") + .setType("Patient") + .setIdentifier(identifier); + // Set also the Element inherited fields + referenceWithAllFields.setId("some-id"); + conditionWithFullReference.setSubject(referenceWithAllFields); + assertSerDeIsIdentity(encoder, conditionWithFullReference); + } + + @Test + public void testIdentifier() { + final ExpressionEncoder encoder = fhirEncoders + .of(Condition.class); + final Condition conditionWithIdentifierWithAssigner = new Condition(); + + final Reference assignerReference = new Reference("Organization/1234") + .setDisplay("Some Display Name") + .setType("Organization"); + + final Identifier identifier = new Identifier() + .setSystem("urn:id-system") + .setValue("id-valule") + .setUse(IdentifierUse.OFFICIAL) + .setAssigner(assignerReference) + .setType(new CodeableConcept().addCoding(new Coding().setCode("code").setSystem("system")) + .setText("text")); + conditionWithIdentifierWithAssigner.addIdentifier(identifier); + assertSerDeIsIdentity(encoder, conditionWithIdentifierWithAssigner); + } + + @Test + public void testExpression() { + + // Expression contains 'reference' field + // We are checking that it is encoded in generic way not and not the subject to special case for Reference 'reference' field. + final ExpressionEncoder encoder = fhirEncoders + .of(PlanDefinition.class); + + final PlanDefinition planDefinition = new PlanDefinition(); + + final PlanDefinitionActionComponent actionComponent = planDefinition + .getActionFirstRep(); + actionComponent.getConditionFirstRep().setExpression(new Expression().setLanguage("language") + .setExpression("expression").setDescription("description")); + assertSerDeIsIdentity(encoder, planDefinition); + } + @Test public void testThrowsExceptionWhenUnsupportedResource() { for (final String resourceName : EXCLUDED_RESOURCES) { @@ -264,7 +329,7 @@ public void testQuantityArrayCanonicalization() { final List properties = deviceRow.getList(deviceRow.fieldIndex("property")); final Row propertyRow = properties.get(0); final List quantityArray = propertyRow.getList(propertyRow.fieldIndex("valueQuantity")); - + final Row quantity1 = quantityArray.get(0); assertQuantity(quantity1, "0.0010", "m"); diff --git a/encoders/src/test/java/au/csiro/pathling/encoders/SchemaConverterTest.java b/encoders/src/test/java/au/csiro/pathling/encoders/SchemaConverterTest.java index 02fc596650..eb36ee8467 100644 --- a/encoders/src/test/java/au/csiro/pathling/encoders/SchemaConverterTest.java +++ b/encoders/src/test/java/au/csiro/pathling/encoders/SchemaConverterTest.java @@ -25,6 +25,7 @@ import static au.csiro.pathling.test.SchemaAsserts.assertFieldNotPresent; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -93,9 +94,8 @@ public class SchemaConverterTest { private StructType medRequestSchema; private StructType questionnaireSchema; private StructType questionnaireResponseSchema; - private StructType deviceSchema; - + private StructType observationSchema_L2; /** * Traverses a DataType recursively passing all encountered StructTypes to the provided consumer. @@ -173,16 +173,17 @@ public void setUp() { questionnaireSchema = converter_L0.resourceSchema(Questionnaire.class); questionnaireResponseSchema = converter_L0.resourceSchema(QuestionnaireResponse.class); deviceSchema = converter_L0.resourceSchema(Device.class); + observationSchema_L2 = converter_L2.resourceSchema(Observation.class); } @Test public void resourceHasId() { - assertTrue(getField(conditionSchema, true, "id") instanceof StringType); + assertInstanceOf(StringType.class, getField(conditionSchema, true, "id")); } @Test public void boundCodeToStruct() { - assertTrue(getField(conditionSchema, true, "verificationStatus") instanceof StructType); + assertInstanceOf(StructType.class, getField(conditionSchema, true, "verificationStatus")); } @Test @@ -190,11 +191,11 @@ public void codingToStruct() { final DataType codingType = getField(conditionSchema, true, "severity", "coding"); - assertTrue(getField(codingType, true, "system") instanceof StringType); - assertTrue(getField(codingType, true, "version") instanceof StringType); - assertTrue(getField(codingType, true, "code") instanceof StringType); - assertTrue(getField(codingType, true, "display") instanceof StringType); - assertTrue(getField(codingType, true, "userSelected") instanceof BooleanType); + assertInstanceOf(StringType.class, getField(codingType, true, "system")); + assertInstanceOf(StringType.class, getField(codingType, true, "version")); + assertInstanceOf(StringType.class, getField(codingType, true, "code")); + assertInstanceOf(StringType.class, getField(codingType, true, "display")); + assertInstanceOf(BooleanType.class, getField(codingType, true, "userSelected")); } @Test @@ -202,30 +203,30 @@ public void codeableConceptToStruct() { final DataType codeableType = getField(conditionSchema, true, "severity"); - assertTrue(codeableType instanceof StructType); - assertTrue(getField(codeableType, true, "coding") instanceof ArrayType); - assertTrue(getField(codeableType, true, "text") instanceof StringType); + assertInstanceOf(StructType.class, codeableType); + assertInstanceOf(ArrayType.class, getField(codeableType, true, "coding")); + assertInstanceOf(StringType.class, getField(codeableType, true, "text")); } @Test public void idToString() { - assertTrue(getField(conditionSchema, true, "id") instanceof StringType); + assertInstanceOf(StringType.class, getField(conditionSchema, true, "id")); } @Test public void narrativeToStruct() { - assertTrue(getField(conditionSchema, true, "text", "status") instanceof StringType); - assertTrue(getField(conditionSchema, true, "text", "div") instanceof StringType); + assertInstanceOf(StringType.class, getField(conditionSchema, true, "text", "status")); + assertInstanceOf(StringType.class, getField(conditionSchema, true, "text", "div")); } @Test public void expandChoiceFields() { - assertTrue(getField(conditionSchema, true, "onsetPeriod") instanceof StructType); - assertTrue(getField(conditionSchema, true, "onsetRange") instanceof StructType); - assertTrue(getField(conditionSchema, true, "onsetDateTime") instanceof StringType); - assertTrue(getField(conditionSchema, true, "onsetString") instanceof StringType); - assertTrue(getField(conditionSchema, true, "onsetAge") instanceof StructType); + assertInstanceOf(StructType.class, getField(conditionSchema, true, "onsetPeriod")); + assertInstanceOf(StructType.class, getField(conditionSchema, true, "onsetRange")); + assertInstanceOf(StringType.class, getField(conditionSchema, true, "onsetDateTime")); + assertInstanceOf(StringType.class, getField(conditionSchema, true, "onsetString")); + assertInstanceOf(StructType.class, getField(conditionSchema, true, "onsetAge")); } @Test @@ -244,19 +245,21 @@ public void orderChoiceFields() { @Test public void decimalWithinChoiceField() { - assertTrue(getField(questionnaireSchema, true, "item", "enableWhen", - "answerDecimal") instanceof DecimalType); - assertTrue(getField(questionnaireSchema, true, "item", "enableWhen", - "answerDecimal_scale") instanceof IntegerType); - assertTrue(getField(questionnaireResponseSchema, true, "item", "answer", - "valueDecimal") instanceof DecimalType); - assertTrue(getField(questionnaireResponseSchema, true, "item", "answer", - "valueDecimal_scale") instanceof IntegerType); + assertInstanceOf(DecimalType.class, getField(questionnaireSchema, true, "item", "enableWhen", + "answerDecimal")); + assertInstanceOf(IntegerType.class, getField(questionnaireSchema, true, "item", "enableWhen", + "answerDecimal_scale")); + assertInstanceOf(DecimalType.class, + getField(questionnaireResponseSchema, true, "item", "answer", + "valueDecimal")); + assertInstanceOf(IntegerType.class, + getField(questionnaireResponseSchema, true, "item", "answer", + "valueDecimal_scale")); } @Test public void instantToTimestamp() { - assertTrue(getField(observationSchema, true, "issued") instanceof TimestampType); + assertInstanceOf(TimestampType.class, getField(observationSchema, true, "issued")); } @Test @@ -266,17 +269,59 @@ public void timeToString() { @Test public void bigDecimalToDecimal() { - assertTrue( - getField(observationSchema, true, "valueQuantity", "value") instanceof DecimalType); + assertInstanceOf(DecimalType.class, + getField(observationSchema, true, "valueQuantity", "value")); } @Test public void reference() { - assertTrue( - getField(observationSchema, true, "subject", "reference") instanceof StringType); - assertTrue(getField(observationSchema, true, "subject", "display") instanceof StringType); + assertInstanceOf(StringType.class, getField(observationSchema, true, "subject", "id")); + assertInstanceOf(StringType.class, getField(observationSchema, true, "subject", "reference")); + assertInstanceOf(StringType.class, getField(observationSchema, true, "subject", "display")); + assertInstanceOf(StringType.class, getField(observationSchema, true, "subject", "type")); + assertInstanceOf(StructType.class, getField(observationSchema, true, "subject", "identifier")); + assertInstanceOf(StringType.class, + getField(observationSchema, true, "subject", "identifier", "value")); + } + @Test + public void identifier() { + assertInstanceOf(StringType.class, + unArray(getField(observationSchema, true, "identifier", "value"))); + // `assigner` field should be present in the root level `Identifier` schema. + assertInstanceOf(StructType.class, + unArray(getField(observationSchema, true, "identifier", "assigner"))); + assertInstanceOf(StringType.class, + unArray(getField(observationSchema, true, "identifier", "assigner", "reference"))); + + } + + @Test + public void identifierInReference() { + // + // Identifier (assigner) in root Reference + // + assertFieldNotPresent("assigner", getField(observationSchema, true, "subject", "identifier")); + // The `assigner` field should not be present in Identifier schema of the Reference `identifier` field. + assertFieldNotPresent("assigner", + getField(observationSchema_L2, true, "subject", "identifier")); + + // + // Identifier (assigner) in a Reference nested in an Identifier + // + // the `identifier` field should not be present because for normal nesting rules for 0-level nesting + assertFieldNotPresent("identifier", + unArray(getField(observationSchema, true, "identifier", "assigner"))); + // the `identifier` field should be present because for normal nesting rules for 2-level nesting + assertInstanceOf(StructType.class, + unArray(getField(observationSchema_L2, true, "identifier", "assigner", "identifier"))); + // but it should not have the assigner field + assertFieldNotPresent("assigner", + unArray(getField(observationSchema_L2, true, "identifier", "assigner", "identifier"))); + } + + @Test public void preferredNameOnly() { @@ -374,7 +419,7 @@ public void testExtensions() { final MapType extensionsContainerType = (MapType) getField(extensionSchema, true, "_extension"); assertEquals(DataTypes.IntegerType, extensionsContainerType.keyType()); - assertTrue(extensionsContainerType.valueType() instanceof ArrayType); + assertInstanceOf(ArrayType.class, extensionsContainerType.valueType()); traverseSchema(extensionSchema, t -> { assertEquals(DataTypes.IntegerType, t.fields()[t.fieldIndex("_fid")].dataType()); @@ -424,13 +469,13 @@ public void testQuantityArray() { } private void assertQuantityType(final DataType quantityType) { - assertTrue(getField(quantityType, true, "value") instanceof DecimalType); - assertTrue(getField(quantityType, true, "value_scale") instanceof IntegerType); - assertTrue(getField(quantityType, true, "comparator") instanceof StringType); - assertTrue(getField(quantityType, true, "unit") instanceof StringType); - assertTrue(getField(quantityType, true, "system") instanceof StringType); - assertTrue(getField(quantityType, true, "code") instanceof StringType); + assertInstanceOf(DecimalType.class, getField(quantityType, true, "value")); + assertInstanceOf(IntegerType.class, getField(quantityType, true, "value_scale")); + assertInstanceOf(StringType.class, getField(quantityType, true, "comparator")); + assertInstanceOf(StringType.class, getField(quantityType, true, "unit")); + assertInstanceOf(StringType.class, getField(quantityType, true, "system")); + assertInstanceOf(StringType.class, getField(quantityType, true, "code")); assertEquals(FlexiDecimal.DATA_TYPE, getField(quantityType, true, "_value_canonicalized")); - assertTrue(getField(quantityType, true, "_code_canonicalized") instanceof StringType); + assertInstanceOf(StringType.class, getField(quantityType, true, "_code_canonicalized")); } } diff --git a/encoders/src/test/java/au/csiro/pathling/encoders/TestData.java b/encoders/src/test/java/au/csiro/pathling/encoders/TestData.java index 010d1586a3..1faa47c94b 100644 --- a/encoders/src/test/java/au/csiro/pathling/encoders/TestData.java +++ b/encoders/src/test/java/au/csiro/pathling/encoders/TestData.java @@ -138,6 +138,44 @@ public static Condition newCondition() { return condition; } + public static Condition conditionWithReferencesWithIdentifiers() { + final Condition condition = new Condition(); + condition.setId("withReferencesWithIdentifiers"); + final Coding typeCoding = new Coding("http://terminology.hl7.org/CodeSystem/v2-0203", "MR", + null); + final CodeableConcept typeConcept = new CodeableConcept(typeCoding); + condition.setSubject( + new Reference("Patient/example") + .setDisplay("Display name") + .setIdentifier( + new Identifier() + .setType(typeConcept) + .setSystem("https://fhir.example.com/identifiers/mrn") + .setValue("urn:id") + .setAssigner(new Reference("Organization/001")) + ) + ); + return condition; + } + + public static Condition conditionWithIdentifiersWithReferences() { + final Condition condition = new Condition(); + condition.setId("withIdentifiersWithReferences"); + final Coding typeCoding = new Coding("http://terminology.hl7.org/CodeSystem/v2-0203", "MR", + null); + final CodeableConcept typeConcept = new CodeableConcept(typeCoding); + condition + .addIdentifier() + .setType(typeConcept) + .setSystem("https://fhir.example.com/identifiers/mrn") + .setValue("urn:id01") + .setAssigner(new Reference("Organization/001") + .setIdentifier(new Identifier().setValue("urn:id02") + .setAssigner(new Reference("Organization/002")))); + return condition; + } + + public static Condition conditionWithVersion() { final Condition condition = new Condition(); final IdType id = new IdType("Condition", "with-version", "1"); @@ -145,6 +183,7 @@ public static Condition conditionWithVersion() { return condition; } + /** * Returns a FHIR Observation for testing purposes. */ diff --git a/fhir-server/src/test/java/au/csiro/pathling/fhirpath/parser/ParserTest.java b/fhir-server/src/test/java/au/csiro/pathling/fhirpath/parser/ParserTest.java index ab17db6576..d27ec73230 100644 --- a/fhir-server/src/test/java/au/csiro/pathling/fhirpath/parser/ParserTest.java +++ b/fhir-server/src/test/java/au/csiro/pathling/fhirpath/parser/ParserTest.java @@ -836,15 +836,7 @@ void testQuantityAdditionWithOverflow() { .selectResult() .hasRows(spark, "responses/ParserTest/testQuantityAdditionWithOverflow_code.csv"); } - - @Test - void testTraversalToUnsupportedReferenceChild() { - final String expression = "reverseResolve(MedicationRequest.subject).requester.identifier"; - final InvalidUserInputError error = assertThrows(InvalidUserInputError.class, - expression); - assertEquals("No such child: " + expression, error.getMessage()); - } - + @Test void testResolutionOfExtensionReference() { mockResource(ResourceType.PATIENT, ResourceType.ENCOUNTER, ResourceType.GOAL); diff --git a/fhirpath/src/main/java/au/csiro/pathling/fhirpath/element/ReferencePath.java b/fhirpath/src/main/java/au/csiro/pathling/fhirpath/element/ReferencePath.java index 61fa51bd55..b5219bdd7d 100644 --- a/fhirpath/src/main/java/au/csiro/pathling/fhirpath/element/ReferencePath.java +++ b/fhirpath/src/main/java/au/csiro/pathling/fhirpath/element/ReferencePath.java @@ -69,16 +69,4 @@ public Column getResourceEquality(@Nonnull final Column targetId, @Nonnull final Column targetCode) { return Referrer.resourceEqualityFor(this, targetCode, targetId); } - - @Nonnull - @Override - public Optional getChildElement(@Nonnull final String name) { - // We only encode the reference and display elements of the Reference type. - if (name.equals("reference") || name.equals("display")) { - return super.getChildElement(name); - } else { - return Optional.empty(); - } - } - } diff --git a/fhirpath/src/test/java/au/csiro/pathling/fhirpath/parser/FhirPathTest.java b/fhirpath/src/test/java/au/csiro/pathling/fhirpath/parser/FhirPathTest.java new file mode 100644 index 0000000000..d69bc21170 --- /dev/null +++ b/fhirpath/src/test/java/au/csiro/pathling/fhirpath/parser/FhirPathTest.java @@ -0,0 +1,115 @@ +/* + * Copyright 2023 Commonwealth Scientific and Industrial Research + * Organisation (CSIRO) ABN 41 687 119 230. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package au.csiro.pathling.fhirpath.parser; + +import static au.csiro.pathling.test.assertions.Assertions.assertThat; +import static org.mockito.Mockito.when; + +import au.csiro.pathling.encoders.FhirEncoders; +import au.csiro.pathling.fhirpath.ResourcePath; +import au.csiro.pathling.io.source.DataSource; +import au.csiro.pathling.terminology.TerminologyService; +import au.csiro.pathling.terminology.TerminologyServiceFactory; +import au.csiro.pathling.test.SpringBootUnitTest; +import au.csiro.pathling.test.assertions.FhirPathAssertion; +import au.csiro.pathling.test.builders.ParserContextBuilder; +import ca.uhn.fhir.context.FhirContext; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import javax.annotation.Nonnull; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.hl7.fhir.r4.model.Condition; +import org.hl7.fhir.r4.model.Enumerations.ResourceType; +import org.hl7.fhir.r4.model.Identifier; +import org.hl7.fhir.r4.model.Reference; +import org.hl7.fhir.r4.model.Resource; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.mock.mockito.MockBean; + +@SpringBootUnitTest +public class FhirPathTest { + + @Autowired + protected SparkSession spark; + + @Autowired + FhirContext fhirContext; + + @Autowired + TerminologyService terminologyService; + + @Autowired + FhirEncoders fhirEncoders; + + @Autowired + TerminologyServiceFactory terminologyServiceFactory; + + @MockBean + DataSource dataSource; + + + @SuppressWarnings("SameParameterValue") + @Nonnull + protected FhirPathAssertion assertThatResultOf(@Nonnull final ResourceType resourceType, + @Nonnull final String expression) { + final ResourcePath subjectResource = ResourcePath + .build(fhirContext, dataSource, resourceType, resourceType.toCode(), true); + + final ParserContext parserContext = new ParserContextBuilder(spark, fhirContext) + .terminologyClientFactory(terminologyServiceFactory) + .database(dataSource) + .inputContext(subjectResource) + .build(); + final Parser resourceParser = new Parser(parserContext); + return assertThat(resourceParser.parse(expression)); + } + + + void withResources(@Nonnull final Resource resources) { + + // group resources by type + // and then encode them into a dataset and setup the mock datasorce + + Stream.of(resources).collect(Collectors.groupingBy(Resource::getResourceType)) + .forEach((resourceType, resourcesOfType) -> { + final ResourceType resourceTypeEnum = ResourceType.fromCode(resourceType.name()); + final Dataset dataset = spark.createDataset(resourcesOfType, + fhirEncoders.of(resourceTypeEnum.toCode())) + .toDF(); + when(dataSource.read(resourceTypeEnum)).thenReturn(dataset.cache()); + }); + } + + + @Test + void testTraversalIntoReferenceIdentifier() { + withResources( + new Condition() + .setSubject(new Reference().setIdentifier(new Identifier().setValue("value"))) + .setId("001") + ); + assertThatResultOf(ResourceType.CONDITION, "subject.identifier.value") + .selectResult() + .hasRows(RowFactory.create("001", "value")); + } + +}