Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fit, fit_transform, transform, predict, and score methods #119

Merged
merged 31 commits into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3342ef4
Add fit, fit_transform, transform, predict, and score methods to AFQD…
richford Jun 18, 2022
da8202a
DEP: Update groupyr dependency
richford Jun 19, 2022
41b1537
MAINT: Update .zenodo.json to add Jason and John
richford Jun 19, 2022
7c39696
BF: Add model_ prefix to the fit, transform, predict, etc. methods
richford Jun 19, 2022
a0a1d31
Use allclose in doctest instead of exact values
richford Jun 19, 2022
7dd8099
Update cross_validate doctest expected result
richford Jun 19, 2022
d567fc6
BF: Okay really correct the doctest values
richford Jun 19, 2022
0a96670
Use np.allclose in cross_validate doctest
richford Jun 19, 2022
2c6f1b7
Add tests of model fit, transform, predict, etc.
richford Jun 19, 2022
068226f
Add AFQDataset to doc/api.rst
richford Jun 19, 2022
ebf7c6d
Remove y param from model.transform
richford Jun 19, 2022
a6cacd1
Add a doc example to show manipulation of AFQDataset
richford Jun 20, 2022
110fa5d
STY: Fix flake error in demo_afq_dataset.py
richford Jun 20, 2022
3da02c9
Add plot_hbn_site_profiles example
richford Jun 21, 2022
2378560
Add s3fs to dev dependencies
richford Jun 21, 2022
f54c4c8
BF: Fix input checking for plot_bundle_profiles function
richford Jun 21, 2022
5de5c45
Undo the redundant commits to plot_bundle_profiles
richford Jun 21, 2022
d87fc35
Add test for AFQDataset.copy()
richford Jun 21, 2022
ef4e9a2
BF: Use equal_nan=True in unit tests for AFQDataset.copy()
richford Jun 21, 2022
2df83e0
Update afqinsight/datasets.py
richford Jun 21, 2022
644e6d0
Update afqinsight/datasets.py
richford Jun 21, 2022
877a5a0
Update afqinsight/datasets.py
richford Jun 21, 2022
31d1bcc
Update afqinsight/datasets.py
richford Jun 21, 2022
8b0511b
Update afqinsight/datasets.py
richford Jun 21, 2022
e0a3520
Use np.allclose to verify deep copy in AFQDataset.copy unit test
richford Jun 21, 2022
3354d87
Use ellipses and normalize whitespace in doctest for cross_validate
richford Jun 21, 2022
edce6a5
Merge branch 'enh/fit-on-dataset' of github.com:richford/AFQ-Insight …
richford Jun 21, 2022
7bcb14c
DOC: Incorporate @arokem's suggestions into autodoc examples
richford Jun 22, 2022
9e22072
STY: Fix flake8 trailing whitespace error
richford Jun 22, 2022
4356673
Update examples/plot_hbn_site_profiles.py
richford Jun 22, 2022
8c6c626
STY: change 'for' to 'to'
richford Jun 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .zenodo.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@
"affiliation": "The University of Washington",
"name": "Rokem, Ariel",
"orcid": "0000-0003-0679-1985"
},
{
"affiliation": "University of Washington",
"name": "Kruper, John",
"orcid": "0000-0003-0081-391X"
},
{
"affiliation": "Stanford University",
"name": "Yeatman, Jason",
"orcid": "0000-0002-2686-1293"
}
],
"description": "<p>AFQ-Insight is a Python library for statistical learning with tractometry data.</p>",
Expand Down
5 changes: 3 additions & 2 deletions afqinsight/cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def cross_validate_checkpoint(

Examples
--------
>>> import numpy as np
>>> import shutil
>>> import tempfile
>>> from sklearn import datasets, linear_model
Expand All @@ -364,8 +365,8 @@ def cross_validate_checkpoint(
>>> cv_results = cross_validate_checkpoint(lasso, X, y, cv=3, checkpoint=False)
>>> sorted(cv_results.keys())
['fit_time', 'score_time', 'test_score']
>>> cv_results['test_score']
array([0.33150734, 0.08022311, 0.03531764])
>>> cv_results['test_score'] # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
array([0.33150..., 0.08022..., 0.03531...])

Multiple metric evaluation using ``cross_validate``, an estimator
pipeline, and checkpointing (please refer the ``scoring`` parameter doc
Expand Down
130 changes: 130 additions & 0 deletions afqinsight/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,26 @@ def shape(self):
else:
return self.X.shape

def copy(self):
"""Return a deep copy of this dataset.

Returns
-------
AFQDataset
A deep copy of this dataset
"""
return AFQDataset(
X=self.X,
y=self.y,
groups=self.groups,
feature_names=self.feature_names,
target_cols=self.target_cols,
group_names=self.group_names,
subjects=self.subjects,
sessions=self.sessions,
classes=self.classes,
)

def bundle_means(self):
"""Return diffusion metrics averaged along the length of each bundle.

Expand Down Expand Up @@ -749,6 +769,116 @@ def as_tensorflow_dataset(self, bundles_as_channels=True, channels_last=True):
else:
return tf.data.Dataset.from_tensor_slices((X, self.y.astype(float)))

def model_fit(self, model, **fit_params):
"""Fit the dataset with a provided model object.

Parameters
----------
model : sklearn model
The estimator or transformer to fit

**fit_params : dict
Additional parameters to pass to the fit method

Returns
-------
model : object
The fitted model
"""
return model.fit(X=self.X, y=self.y, **fit_params)

def model_fit_transform(self, model, **fit_params):
"""Fit and transform the dataset with a provided model object.

Parameters
----------
model : sklearn model
The estimator or transformer to fit

**fit_params : dict
Additional parameters to pass to the fit_transform method

Returns
-------
dataset_new : AFQDataset
New AFQDataset with transformed features
"""
return AFQDataset(
X=model.fit_transform(X=self.X, y=self.y, **fit_params),
y=self.y,
groups=self.groups,
feature_names=self.feature_names,
target_cols=self.target_cols,
group_names=self.group_names,
subjects=self.subjects,
sessions=self.sessions,
classes=self.classes,
)

def model_transform(self, model, **transform_params):
"""Transform the dataset with a provided model object.

Parameters
----------
model : sklearn model
The estimator or transformer to use to transform the features

**transform_params : dict
Additional parameters to pass to the transform method

Returns
-------
dataset_new : AFQDataset
New AFQDataset with transformed features
"""
return AFQDataset(
X=model.transform(X=self.X, **transform_params),
y=self.y,
groups=self.groups,
feature_names=self.feature_names,
target_cols=self.target_cols,
group_names=self.group_names,
subjects=self.subjects,
sessions=self.sessions,
classes=self.classes,
)

def model_predict(self, model, **predict_params):
"""Predict the targets with a provided model object.

Parameters
----------
model : sklearn model
The estimator or transformer to use to predict the targets

**predict_params : dict
Additional parameters to pass to the predict method

Returns
-------
y_pred : ndarray
Predicted targets
"""
return model.predict(X=self.X, **predict_params)

def model_score(self, model, **score_params):
"""Score a model on this dataset.

Parameters
----------
model : sklearn model
The estimator or transformer to use to score the model

**score_params : dict
Additional parameters to pass to the `score` method, e.g., `sample_weight`

Returns
-------
score : float
The score of the model (e.g. R2, accuracy, etc.)
"""
return model.score(X=self.X, y=self.y, **score_params)


def _download_url_to_file(url, output_fn, encoding="utf-8", verbose=True):
fn_abs = op.abspath(output_fn)
Expand Down
65 changes: 65 additions & 0 deletions afqinsight/tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
AFQDataset,
standardize_subject_id,
)
from sklearn.impute import SimpleImputer
from sklearn.linear_model import Lasso

data_path = op.join(afqi.__path__[0], "data")
test_data_path = op.join(data_path, "test_data")
Expand Down Expand Up @@ -163,6 +165,69 @@ def test_AFQDataset_shape_len_index():
assert repr(dataset) == "AFQDataset(n_samples=10, n_features=4)" # nosec


def test_AFQDataset_fit_transform():
sarica_dir = download_sarica()
dataset = AFQDataset.from_files(
fn_nodes=op.join(sarica_dir, "nodes.csv"),
fn_subjects=op.join(sarica_dir, "subjects.csv"),
dwi_metrics=["md", "fa"],
target_cols=["class"],
label_encode_cols=["class"],
)

# Test that model_fit fits the imputer
imputer = dataset.model_fit(SimpleImputer())
assert np.allclose(imputer.statistics_, np.nanmean(dataset.X, axis=0))

# Test that model_transform imputes the data
dataset_imputed = dataset.model_transform(imputer)
assert np.allclose(dataset_imputed.X, imputer.transform(dataset.X))
richford marked this conversation as resolved.
Show resolved Hide resolved

# Test that fit_transform does the same as fit and then transform
dataset_transformed = dataset.model_fit_transform(SimpleImputer())
assert np.allclose(dataset_transformed.X, dataset_imputed.X)


def test_AFQDataset_copy():
wh_dir = download_weston_havens()
dataset_1 = AFQDataset.from_files(
fn_nodes=op.join(wh_dir, "nodes.csv"),
fn_subjects=op.join(wh_dir, "subjects.csv"),
dwi_metrics=["md", "fa"],
target_cols=["Age"],
)
dataset_2 = dataset_1.copy()

# Test that it copied
assert np.allclose(dataset_1.X, dataset_2.X, equal_nan=True)
assert dataset_1.groups == dataset_2.groups
assert dataset_1.group_names == dataset_2.group_names
assert dataset_1.subjects == dataset_2.subjects

# Test that it's a deep copy
dataset_1.X = np.zeros_like(dataset_2.X)
dataset_1.y = np.zeros_like(dataset_2.y)
assert not np.allclose(dataset_2.X, dataset_1.X, equal_nan=True)
assert not np.allclose(dataset_1.y, dataset_2.y, equal_nan=True)


def test_AFQDataset_predict_score():
wh_dir = download_weston_havens()
dataset = AFQDataset.from_files(
fn_nodes=op.join(wh_dir, "nodes.csv"),
fn_subjects=op.join(wh_dir, "subjects.csv"),
dwi_metrics=["md", "fa"],
target_cols=["Age"],
)
dataset = dataset.model_fit_transform(SimpleImputer(strategy="median"))
estimator = dataset.model_fit(Lasso())
y_pred = dataset.model_predict(estimator)
assert np.allclose(estimator.predict(dataset.X), y_pred)
assert np.allclose(
estimator.score(dataset.X, dataset.y), dataset.model_score(estimator)
)


def test_drop_target_na():
dataset = AFQDataset(X=np.random.rand(10, 4), y=np.random.rand(10))
dataset.y[:5] = np.nan
Expand Down
32 changes: 16 additions & 16 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ API Reference

.. currentmodule:: afqinsight

Datasets
========

This class encapsulates an AFQ dataset and has static methods to read data from csv files
conforming to the AFQ data standard.

.. autoclass:: AFQDataset

Pipelines
=========

Expand All @@ -13,22 +21,6 @@ These are AFQ-Insights recommended estimator pipelines.

.. autofunction:: make_afq_classifier_pipeline

Cross Validation
================

This function validates model performance using cross-validation, while
checkpointing the estimators and scores.

.. autofunction:: cross_validate_checkpoint

Dataset Loader
==============

This function reads data from csv files conforming to the AFQ data standard
and return feature and target matrices, grouping arrays, and subject IDs.

.. autofunction:: load_afq_data

Transformers
============

Expand All @@ -37,3 +29,11 @@ data format to feature matrices that are ready for ingestion into
sklearn-compatible pipelines.

.. autoclass:: AFQDataFrameMapper

Cross Validation
================

This function validates model performance using cross-validation, while
checkpointing the estimators and scores.

.. autofunction:: cross_validate_checkpoint
5 changes: 4 additions & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@
# source_encoding = 'utf-8-sig'

# Generate the plots for the gallery
plot_gallery = "True"
plot_gallery = True
gallery_conf = {
"filename_pattern": ["/plot", "/demo"],
richford marked this conversation as resolved.
Show resolved Hide resolved
}

# The master toctree document.
master_doc = "index"
Expand Down
Loading