Skip to content

Commit

Permalink
patient / person API endpoints for retrieving data (#206)
Browse files Browse the repository at this point in the history
## Description
Creating `GET /person/<ref-id>` and `GET /patient/<ref-id>` to retrieve
info on the existing MPI data elements.

## Related Issues
closes #162 

## Additional Notes
Adding a testing helper context manager to count database queries,
hoping this can be useful when we want to make assertions on how many
SQL queries are issued.
  • Loading branch information
ericbuckley authored Feb 11, 2025
1 parent 1af8368 commit 3f18299
Show file tree
Hide file tree
Showing 9 changed files with 221 additions and 8 deletions.
10 changes: 8 additions & 2 deletions src/recordlinker/database/mpi_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,15 @@ def get_patients_by_reference_ids(
) -> list[models.Patient | None]:
"""
Retrieve all the Patients by their reference ids. If a Patient is not found,
a None value will be returned in the list for that reference id.
a None value will be returned in the list for that reference id. Eagerly load
the Person associated with the Patient.
"""
query = select(models.Patient).where(models.Patient.reference_id.in_(reference_ids))
query = (
select(models.Patient)
.where(models.Patient.reference_id.in_(reference_ids))
.options(orm.joinedload(models.Patient.person))
)

patients_by_id: dict[uuid.UUID, models.Patient] = {
patient.reference_id: patient for patient in session.execute(query).scalars().all()
}
Expand Down
24 changes: 24 additions & 0 deletions src/recordlinker/routes/patient_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,30 @@ def create_patient(
)


@router.get(
"/{patient_reference_id}",
summary="Retrieve a patient record",
status_code=fastapi.status.HTTP_200_OK,
)
def get_patient(
patient_reference_id: uuid.UUID,
session: orm.Session = fastapi.Depends(get_session),
) -> schemas.PatientInfo:
"""
Retrieve an existing patient record in the MPI
"""
patient = service.get_patients_by_reference_ids(session, patient_reference_id)[0]
if patient is None:
raise fastapi.HTTPException(status_code=fastapi.status.HTTP_404_NOT_FOUND)

return schemas.PatientInfo(
patient_reference_id=patient.reference_id,
person_reference_id=patient.person.reference_id,
record=patient.record,
external_patient_id=patient.external_patient_id,
external_person_id=patient.external_person_id)


@router.patch(
"/{patient_reference_id}",
summary="Update a patient record",
Expand Down
22 changes: 22 additions & 0 deletions src/recordlinker/routes/person_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,25 @@ def update_person(

person = service.update_person_cluster(session, patients, person, commit=False)
return schemas.PersonRef(person_reference_id=person.reference_id)


@router.get(
"/{person_reference_id}",
summary="Retrieve a person cluster",
status_code=fastapi.status.HTTP_200_OK,
)
def get_person(
person_reference_id: uuid.UUID,
session: orm.Session = fastapi.Depends(get_session),
) -> schemas.PersonInfo:
"""
Retrieve an existing person cluster in the MPI
"""
person = service.get_person_by_reference_id(session, person_reference_id)
if person is None:
raise fastapi.HTTPException(status_code=fastapi.status.HTTP_404_NOT_FOUND)

return schemas.PersonInfo(
person_reference_id=person.reference_id,
patient_reference_ids=[patient.reference_id for patient in person.patients],
)
4 changes: 4 additions & 0 deletions src/recordlinker/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
from .link import MatchResponse
from .link import Prediction
from .mpi import PatientCreatePayload
from .mpi import PatientInfo
from .mpi import PatientPersonRef
from .mpi import PatientRef
from .mpi import PatientRefs
from .mpi import PatientUpdatePayload
from .mpi import PersonInfo
from .mpi import PersonRef
from .pii import Feature
from .pii import FeatureAttribute
Expand Down Expand Up @@ -44,6 +46,8 @@
"PatientRefs",
"PatientCreatePayload",
"PatientUpdatePayload",
"PatientInfo",
"PersonInfo",
"Cluster",
"ClusterGroup",
"PersonCluster",
Expand Down
13 changes: 13 additions & 0 deletions src/recordlinker/schemas/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,16 @@ def validate_both_not_empty(self) -> typing.Self:
if self.person_reference_id is None and self.record is None:
raise ValueError("at least one of person_reference_id or record must be provided")
return self


class PatientInfo(pydantic.BaseModel):
patient_reference_id: uuid.UUID
person_reference_id: uuid.UUID
record: PIIRecord
external_patient_id: str | None = None
external_person_id: str | None = None


class PersonInfo(pydantic.BaseModel):
person_reference_id: uuid.UUID
patient_reference_ids: list[uuid.UUID]
31 changes: 31 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import contextlib
import functools
import gzip
import json
import os
import pathlib

import pytest
import sqlalchemy
import sqlalchemy.event
from fastapi.testclient import TestClient

from recordlinker import database
Expand Down Expand Up @@ -82,3 +85,31 @@ def enhanced_algorithm():
for algo in utils.read_json("assets/initial_algorithms.json"):
if algo["label"] == "dibbs-enhanced":
return models.Algorithm.from_dict(**algo)


@contextlib.contextmanager
def count_queries(session):
"""
Context manager that counts the number of queries executed within the scope.
Usage:
```
with count_queries(session) as count:
session.query(...).all()
assert count() == 1
```
"""
query_count = 0

def _count(conn, cursor, statement, parameters, context, executemany):
nonlocal query_count
query_count += 1

# Attach the event listener
sqlalchemy.event.listen(sqlalchemy.Engine, "before_cursor_execute", _count)

try:
yield lambda: query_count
finally:
# Remove the event listener
sqlalchemy.event.remove(sqlalchemy.Engine, "before_cursor_execute", _count)
52 changes: 46 additions & 6 deletions tests/unit/database/test_mpi_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import pytest
import sqlalchemy.exc
from conftest import count_queries
from conftest import db_dialect

from recordlinker import models
Expand Down Expand Up @@ -317,10 +318,19 @@ def test_update_record(self, session):
patient = models.Patient(person=models.Person(), data={"sex": "M"})
session.add(patient)
session.flush()
session.add(models.BlockingValue(patient_id=patient.id, blockingkey=models.BlockingKey.SEX.id, value="M"))
record = schemas.PIIRecord(**{"name": [{"given": ["John"], "family": "Doe"}], "birthdate": "1980-01-01"})
session.add(
models.BlockingValue(
patient_id=patient.id, blockingkey=models.BlockingKey.SEX.id, value="M"
)
)
record = schemas.PIIRecord(
**{"name": [{"given": ["John"], "family": "Doe"}], "birthdate": "1980-01-01"}
)
patient = mpi_service.update_patient(session, patient, record=record)
assert patient.data == {"name": [{"given": ["John"], "family": "Doe"}], "birth_date": "1980-01-01"}
assert patient.data == {
"name": [{"given": ["John"], "family": "Doe"}],
"birth_date": "1980-01-01",
}
assert len(patient.blocking_values) == 3

def test_update_person(self, session):
Expand All @@ -346,7 +356,13 @@ def test_no_values(self, session):
other_patient = models.Patient()
session.add(other_patient)
session.flush()
session.add(models.BlockingValue(patient_id=other_patient.id, blockingkey=models.BlockingKey.FIRST_NAME.id, value="John"))
session.add(
models.BlockingValue(
patient_id=other_patient.id,
blockingkey=models.BlockingKey.FIRST_NAME.id,
value="John",
)
)
session.flush()
patient = models.Patient()
session.add(patient)
Expand All @@ -359,8 +375,16 @@ def test_with_values(self, session):
patient = models.Patient()
session.add(patient)
session.flush()
session.add(models.BlockingValue(patient_id=patient.id, blockingkey=models.BlockingKey.FIRST_NAME.id, value="John"))
session.add(models.BlockingValue(patient_id=patient.id, blockingkey=models.BlockingKey.LAST_NAME.id, value="Smith"))
session.add(
models.BlockingValue(
patient_id=patient.id, blockingkey=models.BlockingKey.FIRST_NAME.id, value="John"
)
)
session.add(
models.BlockingValue(
patient_id=patient.id, blockingkey=models.BlockingKey.LAST_NAME.id, value="Smith"
)
)
session.flush()
assert len(patient.blocking_values) == 2
mpi_service.delete_blocking_values_for_patient(session, patient)
Expand Down Expand Up @@ -724,6 +748,22 @@ def test_reference_ids(self, session):
session, uuid.uuid4(), patient.reference_id
) == [None, patient]

def test_eager_load_of_person(self, session):
pat_ref = uuid.uuid4()
per_ref = uuid.uuid4()
person = models.Person(reference_id=per_ref)
patient = models.Patient(person=person, reference_id=pat_ref, data={})
session.add(patient)
session.flush()
session.expire(person) # expiring the cache to fully test the query
session.expire(patient) # expiring the cache to fully test the query
with count_queries(session) as qcount:
pats = mpi_service.get_patients_by_reference_ids(session, pat_ref)
assert patient == pats[0]
assert per_ref == pats[0].person.reference_id
# assert only one query was made
assert qcount() == 1


class TestGetPersonByReferenceId:
def test_invalid_reference_id(self, session):
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/routes/test_patient_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,39 @@ def test_delete_patient(self, client):
.first()
)
assert patient is None


class TestGetPatient:
def test_invalid_reference_id(self, client):
response = client.get("/patient/123")
assert response.status_code == 422

def test_invalid_patient(self, client):
response = client.get(f"/patient/{uuid.uuid4()}")
assert response.status_code == 404

def test_get_patient(self, client):
patient = models.Patient(person=models.Person(), data={
"name": [{"given": ["John"], "family": "Doe"}],
}, external_patient_id="123", external_person_id="456")
client.session.add(patient)
client.session.flush()
response = client.get(f"/patient/{patient.reference_id}")
assert response.status_code == 200
assert response.json() == {
"patient_reference_id": str(patient.reference_id),
"person_reference_id": str(patient.person.reference_id),
"record": {
"external_id": None,
"birth_date": None,
"sex": None,
"address": [],
"name": [{"family": "Doe", "given": ["John"], "use": None, "prefix": [], "suffix": []}],
"telecom": [],
"race": None,
"identifiers": [],
},
"external_patient_id": "123",
"external_person_id": "456",
}

37 changes: 37 additions & 0 deletions tests/unit/routes/test_person_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,40 @@ def test_update_person(self, client):
assert resp.json()["person_reference_id"] == str(new_person.reference_id)
assert resp.json()["person_reference_id"] == str(pat1.person.reference_id)
assert resp.json()["person_reference_id"] == str(pat2.person.reference_id)


class TestGetPerson:
def test_invalid_person_id(self, client):
response = client.get("/person/123")
assert response.status_code == 422

def test_invalid_person(self, client):
response = client.get(f"/person/{uuid.uuid4()}")
assert response.status_code == 404

def test_empty_patients(self, client):
person = models.Person()
client.session.add(person)
client.session.flush()

response = client.get(f"/person/{person.reference_id}")
assert response.status_code == 200
assert response.json() == {
"person_reference_id": str(person.reference_id),
"patient_reference_ids": [],
}

def test_with_patients(self, client):
person = models.Person()
pat1 = models.Patient(person=person, data={})
client.session.add(pat1)
pat2 = models.Patient(person=person, data={})
client.session.add(pat2)
client.session.flush()

response = client.get(f"/person/{person.reference_id}")
assert response.status_code == 200
assert response.json() == {
"person_reference_id": str(person.reference_id),
"patient_reference_ids": [str(pat1.reference_id), str(pat2.reference_id)],
}

0 comments on commit 3f18299

Please sign in to comment.