-
Notifications
You must be signed in to change notification settings - Fork 0
/
transforms.py
300 lines (230 loc) · 8.94 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
"""
Functions for computing coordinate transformations, derivatives and
sample means of parameters and data.
---
State-Space Analysis of Spike Correlations (Shimazaki et al. PLoS Comp Bio 2012)
Copyright (C) 2014 Thomas Sharp (thomas.sharp@riken.jp)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import itertools
import numpy
import pdb
import scipy.misc
# Matrix to map from theta to probability
p_map = None
# Matrix to map from probability to eta
eta_map = None
def compute_D(N, O):
"""
Computes the number of natural parameters for the given number of cells and
order of interactions.
:param int N:
Total cells in the spike data.
:param int order:
Order of spike-train interactions to estimate, for example, 2 =
pairwise, 3 = triplet-wise...
:returns:
Number of natural parameters for the spike-pattern interactions.
"""
D = numpy.sum([scipy.misc.comb(N, k, exact=1) for k in xrange(1, O + 1)])
return D
def compute_eta(p):
"""
Computes the expected values, eta, of spike patterns
n_0,1 = p(0,1) + p(1,1) # for example
from the supplied probability mass.
:param numpy.ndarray p:
Probability mass of spike patterns.
:returns:
Expected values of spike patterns (eta) as a numpy.ndarray.
"""
global eta_map
eta = numpy.dot(eta_map, p)
return eta
def compute_fisher_info(p, eta):
"""
Computes the Fisher-information matrix of the expected values, eta, of spike
patterns for the purposes of Newton-Raphson gradient-ascent and
computation of the marginal probability distribution. For example, for two
neurons:
H = [n1 - n1^2, n12 - n1 * n2, n12 - n1 * n12,
n12 - n1 * n2, n2 - n2^2, n12 - n2 * n12,
n12 - n1 * n12, n12 - n2 * n12, n12 - n12^2]
:param numpy.ndarray p:
Probability mass of spike patterns.
:param numpy.ndarray eta:
Expected values of spike patterns.
:returns:
Fisher-information matrix as a numpy.ndarray.
"""
global p_map, eta_map
# Stack columns of p for next step
p_stack = numpy.repeat(p, eta.size).reshape(p.size, eta.size)
# Compute Fisher matrix
fisher = numpy.dot(eta_map, p_stack * p_map) - numpy.outer(eta, eta)
return fisher
def compute_p(theta):
"""
Computes the probability distribution of spike patterns, for example
p(x1,x2) = e^(t1x1)e^(t2x2)e^(t12x1x2)
-----------------------------------
1 + e^(t1) + e^(t2) + e^(t1+t2+t12)
from the supplied natural parameters.
:param numpy.ndarray theta:
Natural `theta' parameters: t1, t2, ..., t12, ...
:returns:
Probability mass as a numpy.ndarray.
"""
global p_map
# Compute log probabilities
log_p = numpy.dot(p_map, theta)
# Take exponential and normalise
p = numpy.exp(log_p)
p_tmp = p / numpy.sum(p)
return p_tmp
def compute_psi(theta):
"""
Computes the normalisation parameter, psi, for the log-linear probability
mass function of spike patterns. For example, for two neurons
psi(theta) = 1 + e^(t1) + e^(t2) + e^(t1+t2+t12)
:param numpy.ndarray theta:
Natural `theta' parameters: t1, t2, ..., t12, ...
:returns:
Normalisation parameter, psi, of the log linear model as a float.
"""
global p_map
# Take coincident-pattern subsets of theta
tmp = numpy.dot(p_map, theta)
# Take the sum of the exponentials and take the log
tmp = numpy.sum(numpy.exp(tmp))
psi = numpy.log(tmp)
return psi
def compute_y(spikes, order, window):
"""
Computes the empirical mean rate of each spike pattern across trials for
each timestep up to `order', for example
y_12,t = 1 / N * \sigma^{L} X1_l,t * X2_l,t
is a second-order pattern where t is the timestep and l is the trial.
:param numpy.ndarray spikes:
Binary matrix with dimensions (time, runs, cells), in which a `1' in
location (t, r, c) denotes a spike at time t in run r by cell c.
:param int order:
Order of spike-train interactions to estimate, for example, 2 =
pairwise, 3 = triplet-wise...
:param int window:
Bin-width for counting spikes, in milliseconds.
:returns:
Trial-mean rates of each pattern in each timestep, as a numpy.ndarray
with `time' rows and sum_{k=1}^{order} {n \choose k} columns, given
n cells.
"""
# Get spike-matrix metadata
T, R, N = spikes.shape
# Bin spikes
spikes = spikes.reshape((T / window, window, R, N))
spikes = spikes.any(axis=1)
# Compute each n-choose-k subset of cell IDs up to `order'
subsets = enumerate_subsets(N, order)
# Set up the output array
y = numpy.zeros((T / window, len(subsets)))
# Iterate over each subset
for i in xrange(len(subsets)):
# Select the cells that are in the subset
sp = spikes[:,:,subsets[i]]
# Find the timesteps in which all subset-cells spike coincidentally
spc = sp.sum(axis=2) == len(subsets[i])
# Average the occurences of coincident patterns to get the mean rate
y[:,i] = spc.mean(axis=1)
return y
def enumerate_subsets(N, O):
"""
Enumerates all N-choose-k subsets of cell IDs for k = 1, 2, ..., O. For
example,
>>> compute_subsets(4, 2)
>>> [(0,), (1,), (2,), (3,), (0, 1),
(0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
:param int N:
Total cells from which to choose subsets.
:param int O:
Maximum size of subsets to enumerate.
:returns:
List of tuples, each tuple containing a subset of cell IDs.
"""
# Compute each C-choose-k subset of cell IDs up to `O'
subsets = list()
ids = numpy.arange(N)
for k in xrange(1, O + 1):
subsets.extend(list(itertools.combinations(ids, k)))
# Assert that we've got the correct number of subsets
assert len(subsets) == compute_D(N, O)
return subsets
def enumerate_patterns(N):
"""
Enumerates all spike patterns in order, for example:
>>> enumerate_patterns(3)
array([[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[1, 1, 0],
[1, 0, 1],
[0, 1, 1],
[1, 1, 1]], dtype=uint8)
:param int N:
Number of cells for which to enumerate patterns.
:returns:
Binary matrix of spike patterns as a numpy.ndarray of dimensions
(2**N, N).
"""
# Get the spike patterns as ordered subsets
subsets = enumerate_subsets(N, N)
assert len(subsets) == 2**N - 1
# Generate output array and fill according to subsets
fx = numpy.zeros((2**N, N), dtype=numpy.uint8)
for i in xrange(len(subsets)):
fx[i+1,subsets[i]] = 1
return fx
def initialise(N, O):
"""
Sets up matrices to transform between theta, probability and eta
coordinates. Computing probability requires finding subsets of theta for the
numerator; for example, with two cells, finding the numerator:
p(x1,x2) = e^(t1x1)e^(t2x2)e^(t12x1x2)
-----------------------------------
1 + e^(t1) + e^(t2) + e^(t1+t2+t12)
We calculate a `p_map' to do this for arbitrary numbers of neurons and
orders of interactions. To compute from probabilities to eta coordinates,
we use an `eta_map' that is just the transpose of the `p_map'. The method
for doing this is a bit tricky, and not easy to explain. Suffice to say,
it produces to the correct maps.
This function has the side effect of setting global variables `p_map' and
`eta_map', which are used later by other functions in this module.
:param int N:
Total cells in the spike data.
:param int order:
Order of spike-train interactions to estimate, for example, 2 =
pairwise, 3 = triplet-wise...
"""
global p_map, eta_map
# Create a matrix of binary spike patterns
fx = enumerate_patterns(N)
# Compute the number of natural parameters, given the order parameter
D = compute_D(N, O)
# Set up the output matrix
p_map = numpy.ones((2**N, D), dtype=numpy.uint8)
# Compute the map!
for i in xrange(1, D+1):
idx = numpy.nonzero(fx[i,:])[0]
for j in xrange(idx.size):
p_map[:,i-1] = p_map[:,i-1] & fx[:,idx[j]]
# Set up the eta map
eta_map = p_map.transpose()