Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factorize get subject session #1190

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ include = [
"python/lib/config_file.py",
"python/lib/env.py",
"python/lib/file_system.py",
"python/lib/get_subject_session.py",
"python/lib/logging.py",
"python/lib/make_env.py",
"python/lib/validate_subject_info.py",
Expand Down
6 changes: 6 additions & 0 deletions python/lib/database_lib/session_db.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""This class performs session table related database queries and common checks"""

from typing_extensions import deprecated

__license__ = "GPLv3"


@deprecated('Use `lib.db.models.session.DbSession` instead')
class SessionDB:
"""
This class performs database queries for session table.
Expand Down Expand Up @@ -35,6 +37,7 @@ def __init__(self, db, verbose):
self.db = db
self.verbose = verbose

@deprecated('Use `lib.db.queries.try_get_candidate_with_cand_id_visit_label` instead')
def create_session_dict(self, cand_id, visit_label):
"""
Queries the session table for a particular candidate ID and visit label and returns a dictionary
Expand All @@ -56,6 +59,7 @@ def create_session_dict(self, cand_id, visit_label):

return results[0] if results else None

@deprecated('Use `lib.db.queries.site.try_get_site_with_psc_id_visit_label` instead')
def get_session_center_info(self, pscid, visit_label):
"""
Get site information for a given visit.
Expand All @@ -77,6 +81,7 @@ def get_session_center_info(self, pscid, visit_label):

return results[0] if results else None

@deprecated('Use `lib.get_subject_session.get_candidate_next_visit_number` instead')
def determine_next_session_site_id_and_visit_number(self, cand_id):
"""
Determines the next session site and visit number based on the last session inserted for a given candidate.
Expand All @@ -99,6 +104,7 @@ def determine_next_session_site_id_and_visit_number(self, cand_id):

return results[0] if results else None

@deprecated('Use `lib.db.models.session.DbSession` instead')
def insert_into_session(self, fields, values):
"""
Insert a new row in the session table using fields list as column names and values as values.
Expand Down
3 changes: 1 addition & 2 deletions python/lib/db/models/notification_spool.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,4 @@ class DbNotificationSpool(Base):
origin : Mapped[Optional[str]] = mapped_column('Origin')
active : Mapped[bool] = mapped_column('Active', YNBool)

type : Mapped['db_notification_type.DbNotificationType'] \
= relationship('DbNotificationType')
type : Mapped['db_notification_type.DbNotificationType'] = relationship('DbNotificationType')
4 changes: 1 addition & 3 deletions python/lib/db/queries/site.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from sqlalchemy import select
from sqlalchemy.orm import Session as Database

from lib.db.models.candidate import DbCandidate
from lib.db.models.session import DbSession
from lib.db.models.site import DbSite

Expand All @@ -14,7 +13,6 @@ def try_get_site_with_cand_id_visit_label(db: Database, cand_id: int, visit_labe

return db.execute(select(DbSite)
.join(DbSession.site)
.join(DbSession.candidate)
.where(DbCandidate.cand_id == cand_id)
.where(DbSession.cand_id == cand_id)
.where(DbSession.visit_label == visit_label)
).scalar_one_or_none()
108 changes: 10 additions & 98 deletions python/lib/dcm2bids_imaging_pipeline_lib/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import lib.utilities
from lib.database import Database
from lib.database_lib.config import Config
from lib.db.queries.session import try_get_session_with_cand_id_visit_label
from lib.dicom_archive import DicomArchive
from lib.exception.determine_subject_info_error import DetermineSubjectInfoError
from lib.exception.validate_subject_info_error import ValidateSubjectInfoError
Expand Down Expand Up @@ -192,10 +193,14 @@ def determine_study_info(self):

# get the CenterID from the session table if the PSCID and visit label exists
# and could be extracted from the database
self.session_obj.create_session_dict(self.subject_info.cand_id, self.subject_info.visit_label)
session_dict = self.session_obj.session_info_dict
if session_dict:
return {"CenterName": session_dict["MRI_alias"], "CenterID": session_dict["CenterID"]}
self.session = try_get_session_with_cand_id_visit_label(
self.env.db,
self.subject_info.cand_id,
self.subject_info.visit_label,
)

if self.session is not None:
return {"CenterName": self.session.site.mri_alias, "CenterID": self.session.site_id}

# if could not find center information based on cand_id and visit_label, use the
# patient name to match it to the site alias or MRI alias
Expand Down Expand Up @@ -223,7 +228,7 @@ def determine_scanner_info(self):
self.dicom_archive_obj.tarchive_info_dict['ScannerSerialNumber'],
self.dicom_archive_obj.tarchive_info_dict['ScannerModel'],
self.site_dict['CenterID'],
self.session_obj.session_info_dict['ProjectID'] if self.session_obj.session_info_dict else None
self.session.project_id if self.session is not None else None,
)

log_verbose(self.env, f"Found Scanner ID: {scanner_id}")
Expand All @@ -248,99 +253,6 @@ def validate_subject_info(self):
upload_id=self.upload_id, fields=('IsCandidateInfoValidated',), values=('0',)
)

def get_session_info(self):
"""
Creates the session info dictionary based on entries found in the session table.
"""

self.session_obj.create_session_dict(self.subject_info.cand_id, self.subject_info.visit_label)

if self.session_obj.session_info_dict:
log_verbose(self.env, f"Session ID for the file to insert is {self.session_obj.session_info_dict['ID']}")

def create_session(self):
"""
Function that will create a new visit in the session table for the imaging scans after verification
that all the information necessary for the creation of the visit are present.
"""

create_visit = self.subject_info.create_visit

if create_visit is None:
log_error_exit(
self.env,
f"Visit {self.subject_info.visit_label} for candidate {self.subject_info.cand_id} does not exist.",
lib.exitcode.GET_SESSION_ID_FAILURE,
)

# check that the project ID and cohort ID refers to an existing row in project_cohort_rel table
self.session_obj.create_proj_cohort_rel_info_dict(create_visit.project_id, create_visit.cohort_id)
if not self.session_obj.proj_cohort_rel_info_dict.keys():
log_error_exit(
self.env,
(
f"Cannot create visit with project ID {create_visit.project_id}"
f" and cohort ID {create_visit.cohort_id}:"
f" no such association in table project_cohort_rel"
),
lib.exitcode.CREATE_SESSION_FAILURE,
)

# determine the visit number and center ID for the next session to be created
center_id, visit_nb = self.determine_new_session_site_and_visit_nb()
if not center_id:
log_error_exit(
self.env,
(
f"No center ID found for candidate {self.subject_info.cand_id}"
f", visit {self.subject_info.visit_label}"
)
)
else:
log_verbose(self.env, f"Set newVisitNo = {visit_nb} and center ID = {center_id}")

# create the new visit
session_id = self.session_obj.insert_into_session(
{
'CandID': self.subject_info.cand_id,
'Visit_label': self.subject_info.visit_label,
'CenterID': center_id,
'VisitNo': visit_nb,
'Current_stage': 'Not Started',
'Scan_done': 'Y',
'Submitted': 'N',
'CohortID': create_visit.cohort_id,
'ProjectID': create_visit.project_id
}
)
if session_id:
self.get_session_info()

def determine_new_session_site_and_visit_nb(self):
"""
Determines the site and visit number of the new session to be created.

:returns: The center ID and visit number of the future new session
"""
visit_nb = 0
center_id = 0

if self.subject_info.is_phantom:
center_info_dict = self.session_obj.get_session_center_info(
self.subject_info.psc_id, self.subject_info.visit_label,
)

if center_info_dict:
center_id = center_info_dict["CenterID"]
visit_nb = 1
else:
center_info_dict = self.session_obj.get_next_session_site_id_and_visit_number(self.subject_info.cand_id)
if center_info_dict:
center_id = center_info_dict["CenterID"]
visit_nb = center_info_dict["newVisitNo"]

return center_id, visit_nb

def check_if_tarchive_validated_in_db(self):
"""
Checks whether the DICOM archive was previously validated in the database (as per the value present
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def _move_and_update_dicom_archive(self):
archive_location = self.dicom_archive_obj.tarchive_info_dict["ArchiveLocation"]

fields_to_update = ("SessionID",)
values_for_update = (self.session_obj.session_id,)
values_for_update = (self.session.id,)
pattern = re.compile("^[0-9]{4}/")
if acq_date and not pattern.match(archive_location):
# move the DICOM archive into a year subfolder
Expand Down Expand Up @@ -412,7 +412,7 @@ def _update_mri_upload(self):
self.imaging_upload_obj.update_mri_upload(
upload_id=self.upload_id,
fields=("Inserting", "InsertionComplete", "number_of_mincInserted", "number_of_mincCreated", "SessionID"),
values=("0", "1", len(files_inserted_list), len(self.nifti_files_to_insert), self.session_obj.session_id)
values=("0", "1", len(files_inserted_list), len(self.nifti_files_to_insert), self.session.id)
)

def _get_summary_of_insertion(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lib.dcm2bids_imaging_pipeline_lib.base_pipeline import BasePipeline
from lib.exception.determine_subject_info_error import DetermineSubjectInfoError
from lib.exception.validate_subject_info_error import ValidateSubjectInfoError
from lib.get_subject_session import get_subject_session
from lib.logging import log_error_exit, log_verbose
from lib.validate_subject_info import validate_subject_info

Expand Down Expand Up @@ -110,9 +111,7 @@ def __init__(self, loris_getopt_obj, script_name):
# ---------------------------------------------------------------------------------------------
# Determine/create the session the file should be linked to
# ---------------------------------------------------------------------------------------------
self.get_session_info()
if not self.session_obj.session_info_dict:
self.create_session()
self.session = get_subject_session(self.env, self.subject_info)

# ---------------------------------------------------------------------------------------------
# Determine acquisition protocol (or register into mri_protocol_violated_scans and exits)
Expand Down Expand Up @@ -169,9 +168,9 @@ def __init__(self, loris_getopt_obj, script_name):
self.exclude_violations_list = []
if not self.bypass_extra_checks:
self.violations_summary = self.imaging_obj.run_extra_file_checks(
self.session_obj.session_info_dict['ProjectID'],
self.session_obj.session_info_dict['CohortID'],
self.session_obj.session_info_dict['Visit_label'],
self.session.project_id,
self.session.cohort_id,
self.session.visit_label,
self.scan_type_id,
self.json_file_dict
)
Expand Down Expand Up @@ -357,15 +356,15 @@ def _determine_acquisition_protocol(self):
self.json_file_dict['DeviceSerialNumber'],
self.json_file_dict['ManufacturersModelName'],
self.site_dict['CenterID'],
self.session_obj.session_info_dict['ProjectID']
self.session.project_id,
)

# get the list of lines in the mri_protocol table that apply to the given scan based on the protocol group
protocols_list = self.imaging_obj.get_list_of_eligible_protocols_based_on_session_info(
self.session_obj.session_info_dict['ProjectID'],
self.session_obj.session_info_dict['CohortID'],
self.session_obj.session_info_dict['CenterID'],
self.session_obj.session_info_dict['Visit_label'],
self.session.project_id,
self.session.cohort_id,
self.session.site_id,
self.session.visit_label,
self.scanner_id
)

Expand Down Expand Up @@ -458,7 +457,7 @@ def _determine_new_nifti_assembly_rel_path(self):
# determine NIfTI file name
new_nifti_name = self._construct_nifti_filename(file_bids_entities_dict)
already_inserted_filenames = self.imaging_obj.get_list_of_files_already_inserted_for_session_id(
self.session_obj.session_info_dict['ID']
self.session.id,
)
while new_nifti_name in already_inserted_filenames:
file_bids_entities_dict['run'] += 1
Expand Down Expand Up @@ -680,7 +679,7 @@ def _register_into_files_and_parameter_file(self, nifti_rel_path):
)

files_insert_info_dict = {
'SessionID': self.session_obj.session_info_dict['ID'],
'SessionID': self.session.id,
'File': nifti_rel_path,
'SeriesUID': scan_param['SeriesInstanceUID'] if 'SeriesInstanceUID' in scan_param.keys() else None,
'EchoTime': scan_param['EchoTime'] if 'EchoTime' in scan_param.keys() else None,
Expand Down
85 changes: 85 additions & 0 deletions python/lib/get_subject_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import cast

import lib.exitcode
from lib.config_file import SubjectInfo
from lib.db.models.candidate import DbCandidate
from lib.db.models.session import DbSession
from lib.db.queries.candidate import try_get_candidate_with_cand_id
from lib.db.queries.session import try_get_session_with_cand_id_visit_label
from lib.db.queries.site import try_get_site_with_cand_id_visit_label
from lib.env import Env
from lib.logging import log_error_exit, log_verbose


def get_candidate_next_visit_number(candidate: DbCandidate) -> int:
"""
Get the next visit number for a new session for a given candidate.
"""

visit_numbers = [session.visit_number for session in candidate.sessions if session.visit_number is not None]
return max(*visit_numbers, 0) + 1


def get_subject_session(env: Env, subject_info: SubjectInfo) -> DbSession:
"""
Get the imaging session corresponding to a given subject configuration.

This function first looks for an adequate session in the database, and returns it if one is
found. If no session is found, this function creates a new session in the database if the
subject configuration allows it, or exits the program otherwise.
"""

session = _get_subject_session(env, subject_info)
log_verbose(env, f"Session ID for the file to insert is {session.id}")
return session


def _get_subject_session(env: Env, subject_info: SubjectInfo) -> DbSession:
"""
Implementation of `get_subject_session`.
"""

session = try_get_session_with_cand_id_visit_label(env.db, subject_info.cand_id, subject_info.visit_label)
if session is not None:
return session

if subject_info.create_visit is None:
log_error_exit(
env,
f"Visit {subject_info.visit_label} for candidate {subject_info.cand_id} does not exist.",
lib.exitcode.GET_SESSION_ID_FAILURE,
)

if subject_info.is_phantom:
site = try_get_site_with_cand_id_visit_label(env.db, subject_info.cand_id, subject_info.visit_label)
visit_number = 1
else:
candidate = try_get_candidate_with_cand_id(env.db, subject_info.cand_id)
# Safe because it has been checked that the candidate exists in `validate_subject_info`
candidate = cast(DbCandidate, candidate)
site = candidate.registration_site
visit_number = get_candidate_next_visit_number(candidate)

if site is None:
log_error_exit(
env,
f"No center ID found for candidate {subject_info.cand_id}, visit {subject_info.visit_label}"
)

log_verbose(env, f"Set newVisitNo = {visit_number} and center ID = {site.id}")

session = DbSession(
cand_id = subject_info.cand_id,
site_id = site.id,
visit_number = visit_number,
current_stage = 'Not Started',
scan_done = True,
submitted = False,
project_id = subject_info.create_visit.project_id,
cohort_id = subject_info.create_visit.cohort_id,
)

env.db.add(session)
env.db.commit()

return session
Loading
Loading