Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

patient / person API endpoints for retrieving data #206

Merged
merged 5 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/recordlinker/database/mpi_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,15 @@ def get_patients_by_reference_ids(
) -> list[models.Patient | None]:
"""
Retrieve all the Patients by their reference ids. If a Patient is not found,
a None value will be returned in the list for that reference id.
a None value will be returned in the list for that reference id. Eagerly load
the Person associated with the Patient.
"""
query = select(models.Patient).where(models.Patient.reference_id.in_(reference_ids))
query = (
select(models.Patient)
.where(models.Patient.reference_id.in_(reference_ids))
.options(orm.joinedload(models.Patient.person))
)

patients_by_id: dict[uuid.UUID, models.Patient] = {
patient.reference_id: patient for patient in session.execute(query).scalars().all()
}
Expand Down
24 changes: 24 additions & 0 deletions src/recordlinker/routes/patient_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,30 @@ def create_patient(
)


@router.get(
"/{patient_reference_id}",
summary="Retrieve a patient record",
status_code=fastapi.status.HTTP_200_OK,
)
def get_patient(
patient_reference_id: uuid.UUID,
session: orm.Session = fastapi.Depends(get_session),
) -> schemas.PatientInfo:
"""
Retrieve an existing patient record in the MPI
"""
patient = service.get_patients_by_reference_ids(session, patient_reference_id)[0]
if patient is None:
raise fastapi.HTTPException(status_code=fastapi.status.HTTP_404_NOT_FOUND)

return schemas.PatientInfo(
patient_reference_id=patient.reference_id,
person_reference_id=patient.person.reference_id,
record=patient.record,
external_patient_id=patient.external_patient_id,
external_person_id=patient.external_person_id)


@router.patch(
"/{patient_reference_id}",
summary="Update a patient record",
Expand Down
22 changes: 22 additions & 0 deletions src/recordlinker/routes/person_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,25 @@ def update_person(

person = service.update_person_cluster(session, patients, person, commit=False)
return schemas.PersonRef(person_reference_id=person.reference_id)


@router.get(
"/{person_reference_id}",
summary="Retrieve a person cluster",
status_code=fastapi.status.HTTP_200_OK,
)
def get_person(
person_reference_id: uuid.UUID,
session: orm.Session = fastapi.Depends(get_session),
) -> schemas.PersonInfo:
"""
Retrieve an existing person cluster in the MPI
"""
person = service.get_person_by_reference_id(session, person_reference_id)
if person is None:
raise fastapi.HTTPException(status_code=fastapi.status.HTTP_404_NOT_FOUND)

return schemas.PersonInfo(
person_reference_id=person.reference_id,
patient_reference_ids=[patient.reference_id for patient in person.patients],
)
4 changes: 4 additions & 0 deletions src/recordlinker/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
from .link import MatchResponse
from .link import Prediction
from .mpi import PatientCreatePayload
from .mpi import PatientInfo
from .mpi import PatientPersonRef
from .mpi import PatientRef
from .mpi import PatientRefs
from .mpi import PatientUpdatePayload
from .mpi import PersonInfo
from .mpi import PersonRef
from .pii import Feature
from .pii import FeatureAttribute
Expand Down Expand Up @@ -44,6 +46,8 @@
"PatientRefs",
"PatientCreatePayload",
"PatientUpdatePayload",
"PatientInfo",
"PersonInfo",
"Cluster",
"ClusterGroup",
"PersonCluster",
Expand Down
13 changes: 13 additions & 0 deletions src/recordlinker/schemas/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,16 @@ def validate_both_not_empty(self) -> typing.Self:
if self.person_reference_id is None and self.record is None:
raise ValueError("at least one of person_reference_id or record must be provided")
return self


class PatientInfo(pydantic.BaseModel):
patient_reference_id: uuid.UUID
person_reference_id: uuid.UUID
record: PIIRecord
external_patient_id: str | None = None
external_person_id: str | None = None


class PersonInfo(pydantic.BaseModel):
person_reference_id: uuid.UUID
patient_reference_ids: list[uuid.UUID]
31 changes: 31 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import contextlib
import functools
import gzip
import json
import os
import pathlib

import pytest
import sqlalchemy
import sqlalchemy.event
from fastapi.testclient import TestClient

from recordlinker import database
Expand Down Expand Up @@ -82,3 +85,31 @@ def enhanced_algorithm():
for algo in utils.read_json("assets/initial_algorithms.json"):
if algo["label"] == "dibbs-enhanced":
return models.Algorithm.from_dict(**algo)


@contextlib.contextmanager
def count_queries(session):
"""
Context manager that counts the number of queries executed within the scope.

Usage:
```
with count_queries(session) as count:
session.query(...).all()
assert count() == 1
```
"""
query_count = 0

def _count(conn, cursor, statement, parameters, context, executemany):
nonlocal query_count
query_count += 1

# Attach the event listener
sqlalchemy.event.listen(sqlalchemy.Engine, "before_cursor_execute", _count)

try:
yield lambda: query_count
finally:
# Remove the event listener
sqlalchemy.event.remove(sqlalchemy.Engine, "before_cursor_execute", _count)
52 changes: 46 additions & 6 deletions tests/unit/database/test_mpi_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import pytest
import sqlalchemy.exc
from conftest import count_queries
from conftest import db_dialect

from recordlinker import models
Expand Down Expand Up @@ -317,10 +318,19 @@ def test_update_record(self, session):
patient = models.Patient(person=models.Person(), data={"sex": "M"})
session.add(patient)
session.flush()
session.add(models.BlockingValue(patient_id=patient.id, blockingkey=models.BlockingKey.SEX.id, value="M"))
record = schemas.PIIRecord(**{"name": [{"given": ["John"], "family": "Doe"}], "birthdate": "1980-01-01"})
session.add(
models.BlockingValue(
patient_id=patient.id, blockingkey=models.BlockingKey.SEX.id, value="M"
)
)
record = schemas.PIIRecord(
**{"name": [{"given": ["John"], "family": "Doe"}], "birthdate": "1980-01-01"}
)
patient = mpi_service.update_patient(session, patient, record=record)
assert patient.data == {"name": [{"given": ["John"], "family": "Doe"}], "birth_date": "1980-01-01"}
assert patient.data == {
"name": [{"given": ["John"], "family": "Doe"}],
"birth_date": "1980-01-01",
}
assert len(patient.blocking_values) == 3

def test_update_person(self, session):
Expand All @@ -346,7 +356,13 @@ def test_no_values(self, session):
other_patient = models.Patient()
session.add(other_patient)
session.flush()
session.add(models.BlockingValue(patient_id=other_patient.id, blockingkey=models.BlockingKey.FIRST_NAME.id, value="John"))
session.add(
models.BlockingValue(
patient_id=other_patient.id,
blockingkey=models.BlockingKey.FIRST_NAME.id,
value="John",
)
)
session.flush()
patient = models.Patient()
session.add(patient)
Expand All @@ -359,8 +375,16 @@ def test_with_values(self, session):
patient = models.Patient()
session.add(patient)
session.flush()
session.add(models.BlockingValue(patient_id=patient.id, blockingkey=models.BlockingKey.FIRST_NAME.id, value="John"))
session.add(models.BlockingValue(patient_id=patient.id, blockingkey=models.BlockingKey.LAST_NAME.id, value="Smith"))
session.add(
models.BlockingValue(
patient_id=patient.id, blockingkey=models.BlockingKey.FIRST_NAME.id, value="John"
)
)
session.add(
models.BlockingValue(
patient_id=patient.id, blockingkey=models.BlockingKey.LAST_NAME.id, value="Smith"
)
)
session.flush()
assert len(patient.blocking_values) == 2
mpi_service.delete_blocking_values_for_patient(session, patient)
Expand Down Expand Up @@ -724,6 +748,22 @@ def test_reference_ids(self, session):
session, uuid.uuid4(), patient.reference_id
) == [None, patient]

def test_eager_load_of_person(self, session):
pat_ref = uuid.uuid4()
per_ref = uuid.uuid4()
person = models.Person(reference_id=per_ref)
patient = models.Patient(person=person, reference_id=pat_ref, data={})
session.add(patient)
session.flush()
session.expire(person) # expiring the cache to fully test the query
session.expire(patient) # expiring the cache to fully test the query
with count_queries(session) as qcount:
pats = mpi_service.get_patients_by_reference_ids(session, pat_ref)
assert patient == pats[0]
assert per_ref == pats[0].person.reference_id
# assert only one query was made
assert qcount() == 1


class TestGetPersonByReferenceId:
def test_invalid_reference_id(self, session):
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/routes/test_patient_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,39 @@ def test_delete_patient(self, client):
.first()
)
assert patient is None


class TestGetPatient:
def test_invalid_reference_id(self, client):
response = client.get("/patient/123")
assert response.status_code == 422

def test_invalid_patient(self, client):
response = client.get(f"/patient/{uuid.uuid4()}")
assert response.status_code == 404

def test_get_patient(self, client):
patient = models.Patient(person=models.Person(), data={
"name": [{"given": ["John"], "family": "Doe"}],
}, external_patient_id="123", external_person_id="456")
client.session.add(patient)
client.session.flush()
response = client.get(f"/patient/{patient.reference_id}")
assert response.status_code == 200
assert response.json() == {
"patient_reference_id": str(patient.reference_id),
"person_reference_id": str(patient.person.reference_id),
"record": {
"external_id": None,
"birth_date": None,
"sex": None,
"address": [],
"name": [{"family": "Doe", "given": ["John"], "use": None, "prefix": [], "suffix": []}],
"telecom": [],
"race": None,
"identifiers": [],
},
"external_patient_id": "123",
"external_person_id": "456",
}

37 changes: 37 additions & 0 deletions tests/unit/routes/test_person_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,40 @@ def test_update_person(self, client):
assert resp.json()["person_reference_id"] == str(new_person.reference_id)
assert resp.json()["person_reference_id"] == str(pat1.person.reference_id)
assert resp.json()["person_reference_id"] == str(pat2.person.reference_id)


class TestGetPerson:
def test_invalid_person_id(self, client):
response = client.get("/person/123")
assert response.status_code == 422

def test_invalid_person(self, client):
response = client.get(f"/person/{uuid.uuid4()}")
assert response.status_code == 404

def test_empty_patients(self, client):
person = models.Person()
client.session.add(person)
client.session.flush()

response = client.get(f"/person/{person.reference_id}")
assert response.status_code == 200
assert response.json() == {
"person_reference_id": str(person.reference_id),
"patient_reference_ids": [],
}

def test_with_patients(self, client):
person = models.Person()
pat1 = models.Patient(person=person, data={})
client.session.add(pat1)
pat2 = models.Patient(person=person, data={})
client.session.add(pat2)
client.session.flush()

response = client.get(f"/person/{person.reference_id}")
assert response.status_code == 200
assert response.json() == {
"person_reference_id": str(person.reference_id),
"patient_reference_ids": [str(pat1.reference_id), str(pat2.reference_id)],
}
Loading