Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing full encoding and decoding support for References #1786

Merged
merged 7 commits into from
May 21, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ package au.csiro.pathling.encoders

import au.csiro.pathling.schema._
import ca.uhn.fhir.context._
import org.hl7.fhir.instance.model.api.IBaseReference


/**
Expand Down Expand Up @@ -95,8 +96,20 @@ trait SchemaProcessor[DT, SF] extends SchemaVisitor[DT, SF] with EncoderSettings
}
}

private def includeElement(elementDefinition: BaseRuntimeElementDefinition[_]): Boolean = {
val nestingLevel = EncodingContext.currentNestingLevel(elementDefinition)
if (classOf[IBaseReference].isAssignableFrom(elementDefinition.getImplementingClass)) {
// This is a special provision for References which disallows any nesting.
// It removes the `assigner` field from the Identifier type instances
// nested inside a Reference (in its `identifier` element).
nestingLevel <= 0
} else {
nestingLevel <= maxNestingLevel
}
}

override def visitElement(elementCtx: ElementCtx[DT, SF]): Seq[SF] = {
if (EncodingContext.currentNestingLevel(elementCtx.elementDefinition) <= maxNestingLevel) {
if (includeElement(elementCtx.elementDefinition)) {
buildValue(elementCtx.childDefinition, elementCtx.elementDefinition, elementCtx.elementName)
} else {
Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@
package au.csiro.pathling.encoders

import au.csiro.pathling.encoders.ExtensionSupport.{EXTENSIONS_FIELD_NAME, FID_FIELD_NAME}
import au.csiro.pathling.encoders.QuantitySupport.{CODE_CANONICALIZED_FIELD_NAME, VALUE_CANONICALIZED_FIELD_NAME}
import au.csiro.pathling.encoders.SerializerBuilderProcessor.{dataTypeToUtf8Expr, getChildExpression, objectTypeFor}
import au.csiro.pathling.encoders.datatypes.{DataTypeMappings, DecimalCustomCoder}
import au.csiro.pathling.encoders.terminology.ucum.Ucum
import au.csiro.pathling.encoders.datatypes.DataTypeMappings
import au.csiro.pathling.schema.SchemaVisitor.isCollection
import au.csiro.pathling.schema._
import ca.uhn.fhir.context.BaseRuntimeElementDefinition.ChildTypeEnum
Expand All @@ -36,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.{ExternalMapToCatalyst,
import org.apache.spark.sql.catalyst.expressions.{BoundReference, CreateNamedStruct, Expression, If, IsNull, Literal}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.hl7.fhir.instance.model.api.{IBaseDatatype, IBaseHasExtensions, IBaseResource}
import org.hl7.fhir.instance.model.api.{IBaseDatatype, IBaseHasExtensions, IBaseReference, IBaseResource}
import org.hl7.fhir.r4.model.{Base, Extension, Quantity}
import org.hl7.fhir.utilities.xhtml.XhtmlNode

Expand Down Expand Up @@ -226,16 +224,24 @@ private[encoders] object SerializerBuilderProcessor {
// Primitive single-value types typically use the Element suffix in their
// accessors, with the exception of the "div" field for reasons that are not clear.
//noinspection DuplicatedCode
if (field.isInstanceOf[RuntimeChildPrimitiveDatatypeDefinition] &&
field.getMax == 1 &&
field.getElementName != "div")
"get" + field.getElementName.capitalize + "Element"
else {
if (field.getElementName.equals("class")) {
"get" + field.getElementName.capitalize + "_"
} else {
field match {
case p: RuntimeChildPrimitiveDatatypeDefinition if p.getMax == 1 && p
.getElementName != "div" =>
if ("reference" == p.getElementName && classOf[IBaseReference]
.isAssignableFrom(p.getField.getDeclaringClass)) {
// Special case for subclasses of IBaseReference.
// The accessor getReferenceElement returns IdType rather than
// StringType and getReferenceElement_ needs to be used instead.
// All subclasses of IBaseReference have a getReferenceElement_
// method.
"getReferenceElement_"
johngrimes marked this conversation as resolved.
Show resolved Hide resolved
} else {
"get" + p.getElementName.capitalize + "Element"
}
case f if f.getElementName.equals("class") =>
"get" + f.getElementName.capitalize + "_"
case _ =>
"get" + field.getElementName.capitalize
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import ca.uhn.fhir.model.api.TemporalPrecisionEnum
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions.objects.{InitializeJavaBean, Invoke, NewInstance, StaticInvoke}
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, DataTypes, ObjectType}
import org.hl7.fhir.instance.model.api.{IBase, IBaseDatatype, IPrimitiveType}
import org.hl7.fhir.r4.model._
Expand All @@ -55,36 +54,13 @@ class R4DataTypeMappings extends DataTypeMappings {

override def baseType(): Class[_ <: IBaseDatatype] = classOf[org.hl7.fhir.r4.model.Type]

override def overrideCompositeExpression(inputObject: Expression,
override def overrideCompositeExpression(inputObject: Expression,
definition: BaseRuntimeElementCompositeDefinition[_]): Option[Seq[ExpressionWithName]] = {

if (definition.getImplementingClass == classOf[Reference]) {
// Reference type, so return only supported fields. We also explicitly use the IIDType for the
// reference element, since that differs from the conventions used to infer other types.
val reference = dataTypeToUtf8Expr(
Invoke(inputObject,
"getReferenceElement",
ObjectType(classOf[IdType])))

val display = dataTypeToUtf8Expr(
Invoke(inputObject,
"getDisplayElement",
ObjectType(classOf[org.hl7.fhir.r4.model.StringType])))

Some(List(("reference", reference), ("display", display)))
} else {
None
}
None
}

override def skipField(definition: BaseRuntimeElementCompositeDefinition[_],
child: BaseRuntimeChildDefinition): Boolean = {

// References may be recursive, so include only the reference adn display name.
val skipRecursiveReference = definition.getImplementingClass == classOf[Reference] &&
!(child.getElementName == "reference" ||
child.getElementName == "display")

// Contains elements are currently not encoded in our Spark dataset.
val skipContains = definition
.getImplementingClass == classOf[ValueSet.ValueSetExpansionContainsComponent] &&
Expand All @@ -95,8 +71,7 @@ class R4DataTypeMappings extends DataTypeMappings {
// "modifierExtensionExtension", not "extensionExtension".
// See: https://github.com/hapifhir/hapi-fhir/issues/3414
val skipModifierExtension = child.getElementName.equals("modifierExtension")

skipRecursiveReference || skipContains || skipModifierExtension
skipContains || skipModifierExtension
}

override def primitiveEncoderExpression(inputObject: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

package au.csiro.pathling.encoders;

import static org.apache.spark.sql.functions.col;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand Down Expand Up @@ -273,13 +275,80 @@ public void coding() {

@Test
public void reference() {
final Condition conditionWithReferences = TestData.conditionWithReferencesWithIdentifiers();

final Dataset<Condition> conditionL3Dataset = spark
.createDataset(ImmutableList.of(conditionWithReferences), ENCODERS_L3.of(Condition.class));

final Condition decodedL3Condition = conditionL3Dataset.head();

assertEquals(
RowFactory.create(
"withReferencesWithIdentifiers",
"Patient/example",
"http://terminology.hl7.org/CodeSystem/v2-0203",
"MR",
"https://fhir.example.com/identifiers/mrn",
"urn:id"
),
conditionL3Dataset.select(
col("id"),
col("subject.reference"),
col("subject.identifier.type.coding.system").getItem(0),
col("subject.identifier.type.coding.code").getItem(0),
col("subject.identifier.system"),
col("subject.identifier.value")
).head());

assertEquals("Patient/example",
decodedL3Condition.getSubject().getReference());

assertEquals("urn:id",
decodedL3Condition.getSubject().getIdentifier().getValue());

// the assigner should be pruned from the reference identifier.
assertTrue(conditionWithReferences.getSubject().getIdentifier().hasAssigner());
assertFalse(decodedL3Condition.getSubject().getIdentifier().hasAssigner());
}


assertEquals(condition.getSubject().getReference(),
conditionsDataset.select("subject.reference").head().get(0));
assertEquals(condition.getSubject().getReference(),
decodedCondition.getSubject().getReference());
@Test
public void identifier() {
final Condition conditionWithIdentifiers = TestData.conditionWithIdentifiersWithReferences();

final Dataset<Condition> conditionL3Dataset = spark
.createDataset(ImmutableList.of(conditionWithIdentifiers), ENCODERS_L3.of(Condition.class));

final Condition decodedL3Condition = conditionL3Dataset.head();

assertEquals(
RowFactory.create(
"withIdentifiersWithReferences",
"http://terminology.hl7.org/CodeSystem/v2-0203",
"MR",
"https://fhir.example.com/identifiers/mrn",
"urn:id01",
"Organization/001",
"urn:id02"
),
conditionL3Dataset.select(
col("id"),
col("identifier.type.coding").getItem(0).getField("system").getItem(0),
col("identifier.type.coding").getItem(0).getField("code").getItem(0),
col("identifier.system").getItem(0),
col("identifier.value").getItem(0),
col("identifier.assigner.reference").getItem(0),
col("identifier.assigner.identifier.value").getItem(0)
).head());

// the assigner should be pruned from the reference identifier.
assertTrue(conditionWithIdentifiers.getIdentifier().get(0).getAssigner().getIdentifier()
.hasAssigner());
assertFalse(
decodedL3Condition.getIdentifier().get(0).getAssigner().getIdentifier().hasAssigner());
}


@Test
public void integer() {

Expand Down Expand Up @@ -325,13 +394,13 @@ public void choiceBigDecimalInQuestionnaire() {
.getAnswerDecimalType().getValue();

final BigDecimal queriedDecimal = (BigDecimal) questionnaireDataset
.select(functions.col("item").getItem(0).getField("enableWhen").getItem(0)
.select(col("item").getItem(0).getField("enableWhen").getItem(0)
.getField("answerDecimal"))
.head()
.get(0);

final int queriedDecimal_scale = questionnaireDataset
.select(functions.col("item").getItem(0).getField("enableWhen").getItem(0)
.select(col("item").getItem(0).getField("enableWhen").getItem(0)
.getField("answerDecimal_scale"))
.head()
.getInt(0);
Expand Down Expand Up @@ -360,13 +429,13 @@ public void choiceBigDecimalInQuestionnaireResponse() {
.getValueDecimalType().getValue();

final BigDecimal queriedDecimal = (BigDecimal) questionnaireResponseDataset
.select(functions.col("item").getItem(0).getField("answer").getItem(0)
.select(col("item").getItem(0).getField("answer").getItem(0)
.getField("valueDecimal"))
.head()
.get(0);

final int queriedDecimal_scale = questionnaireResponseDataset
.select(functions.col("item").getItem(0).getField("answer").getItem(0)
.select(col("item").getItem(0).getField("answer").getItem(0)
.getField("valueDecimal_scale"))
.head()
.getInt(0);
Expand Down Expand Up @@ -516,21 +585,21 @@ public void testNestedQuestionnaire() {

assertEquals(Stream.of("Item/0", "Item/0", "Item/0", "Item/0").map(RowFactory::create)
.collect(Collectors.toUnmodifiableList()),
questionnaireDataset_L3.select(functions.col("item").getItem(0).getField("linkId"))
questionnaireDataset_L3.select(col("item").getItem(0).getField("linkId"))
.collectAsList());

assertEquals(Stream.of(null, "Item/1.0", "Item/1.0", "Item/1.0").map(RowFactory::create)
.collect(Collectors.toUnmodifiableList()),
questionnaireDataset_L3
.select(functions.col("item")
.select(col("item")
.getItem(1).getField("item")
.getItem(0).getField("linkId"))
.collectAsList());

assertEquals(Stream.of(null, null, "Item/2.1.0", "Item/2.1.0").map(RowFactory::create)
.collect(Collectors.toUnmodifiableList()),
questionnaireDataset_L3
.select(functions.col("item")
.select(col("item")
.getItem(2).getField("item")
.getItem(1).getField("item")
.getItem(0).getField("linkId"))
Expand All @@ -539,7 +608,7 @@ public void testNestedQuestionnaire() {
assertEquals(Stream.of(null, null, null, "Item/3.2.1.0").map(RowFactory::create)
.collect(Collectors.toUnmodifiableList()),
questionnaireDataset_L3
.select(functions.col("item")
.select(col("item")
.getItem(3).getField("item")
.getItem(2).getField("item")
.getItem(1).getField("item")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,20 @@
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
import org.apache.spark.sql.catalyst.encoders.RowEncoder;
import org.hl7.fhir.r4.model.BaseResource;
import org.hl7.fhir.r4.model.CodeableConcept;
import org.hl7.fhir.r4.model.Coding;
import org.hl7.fhir.r4.model.Condition;
import org.hl7.fhir.r4.model.Device;
import org.hl7.fhir.r4.model.Expression;
import org.hl7.fhir.r4.model.IdType;
import org.hl7.fhir.r4.model.Identifier;
import org.hl7.fhir.r4.model.Identifier.IdentifierUse;
import org.hl7.fhir.r4.model.MolecularSequence;
import org.hl7.fhir.r4.model.MolecularSequence.MolecularSequenceQualityRocComponent;
import org.hl7.fhir.r4.model.Observation;
import org.hl7.fhir.r4.model.PlanDefinition;
import org.hl7.fhir.r4.model.PlanDefinition.PlanDefinitionActionComponent;
import org.hl7.fhir.r4.model.Reference;
import org.json4s.jackson.JsonMethods;
import org.junit.jupiter.api.Test;
import scala.collection.mutable.WrappedArray;
Expand Down Expand Up @@ -162,6 +168,65 @@ public void testHtmlNarrative() {
assertSerDeIsIdentity(encoder, conditionWithNarrative);
}

@Test
public void testReference() {
final ExpressionEncoder<Condition> encoder = fhirEncoders
.of(Condition.class);
final Condition conditionWithFullReference = new Condition();
final Identifier identifier = new Identifier()
.setSystem("urn:id-system")
.setValue("id-valule")
.setUse(IdentifierUse.OFFICIAL)
.setType(new CodeableConcept().addCoding(new Coding().setCode("code").setSystem("system"))
.setText("text"));
final Reference referenceWithAllFields = new Reference("Patient/1234")
.setDisplay("Some Display Name")
.setType("Patient")
.setIdentifier(identifier);
// Set also the Element inherited fields
referenceWithAllFields.setId("some-id");
conditionWithFullReference.setSubject(referenceWithAllFields);
assertSerDeIsIdentity(encoder, conditionWithFullReference);
}

@Test
public void testIdentifier() {
final ExpressionEncoder<Condition> encoder = fhirEncoders
.of(Condition.class);
final Condition conditionWithIdentifierWithAssigner = new Condition();

final Reference assignerReference = new Reference("Organization/1234")
.setDisplay("Some Display Name")
.setType("Organization");

final Identifier identifier = new Identifier()
.setSystem("urn:id-system")
.setValue("id-valule")
.setUse(IdentifierUse.OFFICIAL)
.setAssigner(assignerReference)
.setType(new CodeableConcept().addCoding(new Coding().setCode("code").setSystem("system"))
.setText("text"));
conditionWithIdentifierWithAssigner.addIdentifier(identifier);
assertSerDeIsIdentity(encoder, conditionWithIdentifierWithAssigner);
}

@Test
public void testExpression() {

// Expression contains 'reference' field
// We are checking that it is encoded in generic way not and not the subject to special case for Reference 'reference' field.
final ExpressionEncoder<PlanDefinition> encoder = fhirEncoders
.of(PlanDefinition.class);

final PlanDefinition planDefinition = new PlanDefinition();

final PlanDefinitionActionComponent actionComponent = planDefinition
.getActionFirstRep();
actionComponent.getConditionFirstRep().setExpression(new Expression().setLanguage("language")
.setExpression("expression").setDescription("description"));
assertSerDeIsIdentity(encoder, planDefinition);
}

@Test
public void testThrowsExceptionWhenUnsupportedResource() {
for (final String resourceName : EXCLUDED_RESOURCES) {
Expand Down Expand Up @@ -264,7 +329,7 @@ public void testQuantityArrayCanonicalization() {
final List<Row> properties = deviceRow.getList(deviceRow.fieldIndex("property"));
final Row propertyRow = properties.get(0);
final List<Row> quantityArray = propertyRow.getList(propertyRow.fieldIndex("valueQuantity"));

final Row quantity1 = quantityArray.get(0);
assertQuantity(quantity1, "0.0010", "m");

Expand Down
Loading
Loading