Skip to content

Commit

Permalink
Merge pull request #80 from richford/enh/bundle-agg
Browse files Browse the repository at this point in the history
ENH: Allow bundle aggregation and put load_afq_data output into a named_tuple
  • Loading branch information
arokem authored Jul 9, 2021
2 parents 5f6c86b + 2f630d0 commit 4d923be
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 88 deletions.
116 changes: 71 additions & 45 deletions afqinsight/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,22 @@
from .transform import AFQDataFrameMapper

__all__ = ["load_afq_data", "output_beta_to_afq"]
DATA_DIR = op.join(op.expanduser("~"), ".cache", "afq-insight")
_DATA_DIR = op.join(op.expanduser("~"), ".cache", "afq-insight")
_FIELDS = [
"X",
"y",
"groups",
"feature_names",
"group_names",
"subjects",
"sessions",
"classes",
]
try:
AFQData = namedtuple("AFQData", _FIELDS, defaults=(None,) * 9)
except TypeError:
AFQData = namedtuple("AFQData", _FIELDS)
AFQData.__new__.__defaults__ = (None,) * len(AFQData._fields)


def load_afq_data(
Expand All @@ -27,6 +42,7 @@ def load_afq_data(
unsupervised=False,
concat_subject_session=False,
return_sessions=False,
return_bundle_means=False,
):
"""Load AFQ data from CSV, transform it, return feature matrix and target.
Expand Down Expand Up @@ -81,32 +97,39 @@ def load_afq_data(
return_sessions : bool, default=False
If True, return sessionID
return_bundle_means : bool, default=False
If True, return diffusion metrics averaged along the length of each
bundle.
Returns
-------
X : array-like of shape (n_samples, n_features)
The feature samples.
AFQData : namedtuple
A namedtuple with the fields:
y : array-like of shape (n_samples,) or (n_samples, n_targets), optional
Target values. Returned only if ``unsupervised`` is False
X : array-like of shape (n_samples, n_features)
The feature samples.
groups : list of numpy.ndarray
feature indices for each feature group
y : array-like of shape (n_samples,) or (n_samples, n_targets), optional
Target values. This will be None if ``unsupervised`` is True
feature_names : list of tuples
The multi-indexed columns of X
groups : list of numpy.ndarray
feature indices for each feature group
group_names : list of tuples
The multi-indexed groups of X
feature_names : list of tuples
The multi-indexed columns of X
subjects : list
Subject IDs
group_names : list of tuples
The multi-indexed groups of X
sessions : list
Session IDs. Returned only if ``return_sessions`` is True.
subjects : list
Subject IDs
classes : dict
Class labels for each column specified in ``label_encode_cols``.
Returned only if ``unsupervised`` is False
sessions : list
Session IDs. This will be None if ``return_sessions`` is False.
classes : dict
Class labels for each column specified in ``label_encode_cols``.
This will be None if ``unsupervised`` is True
See Also
--------
Expand All @@ -129,18 +152,24 @@ def load_afq_data(
if dwi_metrics is not None:
nodes = nodes[["tractID", "nodeID", "subjectID"] + dwi_metrics]

mapper = AFQDataFrameMapper()
if return_bundle_means:
mapper = AFQDataFrameMapper(bundle_agg_func="mean")
else:
mapper = AFQDataFrameMapper()

X = mapper.fit_transform(nodes)
subjects = mapper.subjects_
groups = mapper.groups_
feature_names = mapper.feature_names_
group_names = [tup[0:2] for tup in feature_names if tup[2] == 0]
subjects = mapper.subjects_

if return_bundle_means:
group_names = feature_names
else:
group_names = [tup[0:2] for tup in feature_names if tup[2] == 0]

if unsupervised:
if return_sessions:
output = X, groups, feature_names, group_names, subjects, sessions
else:
output = X, groups, feature_names, group_names, subjects
y = None
classes = None
else:
if target_cols is None:
raise ValueError(
Expand All @@ -150,6 +179,7 @@ def load_afq_data(
"load data for an unsupervised learning "
"problem, please set `unsupervised=True`."
)

# Read using sep=None, engine="python" to allow for both csv and tsv
targets = pd.read_csv(
fn_subjects, sep=None, engine="python", index_col=index_col
Expand All @@ -165,12 +195,11 @@ def load_afq_data(
)

# Select user defined target columns
if target_cols is not None:
y = targets.loc[:, target_cols]
y = targets.loc[:, target_cols]

# Label encode the user-supplied categorical columns
classes = {}
if label_encode_cols is not None:
classes = {}
if not set(label_encode_cols) <= set(target_cols):
raise ValueError(
"label_encode_cols must be a subset of target_cols; "
Expand All @@ -181,24 +210,21 @@ def load_afq_data(
for col in label_encode_cols:
y.loc[:, col] = le.fit_transform(y[col])
classes[col] = le.classes_
else:
classes = None

y = np.squeeze(y.to_numpy())

if return_sessions:
output = (
X,
y,
groups,
feature_names,
group_names,
subjects,
sessions,
classes,
)
else:
output = X, y, groups, feature_names, group_names, subjects, classes

return output
return AFQData(
X=X,
y=y,
groups=groups,
feature_names=feature_names,
group_names=group_names,
subjects=subjects,
sessions=sessions,
classes=classes,
)


def output_beta_to_afq(
Expand Down Expand Up @@ -427,7 +453,7 @@ def download_sarica(data_home=None):
Human Brain Mapping, vol. 38, pp. 727-739, 2017
DOI: 10.1002/hbm.23412
"""
data_home = data_home if data_home is not None else DATA_DIR
data_home = data_home if data_home is not None else _DATA_DIR
_download_afq_dataset("sarica", data_home=data_home)
return op.join(data_home, "sarica_data")

Expand Down Expand Up @@ -468,6 +494,6 @@ def download_weston_havens(data_home=None):
Nature Communications, vol. 5:1, pp. 4932, 2014
DOI: 10.1038/ncomms5932
"""
data_home = data_home if data_home is not None else DATA_DIR
data_home = data_home if data_home is not None else _DATA_DIR
_download_afq_dataset("weston_havens", data_home=data_home)
return op.join(data_home, "weston_havens_data")
92 changes: 60 additions & 32 deletions afqinsight/tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,51 @@

def test_fetch():
sarica_dir = download_sarica()
X, y, groups, feature_names, group_names, subjects, _ = load_afq_data(
X, y, groups, feature_names, group_names, subjects, _, _ = load_afq_data(
workdir=sarica_dir,
dwi_metrics=["md", "fa"],
target_cols=["class"],
label_encode_cols=["class"],
)

assert X.shape == (48, 4000)
assert y.shape == (48,)
assert len(groups) == 40
assert len(feature_names) == 4000
assert len(group_names) == 40
assert len(subjects) == 48
assert op.isfile(op.join(afqi.datasets.DATA_DIR, "sarica_data", "nodes.csv"))
assert op.isfile(op.join(afqi.datasets.DATA_DIR, "sarica_data", "subjects.csv"))
assert X.shape == (48, 4000) # nosec
assert y.shape == (48,) # nosec
assert len(groups) == 40 # nosec
assert len(feature_names) == 4000 # nosec
assert len(group_names) == 40 # nosec
assert len(subjects) == 48 # nosec
assert op.isfile(
op.join(afqi.datasets._DATA_DIR, "sarica_data", "nodes.csv")
) # nosec
assert op.isfile(
op.join(afqi.datasets._DATA_DIR, "sarica_data", "subjects.csv")
) # nosec

wh_dir = download_weston_havens()
X, y, groups, feature_names, group_names, subjects, classes = load_afq_data(
workdir=wh_dir,
dwi_metrics=["md", "fa"],
target_cols=["Age"],
X, y, groups, feature_names, group_names, subjects, _, _ = load_afq_data(
workdir=wh_dir, dwi_metrics=["md", "fa"], target_cols=["Age"]
)

assert X.shape == (77, 4000)
assert y.shape == (77,)
assert len(groups) == 40
assert len(feature_names) == 4000
assert len(group_names) == 40
assert len(subjects) == 77
assert op.isfile(op.join(afqi.datasets.DATA_DIR, "weston_havens_data", "nodes.csv"))
assert X.shape == (77, 4000) # nosec
assert y.shape == (77,) # nosec
assert len(groups) == 40 # nosec
assert len(feature_names) == 4000 # nosec
assert len(group_names) == 40 # nosec
assert len(subjects) == 77 # nosec
assert op.isfile(
op.join(afqi.datasets.DATA_DIR, "weston_havens_data", "subjects.csv")
)
op.join(afqi.datasets._DATA_DIR, "weston_havens_data", "nodes.csv")
) # nosec
assert op.isfile(
op.join(afqi.datasets._DATA_DIR, "weston_havens_data", "subjects.csv")
) # nosec

with tempfile.TemporaryDirectory() as td:
_ = download_sarica(data_home=td)
_ = download_weston_havens(data_home=td)
assert op.isfile(op.join(td, "sarica_data", "nodes.csv"))
assert op.isfile(op.join(td, "sarica_data", "subjects.csv"))
assert op.isfile(op.join(td, "weston_havens_data", "nodes.csv"))
assert op.isfile(op.join(td, "weston_havens_data", "subjects.csv"))
assert op.isfile(op.join(td, "sarica_data", "nodes.csv")) # nosec
assert op.isfile(op.join(td, "sarica_data", "subjects.csv")) # nosec
assert op.isfile(op.join(td, "weston_havens_data", "nodes.csv")) # nosec
assert op.isfile(op.join(td, "weston_havens_data", "subjects.csv")) # nosec


def test_load_afq_data_smoke():
Expand All @@ -62,7 +66,8 @@ def test_load_afq_data_smoke():
target_cols=["test_class"],
label_encode_cols=["test_class"],
)
assert len(output) == 7 # nosec
assert len(output) == 8 # nosec
assert output.sessions is None # nosec

output = load_afq_data(
workdir=test_data_path,
Expand All @@ -78,7 +83,10 @@ def test_load_afq_data_smoke():
label_encode_cols=["test_class"],
unsupervised=True,
)
assert len(output) == 5 # nosec
assert len(output) == 8 # nosec
assert output.y is None # nosec
assert output.classes is None # nosec
assert output.sessions is None # nosec

output = load_afq_data(
workdir=test_data_path,
Expand All @@ -87,14 +95,17 @@ def test_load_afq_data_smoke():
unsupervised=True,
return_sessions=True,
)
assert len(output) == 6 # nosec
assert len(output) == 8 # nosec
assert output.y is None # nosec
assert output.classes is None # nosec


def test_load_afq_data():
X, y, groups, feature_names, group_names, subjects, classes = load_afq_data(
(X, y, groups, feature_names, group_names, subjects, _, classes) = load_afq_data(
workdir=test_data_path,
target_cols=["test_class"],
label_encode_cols=["test_class"],
return_bundle_means=False,
)

nodes = pd.read_csv(op.join(test_data_path, "nodes.csv"))
Expand All @@ -112,7 +123,24 @@ def test_load_afq_data():
assert feature_names == cols_ref # nosec
assert group_names == [tup[0:2] for tup in cols_ref if tup[2] == 0] # nosec
assert set(subjects) == set(nodes.subjectID.unique()) # nosec
assert all(classes["test_class"] == np.array(["c0", "c1"]))
assert all(classes["test_class"] == np.array(["c0", "c1"])) # nosec

(X, y, groups, feature_names, group_names, subjects, _, classes) = load_afq_data(
workdir=test_data_path,
target_cols=["test_class"],
label_encode_cols=["test_class"],
return_bundle_means=True,
)

means_ref = (
nodes.groupby(["subjectID", "tractID"])
.agg("mean")
.drop("nodeID", axis="columns")
.unstack("tractID")
)
assert np.allclose(X, means_ref.to_numpy(), equal_nan=True) # nosec
assert group_names == means_ref.columns.to_list() # nosec
assert feature_names == means_ref.columns.to_list() # nosec

with pytest.raises(ValueError):
load_afq_data(
Expand All @@ -123,4 +151,4 @@ def test_load_afq_data():
with pytest.raises(ValueError) as ee:
load_afq_data(test_data_path)

assert "please set `unsupervised=True`" in str(ee.value)
assert "please set `unsupervised=True`" in str(ee.value) # nosec
30 changes: 30 additions & 0 deletions afqinsight/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,36 @@ def test_AFQDataFrameMapper():
assert set(subjects) == set(nodes.subjectID.unique()) # nosec


def test_AFQDataFrameMapper_mean():
nodes_path = op.join(test_data_path, "nodes.csv")
nodes = pd.read_csv(nodes_path)
transformer = AFQDataFrameMapper(bundle_agg_func="mean")
X = transformer.fit_transform(nodes)
groups = transformer.groups_
cols = transformer.feature_names_
subjects = transformer.subjects_

X_ref = (
nodes.groupby(["subjectID", "tractID"])
.agg("mean")
.drop("nodeID", axis="columns")
.unstack("tractID")
.to_numpy()
)
groups_ref = [np.array([idx]) for idx in range(X.shape[1])]
cols_ref = set(
[
tuple([item[0], item[1]])
for item in np.load(op.join(test_data_path, "test_transform_cols.npy"))
]
)

assert np.allclose(groups, groups_ref) # nosec
assert set(cols) == cols_ref # nosec
assert set(subjects) == set(nodes.subjectID.unique()) # nosec
assert np.allclose(X, X_ref, equal_nan=True) # nosec


def test_AFQDataFrameMapper_fit_transform():
nodes_path = op.join(test_data_path, "nodes.csv")
nodes = pd.read_csv(nodes_path)
Expand Down
Loading

0 comments on commit 4d923be

Please sign in to comment.