Skip to content

Commit

Permalink
Merge pull request #71 from richford/enh/unsupervised-load-afq-data
Browse files Browse the repository at this point in the history
ENH: Add unsupervised boolean parameter to load_afq_data
  • Loading branch information
richford authored Apr 20, 2021
2 parents 55bf0f6 + 65c0f56 commit 89e0761
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 23 deletions.
128 changes: 105 additions & 23 deletions afqinsight/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,42 @@

def load_afq_data(
workdir,
target_cols,
dwi_metrics=None,
target_cols=None,
label_encode_cols=None,
index_col="subjectID",
fn_nodes="nodes.csv",
fn_subjects="subjects.csv",
unsupervised=False,
concat_subject_session=False,
return_sessions=False,
):
"""Load AFQ data from CSV, transform it, return feature matrix and target.
This function expects a directory with a diffusion metric csv file
(specified by ``fn_nodes``) and, optionally, a phenotypic data file
(specified by ``fn_subjects``). The nodes csv file must be a long format
dataframe with the following columns: "subjectID," "nodeID," "tractID,"
an optional "sessionID". All other columns are assumed to be diffusion
metric columns, which can be optionally subset using the ``dwi_metrics``
parameter.
For supervised learning problems (with parameter ``unsupervised=False``)
this function will also load phenotypic targets from a subjects csv/tsv
file. This function will load the subject data, drop subjects that are
not found in the dwi feature matrix, and optionally label encode
categorical values.
Parameters
----------
workdir : str
Directory in which to find the AFQ csv files
target_cols : list of strings
dwi_metrics : list of strings, optional
List of diffusion metrics to extract from nodes csv.
e.g. ["dki_md", "dki_fa"]
target_cols : list of strings, optional
List of column names in subjects csv file to use as target variables
label_encode_cols : list of strings, subset of target_cols
Expand All @@ -44,13 +66,24 @@ def load_afq_data(
fn_subjects : str, default='subjects.csv'
Filename for the subjects csv file.
unsupervised : bool, default=False
If True, do not load target data from the ``fn_subjects`` file.
concat_subject_session : bool, default=False
If True, create new subject IDs by concatenating the existing subject
IDs with the session IDs. This is useful when subjects have multiple
sessions and you with to disambiguate between them.
return_sessions : bool, default=False
If True, return sessionID
Returns
-------
X : array-like of shape (n_samples, n_features)
The feature samples.
y : array-like of shape (n_samples,) or (n_samples, n_targets), optional
Target values.
Target values. Returned only if ``unsupervised`` is False
groups : list of numpy.ndarray
feature indices for each feature group
Expand All @@ -64,8 +97,12 @@ def load_afq_data(
subjects : list
Subject IDs
sessions : list
Session IDs. Returned only if ``return_sessions`` is True.
classes : dict
Class labels for each column specified in ``label_encode_cols``
Class labels for each column specified in ``label_encode_cols``.
Returned only if ``unsupervised`` is False
See Also
--------
Expand All @@ -76,26 +113,17 @@ def load_afq_data(
fn_subjects = op.join(workdir, fn_subjects)

nodes = pd.read_csv(fn_nodes)
targets = pd.read_csv(fn_subjects, index_col=index_col).drop(
["Unnamed: 0"], axis="columns"
)

y = targets.loc[:, target_cols]

classes = {}
if label_encode_cols is not None:
if not set(label_encode_cols) <= set(target_cols):
raise ValueError(
"label_encode_cols must be a subset of target_cols; "
"got {0} instead.".format(label_encode_cols)
)
unnamed_cols = [col for col in nodes.columns if "Unnamed:" in col]
nodes.drop(unnamed_cols, axis="columns", inplace=True)

le = LabelEncoder()
for col in label_encode_cols:
y.loc[:, col] = le.fit_transform(y[col])
classes[col] = le.classes_
sessions = nodes["sessionID"] if "sessionID" in nodes.columns else None
if concat_subject_session:
nodes["subjectID"] = nodes["subjectID"] + nodes["sessionID"].astype(str)

y = np.squeeze(y.values)
nodes.drop("sessionID", axis="columns", inplace=True, errors="ignore")

if dwi_metrics is not None:
nodes = nodes[["tractID", "nodeID", "subjectID"] + dwi_metrics]

mapper = AFQDataFrameMapper()
X = mapper.fit_transform(nodes)
Expand All @@ -104,7 +132,61 @@ def load_afq_data(
group_names = [tup[0:2] for tup in feature_names if tup[2] == 0]
subjects = mapper.subjects_

return X, y, groups, feature_names, group_names, subjects, classes
if unsupervised:
if return_sessions:
output = X, groups, feature_names, group_names, subjects, sessions
else:
output = X, groups, feature_names, group_names, subjects
else:
# Read using sep=None, engine="python" to allow for both csv and tsv
targets = pd.read_csv(
fn_subjects, sep=None, engine="python", index_col=index_col
)

# Drop unnamed columns
unnamed_cols = [col for col in targets.columns if "Unnamed:" in col]
targets.drop(unnamed_cols, axis="columns", inplace=True)

# Drop subjects that are not in the dwi feature matrix
targets = pd.DataFrame(index=subjects).merge(
targets, how="left", left_index=True, right_index=True
)

# Select user defined target columns
if target_cols is not None:
y = targets.loc[:, target_cols]

# Label encode the user-supplied categorical columns
classes = {}
if label_encode_cols is not None:
if not set(label_encode_cols) <= set(target_cols):
raise ValueError(
"label_encode_cols must be a subset of target_cols; "
"got {0} instead.".format(label_encode_cols)
)

le = LabelEncoder()
for col in label_encode_cols:
y.loc[:, col] = le.fit_transform(y[col])
classes[col] = le.classes_

y = np.squeeze(y.to_numpy())

if return_sessions:
output = (
X,
y,
groups,
feature_names,
group_names,
subjects,
sessions,
classes,
)
else:
output = X, y, groups, feature_names, group_names, subjects, classes

return output


def output_beta_to_afq(
Expand Down
34 changes: 34 additions & 0 deletions afqinsight/tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,40 @@
test_data_path = op.join(data_path, "test_data")


def test_load_afq_data_smoke():
output = load_afq_data(
workdir=test_data_path,
target_cols=["test_class"],
label_encode_cols=["test_class"],
)
assert len(output) == 7 # nosec

output = load_afq_data(
workdir=test_data_path,
target_cols=["test_class"],
label_encode_cols=["test_class"],
return_sessions=True,
)
assert len(output) == 8 # nosec

output = load_afq_data(
workdir=test_data_path,
target_cols=["test_class"],
label_encode_cols=["test_class"],
unsupervised=True,
)
assert len(output) == 5 # nosec

output = load_afq_data(
workdir=test_data_path,
target_cols=["test_class"],
label_encode_cols=["test_class"],
unsupervised=True,
return_sessions=True,
)
assert len(output) == 6 # nosec


def test_load_afq_data():
X, y, groups, feature_names, group_names, subjects, classes = load_afq_data(
workdir=test_data_path,
Expand Down

0 comments on commit 89e0761

Please sign in to comment.