diff --git a/examples/plot_calcium_deconvolve.py b/examples/plot_calcium_deconvolve.py new file mode 100644 index 0000000..bdc453b --- /dev/null +++ b/examples/plot_calcium_deconvolve.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +""" +Deconvolve synthetic calcium traces +============================== + +This is an example of how to infer spike events + +""" + +####################################################### +# First, we'll generate some fake data + +import numpy as np +import pandas as pd +from oasis.functions import gen_data + +neuron_ids = ['a', 'b', 'c'] +sampling_rate = 30.0 + +traces, _, spikes = map(np.squeeze, gen_data(N=3, b=2, seed=0)) + +time = np.arange(0, traces.shape[1]/sampling_rate, 1/sampling_rate) + +traces = pd.DataFrame(traces.T, index=time, columns=neuron_ids) +spikes = pd.DataFrame(spikes.T, index=time, columns=neuron_ids) + +######################################################## +# let's plot the data + +import matplotlib.pyplot as plt +traces.plot() +plt.show() + +########################################################## +# Now, we'll deconvolve the data + +from neuroglia.calcium import CalciumDeconvolver + +deconvolver = CalciumDeconvolver() +detected_events = deconvolver.transform(traces) + +for neuron in neuron_ids: + y_true = spikes[neuron] + y_pred = detected_events[neuron] + corr = np.corrcoef(y_pred,y_true)[0,1] + print("{}: {:0.2f}".format(neuron,corr)) + +detected_events.plot() +plt.show() + +########################################################## +# Now, we'll predict spikes + +spikes_pred = deconvolver.predict(traces) +spikes_true = (spikes>0).astype(int) + +for neuron in neuron_ids: + y_true = spikes_true[neuron] + y_pred = spikes_pred[neuron] + corr = np.corrcoef(y_pred,y_true)[0,1] + print("{}: {:0.2f}".format(neuron,corr)) diff --git a/neuroglia/calcium.py b/neuroglia/calcium.py index 2162cc3..72b2e35 100644 --- a/neuroglia/calcium.py +++ b/neuroglia/calcium.py @@ -1,46 +1,84 @@ import numpy as np -from sklearn.base import TransformerMixin, BaseEstimator +from sklearn.base import TransformerMixin, BaseEstimator, ClassifierMixin from oasis.functions import deconvolve from scipy.signal import medfilt, savgol_filter -class MedianFilterDetrend(BaseEstimator, TransformerMixin): - """ - Median filter detrending +class MedianFilterDetrender(BaseEstimator, TransformerMixin): + """Detrend the calcium signal using the local median + + Parameters + ---------- + window : int, optional (default: 101) + Number of samples to use to compute local median + peak_std_threshold : float, optional (default: 4.0) + If the median exceeds this threshold, it will be capped at this level. + """ def __init__(self, window=101, - peak_std_threshold=4): + peak_std_threshold=4.0): self.window = window self.peak_std_threshold = peak_std_threshold - def robust_std(self, x): - ''' - Robust estimate of std + def _robust_std(self, x): + '''Robust estimate of std ''' MAD = np.median(np.abs(x - np.median(x))) return 1.4826*MAD def fit(self, X, y=None): - self.fit_params = {} + """Do nothing and return the estimator unchanged + + This method is here to implement the scikit-learn API and work in + scikit-learn pipelines. + + Parameters + ---------- + X : array-like + + Returns + ------- + self + + """ return self def transform(self,X): + """Detrend each column of X + + Parameters + ---------- + X : DataFrame in `traces` structure [n_samples, n_traces] + + Returns + ------- + Xt : DataFrame in `traces` structure [n_samples, n_traces] + The detrended data. + """ + self.fit_params = {} X_new = X.copy() for col in X.columns: tmp_data = X[col].values.astype(np.double) mf = medfilt(tmp_data, self.window) - mf = np.minimum(mf, self.peak_std_threshold * self.robust_std(mf)) + mf = np.minimum(mf, self.peak_std_threshold * self._robust_std(mf)) self.fit_params[col] = dict(mf=mf) X_new[col] = tmp_data - mf return X_new -class SavGolFilterDetrend(BaseEstimator, TransformerMixin): - """ - Savitzky-Golay filter detrending +class SavGolFilterDetrender(BaseEstimator, TransformerMixin): + """Detrend the calcium signal using a Savitzky-Golay filter + + Parameters + ---------- + window : int, optional (default: 201) + Number of samples to use to build the Savitzky-Golay filter + order : int, optional (default: 3) + Order of the Savitzky-Golay filter + """ def __init__(self, window=201, @@ -50,10 +88,35 @@ def __init__(self, self.order = order def fit(self, X, y=None): - self.fit_params = {} + """Do nothing and return the estimator unchanged + + This method is here to implement the scikit-learn API and work in + scikit-learn pipelines. + + Parameters + ---------- + X : array-like + + Returns + ------- + self + + """ return self def transform(self,X): + """Detrend each column of X + + Parameters + ---------- + X : DataFrame in `traces` structure [n_samples, n_traces] + + Returns + ------- + Xt : DataFrame in `traces` structure [n_samples, n_traces] + The detrended data. + """ + self.fit_params = {} X_new = X.copy() for col in X.columns: tmp_data = X[col].values.astype(np.double) @@ -64,22 +127,59 @@ def transform(self,X): return X_new -class EventRescale(BaseEstimator, TransformerMixin): - """ - Savitzky-Golay filter detrending +class EventRescaler(BaseEstimator, TransformerMixin): + """Rescale detected calcium events + + Rescaling and log-transforming the output of the CalciumDeconvolver may + yield values closer to the number of spikes elicited in a sample bin + + This transformer multiplies the input values by `scale` then, if + `log_transform` is `True`, adds 1 and log-transforms the data. + + That is, if log_transform is True, it returns `np.log(1.0 + scale * X)`, + else it returns `scale * X` + + Parameters + ---------- + log_transform : boolean, optional (default: True) + Perform the log transform + scale : float, optional (default: 5.0) + Value to rescale the data before the log_transform """ - def __init__(self, - log_transform=True, - scale=5): + def __init__(self,log_transform=True,scale=5): self.log_transform = log_transform self.scale = scale def fit(self, X, y=None): - self.fit_params = {} + """Do nothing and return the estimator unchanged + + This method is here to implement the scikit-learn API and work in + scikit-learn pipelines. + + Parameters + ---------- + X : array-like + + Returns + ------- + self + + """ return self def transform(self,X): + """Rescale events in X + + Parameters + ---------- + X : DataFrame in `traces` structure [n_samples, n_traces] + + Returns + ------- + Xt : DataFrame in `traces` structure [n_samples, n_traces] + The rescaled data. + """ X_new = X.copy() for col in X.columns: tmp_data = X[col].values.astype(np.double) @@ -91,59 +191,123 @@ def transform(self,X): return X_new -class OASISInferer(BaseEstimator, TransformerMixin): - """docstring for OASISInferer.""" - def __init__(self, - output='spikes', - g=(None,), - sn=None, - b=None, - b_nonneg=True, - optimize_g=0, - penalty=0, - **kwargs - ): - super(OASISInferer, self).__init__() - - self.output = output - self.g = g - self.sn = sn - self.b = b - self.b_nonneg = b_nonneg - self.optimize_g = optimize_g + +def oasis_kwargs(penalty=None,model=None): + + kwargs = {} + + if penalty=='l0': + kwargs.update(penalty=0) + elif penalty=='l1': + kwargs.update(penalty=1) + + if model.lower()=='exponential': + kwargs.update(g=(None,)) + elif model.lower()=='double_exponential': + kwargs.update(g=(None,None)) + + return kwargs + + +class CalciumDeconvolver(BaseEstimator, TransformerMixin, ClassifierMixin): + """Deconvolve calcium traces to detect putative spiking events + + This transformer deconvolves each trace to yield a sparse trace where each + bin is weighted according to the likelihood of spiking events. + + We use the OASIS algorithm from https://github.com/j-friedrich/OASIS/ + + Note: you must install OASIS for `CalciumDeconvolver` to work. + + :: + pip install cython + pip install git+https://github.com/j-friedrich/OASIS.git + + Parameters + ---------- + penalty : {'l0', 'l1'} + Specify the norm used in the penalization when fitting. + model : {'exponential','double_exponential'} + What type of model to fit for the calcium dynamics. Typically, a fast + calcium indicator can be fit with the single 'exponential' model, + whereas an indicator with a slow rise will benefit from using the + 'double_exponential' model, which fits an exponential to the rise time + of the calcium response as well. + + Notes + ----- + + This estimator is stateless (besides constructor parameters), the + fit method does nothing but is useful when used in a pipeline. + """ + def __init__(self,penalty='l0',model='exponential',threshold=0.001): self.penalty = penalty - self.kwargs = kwargs + self.model = model + self.threshold = threshold def fit(self, X, y=None): - self.fit_params = {} + """Do nothing and return the estimator unchanged + + This method is here to implement the scikit-learn API and work in + scikit-learn pipelines. + + Parameters + ---------- + X : array-like + + Returns + ------- + self + + """ return self def transform(self,X): + """Deconvolve each column of X + + Parameters + ---------- + X : DataFrame in `traces` structure [n_samples, n_traces] + + Returns + ------- + Xt : DataFrame in `traces` structure [n_samples, n_traces] + The deconvolved data events. + """ + + kwargs = oasis_kwargs( + self.penalty, + self.model, + ) X_new = X.copy() + self.fit_params = {} for col in X.columns: - c, s, b, g, lam = deconvolve( + denoised, spikes, b, g, lam = deconvolve( X[col].values.astype(np.double), - g = self.g, - sn = self.sn, - b = self.b, - b_nonneg = self.b_nonneg, - optimize_g = self.optimize_g, - penalty = self.penalty, - **self.kwargs - ) + **kwargs) self.fit_params[col] = dict(b=b,g=g,lam=lam,) - - if self.output=='denoised': - X_new[col] = c - elif self.output=='spikes': - X_new[col] = np.maximum(0, s) - else: - raise NotImplementedError + X_new[col] = spikes return X_new + def predict(self,X): + """Find spikes + + Parameters + ---------- + X : DataFrame in `traces` structure [n_samples, n_traces] + + Returns + ------- + y : DataFrame in `traces` structure [n_samples, n_traces] + Predicted spike events. + """ + y = (self.transform(X) > self.threshold).astype(int) + return y + + def normalize_trace(trace, window=3, percentile=8): """ normalized the trace by substracting off a rolling baseline @@ -172,18 +336,48 @@ def normalize_trace(trace, window=3, percentile=8): class Normalize(BaseEstimator,TransformerMixin): - """docstring for Normalize.""" + """ Normalize the trace by a rolling baseline (that is, calculate dF/F) + + Parameters + --------- + window: float, optional (default: 3.0) + time in minutes + percentile: int, optional (default: 8) + percentile to subtract off + """ def __init__(self, window=3.0, percentile=8): - super(Normalize, self).__init__() self.window = window self.percentile = percentile def fit(self, X, y=None): + """Do nothing and return the estimator unchanged + + This method is here to implement the scikit-learn API and work in + scikit-learn pipelines. + + Parameters + ---------- + X : array-like + + Returns + ------- + self + + """ return self def transform(self,X): - # this is where the magic happens + """Normalize each column of X + + Parameters + ---------- + X : DataFrame in `traces` structure [n_samples, n_traces] + Returns + ------- + Xt : DataFrame in `traces` structure [n_samples, n_traces] + The normalized calcium traces. + """ df_norm = pd.DataFrame() for col in X.columns: df_norm[col] = normalize_trace( diff --git a/tests/test_calcium.py b/tests/test_calcium.py index de4a29c..189599f 100644 --- a/tests/test_calcium.py +++ b/tests/test_calcium.py @@ -1,5 +1,5 @@ -from neuroglia.calcium import MedianFilterDetrend, SavGolFilterDetrend -from neuroglia.calcium import OASISInferer +from neuroglia.calcium import MedianFilterDetrender, SavGolFilterDetrender +from neuroglia.calcium import CalciumDeconvolver from oasis.functions import gen_data from sklearn.base import clone @@ -10,13 +10,6 @@ import numpy.testing as npt import xarray.testing as xrt -# Test for proper parameter structure -def test_params(): - fn_list = [MedianFilterDetrend(), SavGolFilterDetrend(), OASISInferer()] - for fn in fn_list: - new_object_params = fn.get_params(deep=False) - for name, param in new_object_params.items(): - new_object_params[name] = clone(param, safe=False) # Test functions perform as expected true_b = 2 @@ -26,28 +19,37 @@ def test_params(): LBL = ['a', 'b', 'c'] sin_scale = 5 -data = y + sin_scale*np.sin(.05*TIME)[:,None] -DFF = pd.DataFrame(data, TIME, LBL) +# data = y +DFF = pd.DataFrame(y, TIME, LBL) +DFF_WITH_DRIFT = DFF.apply(lambda y: y + sin_scale*np.sin(.05*TIME),axis=0) assert np.all(np.mean(DFF) > 2) -def test_MedianFilterDetrend(): - tmp = MedianFilterDetrend().fit_transform(DFF) +def test_MedianFilterDetrender(): + detrender = MedianFilterDetrender() + tmp = detrender.fit_transform(DFF_WITH_DRIFT) assert np.all(np.isclose(np.mean(tmp), 0, atol=.1)) + clone(detrender) -def test_SavGolFilterDetrend(): - tmp = SavGolFilterDetrend().fit_transform(DFF) +def test_SavGolFilterDetrender(): + detrender = SavGolFilterDetrender() + tmp = detrender.fit_transform(DFF_WITH_DRIFT) assert np.all(np.isclose(np.mean(tmp), 0, atol=.1)) + clone(detrender) -def test_OASISInferer(): - tmp = OASISInferer().fit_transform(SavGolFilterDetrend().fit_transform(DFF)) - assert np.all(np.array([np.corrcoef(true_s[n], np.array(tmp[a]))[0][1] for n,a in zip(range(3), LBL)]) > 0.6) - tmp = OASISInferer().fit_transform(MedianFilterDetrend().fit_transform(DFF)) +def test_CalciumDeconvolver(): + deconvolver = CalciumDeconvolver() + tmp = deconvolver.fit_transform(DFF) assert np.all(np.array([np.corrcoef(true_s[n], np.array(tmp[a]))[0][1] for n,a in zip(range(3), LBL)]) > 0.6) + acc = deconvolver.score(DFF,true_s.T>deconvolver.threshold) + print(acc) + + clone(deconvolver) + if __name__ == '__main__': - test_MedianFilterDetrend() - test_SavGolFilterDetrend() - test_OASISInferer() - test_params() + test_MedianFilterDetrender() + test_SavGolFilterDetrender() + test_CalciumDeconvolver() + # test_params()