diff --git a/src/recordlinker/database/mpi_service.py b/src/recordlinker/database/mpi_service.py index 4a942e57..707a1df3 100644 --- a/src/recordlinker/database/mpi_service.py +++ b/src/recordlinker/database/mpi_service.py @@ -151,6 +151,46 @@ def bulk_insert_patients( return patients +def update_patient( + session: orm.Session, + patient: models.Patient, + record: typing.Optional[schemas.PIIRecord] = None, + person: typing.Optional[models.Person] = None, + external_patient_id: typing.Optional[str] = None, + commit: bool = True, +) -> models.Patient: + """ + Updates an existing patient record in the database. + + :param session: The database session + :param patient: The Patient to update + :param record: Optional PIIRecord to update + :param person: Optional Person to associate with the Patient + :param external_patient_id: Optional external patient ID + :param commit: Whether to commit the transaction + + :returns: The updated Patient record + """ + if patient.id is None: + raise ValueError("Patient has not yet been inserted into the database") + + if record: + patient.record = record + delete_blocking_values_for_patient(session, patient, commit=False) + insert_blocking_values(session, [patient], commit=False) + + if person: + patient.person = person + + if external_patient_id is not None: + patient.external_patient_id = external_patient_id + + session.flush() + if commit: + session.commit() + return patient + + def insert_blocking_values( session: orm.Session, patients: typing.Sequence[models.Patient], @@ -190,6 +230,23 @@ def insert_blocking_values( session.commit() +def delete_blocking_values_for_patient( + session: orm.Session, patient: models.Patient, commit: bool = True +) -> None: + """ + Delete all BlockingValues for a given Patient. + + :param session: The database session + :param patient: The Patient to delete BlockingValues for + :param commit: Whether to commit the transaction + + :returns: None + """ + session.query(models.BlockingValue).filter(models.BlockingValue.patient_id == patient.id).delete() + if commit: + session.commit() + + def get_patient_by_reference_id( session: orm.Session, reference_id: uuid.UUID ) -> models.Patient | None: diff --git a/src/recordlinker/routes/patient_router.py b/src/recordlinker/routes/patient_router.py index 73478526..13a45fbe 100644 --- a/src/recordlinker/routes/patient_router.py +++ b/src/recordlinker/routes/patient_router.py @@ -6,6 +6,7 @@ the patient API endpoints. """ +import typing import uuid import fastapi @@ -65,6 +66,91 @@ def update_person( patient_reference_id=patient.reference_id, person_reference_id=person.reference_id ) + +@router.post( + "/", + summary="Create a patient record and link to an existing person", + status_code=fastapi.status.HTTP_201_CREATED, +) +def create_patient( + payload: typing.Annotated[schemas.PatientCreatePayload, fastapi.Body], + session: orm.Session = fastapi.Depends(get_session), +) -> schemas.PatientRef: + """ + Create a new patient record in the MPI and link to an existing person. + """ + person = service.get_person_by_reference_id(session, payload.person_reference_id) + + if person is None: + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=[ + { + "loc": ["body", "person_reference_id"], + "msg": "Person not found", + "type": "value_error", + } + ], + ) + + patient = service.insert_patient( + session, + payload.record, + person=person, + external_patient_id=payload.record.external_id, + commit=False, + ) + return schemas.PatientRef( + patient_reference_id=patient.reference_id, external_patient_id=patient.external_patient_id + ) + + +@router.patch( + "/{patient_reference_id}", + summary="Update a patient record", + status_code=fastapi.status.HTTP_200_OK, +) +def update_patient( + patient_reference_id: uuid.UUID, + payload: typing.Annotated[schemas.PatientUpdatePayload, fastapi.Body], + session: orm.Session = fastapi.Depends(get_session), +) -> schemas.PatientRef: + """ + Update an existing patient record in the MPI + """ + patient = service.get_patient_by_reference_id(session, patient_reference_id) + if patient is None: + raise fastapi.HTTPException(status_code=fastapi.status.HTTP_404_NOT_FOUND) + + person = None + if payload.person_reference_id: + person = service.get_person_by_reference_id(session, payload.person_reference_id) + if person is None: + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=[ + { + "loc": ["body", "person_reference_id"], + "msg": "Person not found", + "type": "value_error", + } + ], + ) + + external_patient_id = getattr(payload.record, "external_id", None) + patient = service.update_patient( + session, + patient, + person=person, + record=payload.record, + external_patient_id=external_patient_id, + commit=False, + ) + return schemas.PatientRef( + patient_reference_id=patient.reference_id, external_patient_id=patient.external_patient_id + ) + + @router.delete( "/{patient_reference_id}", summary="Delete a Patient", @@ -80,5 +166,5 @@ def delete_patient( if patient is None: raise fastapi.HTTPException(status_code=fastapi.status.HTTP_404_NOT_FOUND) - - return service.delete_patient(session, patient) \ No newline at end of file + + return service.delete_patient(session, patient) diff --git a/src/recordlinker/schemas/__init__.py b/src/recordlinker/schemas/__init__.py index 9eede76f..c121f5ea 100644 --- a/src/recordlinker/schemas/__init__.py +++ b/src/recordlinker/schemas/__init__.py @@ -9,8 +9,10 @@ from .link import MatchFhirResponse from .link import MatchResponse from .link import Prediction +from .mpi import PatientCreatePayload from .mpi import PatientPersonRef from .mpi import PatientRef +from .mpi import PatientUpdatePayload from .mpi import PersonRef from .pii import Feature from .pii import FeatureAttribute @@ -38,6 +40,8 @@ "PersonRef", "PatientRef", "PatientPersonRef", + "PatientCreatePayload", + "PatientUpdatePayload", "Cluster", "ClusterGroup", "PersonCluster", diff --git a/src/recordlinker/schemas/mpi.py b/src/recordlinker/schemas/mpi.py index 22972cbf..8cc3a7db 100644 --- a/src/recordlinker/schemas/mpi.py +++ b/src/recordlinker/schemas/mpi.py @@ -1,7 +1,10 @@ +import typing import uuid import pydantic +from .pii import PIIRecord + class PersonRef(pydantic.BaseModel): person_reference_id: uuid.UUID @@ -16,3 +19,22 @@ class PatientRef(pydantic.BaseModel): class PatientPersonRef(pydantic.BaseModel): patient_reference_id: uuid.UUID person_reference_id: uuid.UUID + + +class PatientCreatePayload(pydantic.BaseModel): + person_reference_id: uuid.UUID + record: PIIRecord + + +class PatientUpdatePayload(pydantic.BaseModel): + person_reference_id: uuid.UUID | None = None + record: PIIRecord | None = None + + @pydantic.model_validator(mode="after") + def validate_both_not_empty(self) -> typing.Self: + """ + Ensure that either person_reference_id or record is not None. + """ + if self.person_reference_id is None and self.record is None: + raise ValueError("at least one of person_reference_id or record must be provided") + return self diff --git a/tests/unit/database/test_mpi_service.py b/tests/unit/database/test_mpi_service.py index cf91f6fe..6b7a38d5 100644 --- a/tests/unit/database/test_mpi_service.py +++ b/tests/unit/database/test_mpi_service.py @@ -310,6 +310,65 @@ def test_error(self, session): assert mpi_service.bulk_insert_patients(session, []) +class TestUpdatePatient: + def test_no_patient(self, session): + with pytest.raises(ValueError): + mpi_service.update_patient(session, models.Patient(), schemas.PIIRecord()) + + def test_update_record(self, session): + patient = models.Patient(person=models.Person(), data={"sex": "M"}) + session.add(patient) + session.flush() + session.add(models.BlockingValue(patient_id=patient.id, blockingkey=models.BlockingKey.SEX.id, value="M")) + record = schemas.PIIRecord(**{"name": [{"given": ["John"], "family": "Doe"}], "birthdate": "1980-01-01"}) + patient = mpi_service.update_patient(session, patient, record=record) + assert patient.data == {"name": [{"given": ["John"], "family": "Doe"}], "birth_date": "1980-01-01"} + assert len(patient.blocking_values) == 3 + + def test_update_person(self, session): + person = models.Person() + session.add(person) + patient = models.Patient() + session.add(patient) + session.flush() + patient = mpi_service.update_patient(session, patient, person=person) + assert patient.person_id == person.id + + def test_update_external_patient_id(self, session): + patient = models.Patient() + session.add(patient) + session.flush() + + patient = mpi_service.update_patient(session, patient, external_patient_id="123") + assert patient.external_patient_id == "123" + + +class TestDeleteBlockingValuesForPatient: + def test_no_values(self, session): + other_patient = models.Patient() + session.add(other_patient) + session.flush() + session.add(models.BlockingValue(patient_id=other_patient.id, blockingkey=models.BlockingKey.FIRST_NAME.id, value="John")) + session.flush() + patient = models.Patient() + session.add(patient) + session.flush() + assert len(patient.blocking_values) == 0 + mpi_service.delete_blocking_values_for_patient(session, patient) + assert len(patient.blocking_values) == 0 + + def test_with_values(self, session): + patient = models.Patient() + session.add(patient) + session.flush() + session.add(models.BlockingValue(patient_id=patient.id, blockingkey=models.BlockingKey.FIRST_NAME.id, value="John")) + session.add(models.BlockingValue(patient_id=patient.id, blockingkey=models.BlockingKey.LAST_NAME.id, value="Smith")) + session.flush() + assert len(patient.blocking_values) == 2 + mpi_service.delete_blocking_values_for_patient(session, patient) + assert len(patient.blocking_values) == 0 + + class TestGetBlockData: @pytest.fixture def prime_index(self, session): diff --git a/tests/unit/routes/test_patient_router.py b/tests/unit/routes/test_patient_router.py index f869e510..bf6ac8ce 100644 --- a/tests/unit/routes/test_patient_router.py +++ b/tests/unit/routes/test_patient_router.py @@ -66,6 +66,106 @@ def test_update_person(self, client): assert resp.json()["patient_reference_id"] == str(patient.reference_id) assert resp.json()["person_reference_id"] == str(new_person.reference_id) + +class TestCreatePatient: + def test_missing_data(self, client): + response = client.post("/patient") + assert response.status_code == 422 + + def test_invalid_person(self, client): + data = {"person_reference_id": str(uuid.uuid4()), "record": {}} + response = client.post("/patient", json=data) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "loc": ["body", "person_reference_id"], + "msg": "Person not found", + "type": "value_error", + } + ] + } + + def test_create_patient(self, client): + person = models.Person() + client.session.add(person) + client.session.flush() + + data = { + "person_reference_id": str(person.reference_id), + "record": {"name": [{"given": ["John"], "family": "Doe"}], "external_id": "123"}, + } + response = client.post("/patient", json=data) + assert response.status_code == 201 + patient = client.session.query(models.Patient).first() + assert response.json() == { + "patient_reference_id": str(patient.reference_id), + "external_patient_id": "123", + } + assert len(patient.blocking_values) == 2 + assert patient.person == person + assert patient.data == data["record"] + + +class TestUpdatePatient: + def test_missing_data(self, client): + response = client.patch(f"/patient/{uuid.uuid4()}") + assert response.status_code == 422 + + def test_invalid_reference_id(self, client): + data = {"person_reference_id": str(uuid.uuid4())} + response = client.patch(f"/patient/{uuid.uuid4()}", json=data) + assert response.status_code == 404 + + def test_invalid_person(self, client): + patient = models.Patient() + client.session.add(patient) + client.session.flush() + + data = {"person_reference_id": str(uuid.uuid4())} + response = client.patch(f"/patient/{patient.reference_id}", json=data) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "loc": ["body", "person_reference_id"], + "msg": "Person not found", + "type": "value_error", + } + ] + } + + def test_no_data_to_update(self, client): + patient = models.Patient() + client.session.add(patient) + client.session.flush() + + response = client.patch(f"/patient/{patient.reference_id}", json={}) + assert response.status_code == 422 + + def test_update_patient(self, client): + person = models.Person() + client.session.add(person) + patient = models.Patient() + client.session.add(patient) + client.session.flush() + + data = { + "person_reference_id": str(person.reference_id), + "record": {"name": [{"given": ["John"], "family": "Doe"}], "external_id": "123"}, + } + response = client.patch(f"/patient/{patient.reference_id}", json=data) + assert response.status_code == 200 + assert response.json() == { + "patient_reference_id": str(patient.reference_id), + "external_patient_id": "123", + } + patient = client.session.get(models.Patient, patient.id) + assert len(patient.blocking_values) == 2 + assert patient.person == person + assert patient.data == data["record"] + + class TestDeletePatient: def test_invalid_reference_id(self, client): response = client.delete(f"/patient/{uuid.uuid4()}") @@ -80,6 +180,8 @@ def test_delete_patient(self, client): assert resp.status_code == 204 patient = ( - client.session.query(models.Patient).filter(models.Patient.reference_id == patient.reference_id).first() + client.session.query(models.Patient) + .filter(models.Patient.reference_id == patient.reference_id) + .first() ) assert patient is None