From 38a1c50737be074638fe9ec0846f16a1de644da5 Mon Sep 17 00:00:00 2001 From: Justin Kiggins Date: Tue, 21 Nov 2017 22:49:33 -0800 Subject: [PATCH 1/6] refactor tests --- tests/test_calcium.py | 44 +++++++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/tests/test_calcium.py b/tests/test_calcium.py index de4a29c..c31efd9 100644 --- a/tests/test_calcium.py +++ b/tests/test_calcium.py @@ -1,5 +1,5 @@ -from neuroglia.calcium import MedianFilterDetrend, SavGolFilterDetrend -from neuroglia.calcium import OASISInferer +from neuroglia.calcium import MedianFilterDetrender, SavGolFilterDetrender +from neuroglia.calcium import CalciumDeconvolver from oasis.functions import gen_data from sklearn.base import clone @@ -10,13 +10,6 @@ import numpy.testing as npt import xarray.testing as xrt -# Test for proper parameter structure -def test_params(): - fn_list = [MedianFilterDetrend(), SavGolFilterDetrend(), OASISInferer()] - for fn in fn_list: - new_object_params = fn.get_params(deep=False) - for name, param in new_object_params.items(): - new_object_params[name] = clone(param, safe=False) # Test functions perform as expected true_b = 2 @@ -26,28 +19,33 @@ def test_params(): LBL = ['a', 'b', 'c'] sin_scale = 5 -data = y + sin_scale*np.sin(.05*TIME)[:,None] -DFF = pd.DataFrame(data, TIME, LBL) +# data = y +DFF = pd.DataFrame(y, TIME, LBL) +DFF_WITH_DRIFT = DFF.apply(lambda y: y + sin_scale*np.sin(.05*TIME),axis=0) assert np.all(np.mean(DFF) > 2) -def test_MedianFilterDetrend(): - tmp = MedianFilterDetrend().fit_transform(DFF) +def test_MedianFilterDetrender(): + detrender = MedianFilterDetrender() + tmp = detrender.fit_transform(DFF_WITH_DRIFT) assert np.all(np.isclose(np.mean(tmp), 0, atol=.1)) + clone(detrender) -def test_SavGolFilterDetrend(): - tmp = SavGolFilterDetrend().fit_transform(DFF) +def test_SavGolFilterDetrender(): + detrender = SavGolFilterDetrender() + tmp = detrender.fit_transform(DFF_WITH_DRIFT) assert np.all(np.isclose(np.mean(tmp), 0, atol=.1)) + clone(detrender) -def test_OASISInferer(): - tmp = OASISInferer().fit_transform(SavGolFilterDetrend().fit_transform(DFF)) - assert np.all(np.array([np.corrcoef(true_s[n], np.array(tmp[a]))[0][1] for n,a in zip(range(3), LBL)]) > 0.6) - tmp = OASISInferer().fit_transform(MedianFilterDetrend().fit_transform(DFF)) +def test_CalciumDeconvolver(): + deconvolver = CalciumDeconvolver() + tmp = deconvolver.fit_transform(DFF) assert np.all(np.array([np.corrcoef(true_s[n], np.array(tmp[a]))[0][1] for n,a in zip(range(3), LBL)]) > 0.6) + clone(deconvolver) if __name__ == '__main__': - test_MedianFilterDetrend() - test_SavGolFilterDetrend() - test_OASISInferer() - test_params() + test_MedianFilterDetrender() + test_SavGolFilterDetrender() + test_CalciumDeconvolver() + # test_params() From 852ab8f1f1105a0142869ffe58cdb64ce9ed9443 Mon Sep 17 00:00:00 2001 From: Justin Kiggins Date: Tue, 21 Nov 2017 22:50:12 -0800 Subject: [PATCH 2/6] refactor transformers --- neuroglia/calcium.py | 90 ++++++++++++++++++++------------------------ 1 file changed, 41 insertions(+), 49 deletions(-) diff --git a/neuroglia/calcium.py b/neuroglia/calcium.py index 2162cc3..a9d6524 100644 --- a/neuroglia/calcium.py +++ b/neuroglia/calcium.py @@ -4,7 +4,7 @@ from scipy.signal import medfilt, savgol_filter -class MedianFilterDetrend(BaseEstimator, TransformerMixin): +class MedianFilterDetrender(BaseEstimator, TransformerMixin): """ Median filter detrending """ @@ -15,7 +15,7 @@ def __init__(self, self.window = window self.peak_std_threshold = peak_std_threshold - def robust_std(self, x): + def _robust_std(self, x): ''' Robust estimate of std ''' @@ -23,22 +23,22 @@ def robust_std(self, x): return 1.4826*MAD def fit(self, X, y=None): - self.fit_params = {} return self def transform(self,X): + self.fit_params = {} X_new = X.copy() for col in X.columns: tmp_data = X[col].values.astype(np.double) mf = medfilt(tmp_data, self.window) - mf = np.minimum(mf, self.peak_std_threshold * self.robust_std(mf)) + mf = np.minimum(mf, self.peak_std_threshold * self._robust_std(mf)) self.fit_params[col] = dict(mf=mf) X_new[col] = tmp_data - mf return X_new -class SavGolFilterDetrend(BaseEstimator, TransformerMixin): +class SavGolFilterDetrender(BaseEstimator, TransformerMixin): """ Savitzky-Golay filter detrending """ @@ -50,10 +50,10 @@ def __init__(self, self.order = order def fit(self, X, y=None): - self.fit_params = {} return self def transform(self,X): + self.fit_params = {} X_new = X.copy() for col in X.columns: tmp_data = X[col].values.astype(np.double) @@ -64,19 +64,16 @@ def transform(self,X): return X_new -class EventRescale(BaseEstimator, TransformerMixin): +class EventRescaler(BaseEstimator, TransformerMixin): """ - Savitzky-Golay filter detrending + rescale events """ - def __init__(self, - log_transform=True, - scale=5): + def __init__(self,log_transform=True,scale=5): self.log_transform = log_transform self.scale = scale def fit(self, X, y=None): - self.fit_params = {} return self def transform(self,X): @@ -91,56 +88,51 @@ def transform(self,X): return X_new -class OASISInferer(BaseEstimator, TransformerMixin): + +def oasis_kwargs(penalty,indicator): + + kwargs = {} + + if penalty=='l0': + kwargs.update(penalty=0) + elif penalty=='l1': + kwargs.update(penalty=1) + # elif penalty=='l2': + # kwargs.update(penalty=2) + + if indicator.lower()=='gcamp6f': + kwargs.update(g=(None,)) + elif indicator.lower()=='gcamp6s': + kwargs.update(g=(None,None)) + + return kwargs + + +class CalciumDeconvolver(BaseEstimator, TransformerMixin): """docstring for OASISInferer.""" - def __init__(self, - output='spikes', - g=(None,), - sn=None, - b=None, - b_nonneg=True, - optimize_g=0, - penalty=0, - **kwargs - ): - super(OASISInferer, self).__init__() - - self.output = output - self.g = g - self.sn = sn - self.b = b - self.b_nonneg = b_nonneg - self.optimize_g = optimize_g + def __init__(self,penalty='l0',indicator='GCaMP6f'): self.penalty = penalty - self.kwargs = kwargs + self.indicator = indicator def fit(self, X, y=None): - self.fit_params = {} return self def transform(self,X): + kwargs = oasis_kwargs( + self.penalty, + self.indicator, + ) + X_new = X.copy() + self.fit_params = {} for col in X.columns: - c, s, b, g, lam = deconvolve( + denoised, spikes, b, g, lam = deconvolve( X[col].values.astype(np.double), - g = self.g, - sn = self.sn, - b = self.b, - b_nonneg = self.b_nonneg, - optimize_g = self.optimize_g, - penalty = self.penalty, - **self.kwargs - ) + **kwargs) self.fit_params[col] = dict(b=b,g=g,lam=lam,) - - if self.output=='denoised': - X_new[col] = c - elif self.output=='spikes': - X_new[col] = np.maximum(0, s) - else: - raise NotImplementedError + X_new[col] = spikes return X_new From 47c3423d2f909ba6b284c41aa7e7338d0e0bdd90 Mon Sep 17 00:00:00 2001 From: Justin Kiggins Date: Wed, 22 Nov 2017 21:48:53 -0800 Subject: [PATCH 3/6] docs & docs & more --- neuroglia/calcium.py | 147 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 136 insertions(+), 11 deletions(-) diff --git a/neuroglia/calcium.py b/neuroglia/calcium.py index a9d6524..b79ba59 100644 --- a/neuroglia/calcium.py +++ b/neuroglia/calcium.py @@ -23,6 +23,20 @@ def _robust_std(self, x): return 1.4826*MAD def fit(self, X, y=None): + """Do nothing and return the estimator unchanged + + This method is here to implement the scikit-learn API and work in + scikit-learn pipelines. + + Parameters + ---------- + X : array-like + + Returns + ------- + self + + """ return self def transform(self,X): @@ -50,6 +64,20 @@ def __init__(self, self.order = order def fit(self, X, y=None): + """Do nothing and return the estimator unchanged + + This method is here to implement the scikit-learn API and work in + scikit-learn pipelines. + + Parameters + ---------- + X : array-like + + Returns + ------- + self + + """ return self def transform(self,X): @@ -74,6 +102,20 @@ def __init__(self,log_transform=True,scale=5): self.scale = scale def fit(self, X, y=None): + """Do nothing and return the estimator unchanged + + This method is here to implement the scikit-learn API and work in + scikit-learn pipelines. + + Parameters + ---------- + X : array-like + + Returns + ------- + self + + """ return self def transform(self,X): @@ -89,7 +131,7 @@ def transform(self,X): -def oasis_kwargs(penalty,indicator): +def oasis_kwargs(penalty=None,model=None): kwargs = {} @@ -97,31 +139,84 @@ def oasis_kwargs(penalty,indicator): kwargs.update(penalty=0) elif penalty=='l1': kwargs.update(penalty=1) - # elif penalty=='l2': - # kwargs.update(penalty=2) - if indicator.lower()=='gcamp6f': + if model.lower()=='exponential': kwargs.update(g=(None,)) - elif indicator.lower()=='gcamp6s': + elif model.lower()=='double_exponential': kwargs.update(g=(None,None)) return kwargs class CalciumDeconvolver(BaseEstimator, TransformerMixin): - """docstring for OASISInferer.""" - def __init__(self,penalty='l0',indicator='GCaMP6f'): + """Deconvolve calcium traces to detect putative spiking events + + This transformer deconvolves each trace to yield a sparse trace where each + bin is weighted according to the likelihood of spiking events. + + We use the OASIS algorithm from https://github.com/j-friedrich/OASIS/ + + Note: you must install OASIS for `CalciumDeconvolver` to work. + + :: + pip install cython + pip install git+https://github.com/j-friedrich/OASIS.git + + Parameters + ---------- + penalty : {'l0', 'l1'} + Specify the norm used in the penalization when fitting. + model : {'exponential','double_exponential'} + What type of model to fit for the calcium dynamics. Typically, a fast + calcium indicator can be fit with the single 'exponential' model, + whereas an indicator with a slow rise will benefit from using the + 'double_exponential' model, which fits an exponential to the rise time + of the calcium response as well. + + Notes + ----- + + This estimator is stateless (besides constructor parameters), the + fit method does nothing but is useful when used in a pipeline. + """ + def __init__(self,penalty='l0',model='exponential',threshold=0.001): self.penalty = penalty - self.indicator = indicator + self.model = model + self.threshold = threshold def fit(self, X, y=None): + """Do nothing and return the estimator unchanged + + This method is here to implement the scikit-learn API and work in + scikit-learn pipelines. + + Parameters + ---------- + X : array-like + + Returns + ------- + self + + """ return self def transform(self,X): + """Deconvolve each column of X + + Parameters + ---------- + X : DataFrame in `traces` structure [n_samples, n_traces] + + Returns + ------- + Xt : DataFrame in `traces` structure [n_samples, n_traces] + The deconvolved data events. + """ kwargs = oasis_kwargs( self.penalty, - self.indicator, + self.model, ) X_new = X.copy() @@ -136,6 +231,22 @@ def transform(self,X): return X_new + def predict(self,X): + """Find spikes + + Parameters + ---------- + X : DataFrame in `traces` structure [n_samples, n_traces] + + Returns + ------- + y : DataFrame in `traces` structure [n_samples, n_traces] + Predicted spike events. + """ + y = self.transform(X) > self.threshold + return y + + def normalize_trace(trace, window=3, percentile=8): """ normalized the trace by substracting off a rolling baseline @@ -164,13 +275,27 @@ def normalize_trace(trace, window=3, percentile=8): class Normalize(BaseEstimator,TransformerMixin): - """docstring for Normalize.""" + """ Calculate rolling dF/F + """ def __init__(self, window=3.0, percentile=8): - super(Normalize, self).__init__() self.window = window self.percentile = percentile def fit(self, X, y=None): + """Do nothing and return the estimator unchanged + + This method is here to implement the scikit-learn API and work in + scikit-learn pipelines. + + Parameters + ---------- + X : array-like + + Returns + ------- + self + + """ return self def transform(self,X): From 804884c43d1e4677ddb4858db2a59d9dce59d661 Mon Sep 17 00:00:00 2001 From: Justin Kiggins Date: Wed, 22 Nov 2017 22:06:42 -0800 Subject: [PATCH 4/6] calcium deconvolver works as classifier --- neuroglia/calcium.py | 4 ++-- tests/test_calcium.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/neuroglia/calcium.py b/neuroglia/calcium.py index b79ba59..9460caf 100644 --- a/neuroglia/calcium.py +++ b/neuroglia/calcium.py @@ -1,5 +1,5 @@ import numpy as np -from sklearn.base import TransformerMixin, BaseEstimator +from sklearn.base import TransformerMixin, BaseEstimator, ClassifierMixin from oasis.functions import deconvolve from scipy.signal import medfilt, savgol_filter @@ -148,7 +148,7 @@ def oasis_kwargs(penalty=None,model=None): return kwargs -class CalciumDeconvolver(BaseEstimator, TransformerMixin): +class CalciumDeconvolver(BaseEstimator, TransformerMixin, ClassifierMixin): """Deconvolve calcium traces to detect putative spiking events This transformer deconvolves each trace to yield a sparse trace where each diff --git a/tests/test_calcium.py b/tests/test_calcium.py index c31efd9..189599f 100644 --- a/tests/test_calcium.py +++ b/tests/test_calcium.py @@ -41,6 +41,10 @@ def test_CalciumDeconvolver(): deconvolver = CalciumDeconvolver() tmp = deconvolver.fit_transform(DFF) assert np.all(np.array([np.corrcoef(true_s[n], np.array(tmp[a]))[0][1] for n,a in zip(range(3), LBL)]) > 0.6) + + acc = deconvolver.score(DFF,true_s.T>deconvolver.threshold) + print(acc) + clone(deconvolver) From 91c721070af9084a9346391c04078bd5dc15e0ef Mon Sep 17 00:00:00 2001 From: Justin Kiggins Date: Wed, 22 Nov 2017 22:33:41 -0800 Subject: [PATCH 5/6] more docs more docs --- neuroglia/calcium.py | 99 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 88 insertions(+), 11 deletions(-) diff --git a/neuroglia/calcium.py b/neuroglia/calcium.py index 9460caf..c151827 100644 --- a/neuroglia/calcium.py +++ b/neuroglia/calcium.py @@ -5,19 +5,25 @@ class MedianFilterDetrender(BaseEstimator, TransformerMixin): - """ - Median filter detrending + """Detrend the calcium signal using the local median + + Parameters + ---------- + window : int, optional (default: 101) + Number of samples to use to compute local median + peak_std_threshold : float, optional (default: 4.0) + If the median exceeds this threshold, it will be capped at this level. + """ def __init__(self, window=101, - peak_std_threshold=4): + peak_std_threshold=4.0): self.window = window self.peak_std_threshold = peak_std_threshold def _robust_std(self, x): - ''' - Robust estimate of std + '''Robust estimate of std ''' MAD = np.median(np.abs(x - np.median(x))) return 1.4826*MAD @@ -40,6 +46,17 @@ def fit(self, X, y=None): return self def transform(self,X): + """Detrend each column of X + + Parameters + ---------- + X : DataFrame in `traces` structure [n_samples, n_traces] + + Returns + ------- + Xt : DataFrame in `traces` structure [n_samples, n_traces] + The detrended data. + """ self.fit_params = {} X_new = X.copy() for col in X.columns: @@ -53,8 +70,15 @@ def transform(self,X): class SavGolFilterDetrender(BaseEstimator, TransformerMixin): - """ - Savitzky-Golay filter detrending + """Detrend the calcium signal using a Savitzky-Golay filter + + Parameters + ---------- + window : int, optional (default: 201) + Number of samples to use to build the Savitzky-Golay filter + order : int, optional (default: 3) + Order of the Savitzky-Golay filter + """ def __init__(self, window=201, @@ -81,6 +105,17 @@ def fit(self, X, y=None): return self def transform(self,X): + """Detrend each column of X + + Parameters + ---------- + X : DataFrame in `traces` structure [n_samples, n_traces] + + Returns + ------- + Xt : DataFrame in `traces` structure [n_samples, n_traces] + The detrended data. + """ self.fit_params = {} X_new = X.copy() for col in X.columns: @@ -93,8 +128,23 @@ def transform(self,X): class EventRescaler(BaseEstimator, TransformerMixin): - """ - rescale events + """Rescale detected calcium events + + Rescaling and log-transforming the output of the CalciumDeconvolver may + yield values closer to the number of spikes elicited in a sample bin + + This transformer multiplies the input values by `scale` then, if + `log_transform` is `True`, adds 1 and log-transforms the data. + + That is, if log_transform is True, it returns `np.log(1.0 + scale * X)`, + else it returns `scale * X` + + Parameters + ---------- + log_transform : boolean, optional (default: True) + Perform the log transform + scale : float, optional (default: 5.0) + Value to rescale the data before the log_transform """ def __init__(self,log_transform=True,scale=5): @@ -119,6 +169,17 @@ def fit(self, X, y=None): return self def transform(self,X): + """Rescale events in X + + Parameters + ---------- + X : DataFrame in `traces` structure [n_samples, n_traces] + + Returns + ------- + Xt : DataFrame in `traces` structure [n_samples, n_traces] + The rescaled data. + """ X_new = X.copy() for col in X.columns: tmp_data = X[col].values.astype(np.double) @@ -275,7 +336,14 @@ def normalize_trace(trace, window=3, percentile=8): class Normalize(BaseEstimator,TransformerMixin): - """ Calculate rolling dF/F + """ Normalize the trace by a rolling baseline (that is, calculate dF/F) + + Parameters + --------- + window: float, optional (default: 3.0) + time in minutes + percentile: int, optional (default: 8) + percentile to subtract off """ def __init__(self, window=3.0, percentile=8): self.window = window @@ -299,8 +367,17 @@ def fit(self, X, y=None): return self def transform(self,X): - # this is where the magic happens + """Normalize each column of X + + Parameters + ---------- + X : DataFrame in `traces` structure [n_samples, n_traces] + Returns + ------- + Xt : DataFrame in `traces` structure [n_samples, n_traces] + The normalized calcium traces. + """ df_norm = pd.DataFrame() for col in X.columns: df_norm[col] = normalize_trace( From 75dc977b0349475c6de2e22144b8cf204395681a Mon Sep 17 00:00:00 2001 From: Justin Kiggins Date: Wed, 22 Nov 2017 23:50:30 -0800 Subject: [PATCH 6/6] example runs --- examples/plot_calcium_deconvolve.py | 61 +++++++++++++++++++++++++++++ neuroglia/calcium.py | 2 +- 2 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 examples/plot_calcium_deconvolve.py diff --git a/examples/plot_calcium_deconvolve.py b/examples/plot_calcium_deconvolve.py new file mode 100644 index 0000000..bdc453b --- /dev/null +++ b/examples/plot_calcium_deconvolve.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +""" +Deconvolve synthetic calcium traces +============================== + +This is an example of how to infer spike events + +""" + +####################################################### +# First, we'll generate some fake data + +import numpy as np +import pandas as pd +from oasis.functions import gen_data + +neuron_ids = ['a', 'b', 'c'] +sampling_rate = 30.0 + +traces, _, spikes = map(np.squeeze, gen_data(N=3, b=2, seed=0)) + +time = np.arange(0, traces.shape[1]/sampling_rate, 1/sampling_rate) + +traces = pd.DataFrame(traces.T, index=time, columns=neuron_ids) +spikes = pd.DataFrame(spikes.T, index=time, columns=neuron_ids) + +######################################################## +# let's plot the data + +import matplotlib.pyplot as plt +traces.plot() +plt.show() + +########################################################## +# Now, we'll deconvolve the data + +from neuroglia.calcium import CalciumDeconvolver + +deconvolver = CalciumDeconvolver() +detected_events = deconvolver.transform(traces) + +for neuron in neuron_ids: + y_true = spikes[neuron] + y_pred = detected_events[neuron] + corr = np.corrcoef(y_pred,y_true)[0,1] + print("{}: {:0.2f}".format(neuron,corr)) + +detected_events.plot() +plt.show() + +########################################################## +# Now, we'll predict spikes + +spikes_pred = deconvolver.predict(traces) +spikes_true = (spikes>0).astype(int) + +for neuron in neuron_ids: + y_true = spikes_true[neuron] + y_pred = spikes_pred[neuron] + corr = np.corrcoef(y_pred,y_true)[0,1] + print("{}: {:0.2f}".format(neuron,corr)) diff --git a/neuroglia/calcium.py b/neuroglia/calcium.py index c151827..72b2e35 100644 --- a/neuroglia/calcium.py +++ b/neuroglia/calcium.py @@ -304,7 +304,7 @@ def predict(self,X): y : DataFrame in `traces` structure [n_samples, n_traces] Predicted spike events. """ - y = self.transform(X) > self.threshold + y = (self.transform(X) > self.threshold).astype(int) return y