-
Notifications
You must be signed in to change notification settings - Fork 1
/
exact_LL_setup.m
228 lines (188 loc) · 7.72 KB
/
exact_LL_setup.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
function cone_map = exact_LL_setup( GC_stas , cone_params , cone_map )
%% cone_map = exact_LL_setup( GC_stas , cone_params , cone_map )
% Expand data and parameters into variables used to calculate likelihoods.
% Mainly, spatial supersampling by a factor of cone_map.supersample is
% applied to the STAs, and the convolution of the STAs with the cone
% receptive fields is used to calculate the sparsity structure of
% connections between cones and GCs, as well as a map of LL increases.
if nargin < 3 , cone_map = struct ; end
% size of data
[M0,M1,cone_map.N_colors] = size(GC_stas(1).spatial) ;
%% unpack data from GC_stas
% supersample factor
SS = cone_params.supersample ;
% radius used for convolutions with cone RFs
R = ceil( cone_params.support_radius * SS ) ;
cone_params.support_radius = ceil(cone_params.support_radius) ;
% copy cone_params
cone_map.cone_params = cone_params ;
% set up Region of Interest if it doesn't already exist
% these are fractional (x,y) coordinates, due to supersampling
if ~isfield(cone_map,'ROI')
x = repmat( 1/(2*SS):1/SS:M0-1/(2*SS) , 1 , M1*SS ) ;
y = repmat( 1/(2*SS):1/SS:M1-1/(2*SS) , M0*SS , 1 ) ;
cone_map.ROI = [x' y(:)] ;
clear x y
end
% total number of possible cone positions
cone_map.NROI = size(cone_map.ROI,1) ;
% unpack GC_stas into: supersampled STAs, norms of STAs and N_spikes
N_GC = length(GC_stas) ;
STA_norm = zeros(N_GC,1) ;
cone_map.N_spikes = zeros(N_GC,1) ;
STA = zeros(cone_map.N_colors,M0,M1,N_GC) ;
for i=1:N_GC
cone_map.N_spikes(i) = length(GC_stas(i).spikes) ;
STA(:,:,:,i) = reshape(permute(GC_stas(i).spatial,[3 1 2]), cone_map.N_colors, M0, M1 ) ;
STA_norm(i) = norm(reshape(STA(:,:,:,i),1,[])) ;
end
%% calculate some constants etc in advance, to speed up actual calculations
% memoized function returning gaussian mass in a square pixel box
cone_map.gaus_boxed = gaus_in_a_box_memo( cone_params.sigma, SS, cone_params.support_radius ) ;
% constants used to calculate the log-likelihood
cone_map.cell_consts = cone_map.N_spikes * cone_params.stimulus_variance ;
cone_map.prior_covs = (cone_params.stimulus_variance ./ STA_norm ).^2 ;
cone_map.cov_factors = cone_map.cell_consts+cone_map.prior_covs ;
cone_map.N_cones_terms = log( cone_map.prior_covs ) - log( cone_map.cov_factors) ;
cone_map.quad_factors = cone_map.N_spikes.^2 ./ cone_map.cov_factors ;
% for fast lookup of legal shift moves in move.m
cone_map.outofbounds = sparse([],[],[],M0*SS,M1*SS,2*(M0+M1)*SS) ;
cone_map.outofbounds(:,[1 M1*SS]) = 1 ;
cone_map.outofbounds([1 M0*SS],:) = cone_map.outofbounds([1 M0*SS],:) + 1 ;
% copy params
cone_map.M0 = M0 ;
cone_map.M1 = M1 ;
cone_map.N_GC = N_GC ;
cone_map.SS = SS ;
cone_map.R = R ;
cone_map.STA = single( STA ) ;
cone_map.min_STA_W = -0.2 ; %min(STA_W(:)) ;
cone_map.colorDot = cone_params.colors * cone_params.colors' ;
%% make lookup table for dot products of cone RFs in all possible positions
cone_map.coneConv = zeros( 2*R+SS , 2*R+SS , SS , SS ) ;
WW = zeros(SS,SS) ;
f = 1/(2*SS):1/SS:2*R/SS+1 ;
for xx=1:2*R+SS
x = f(xx) ;
for yy=1:2*R+SS
y = f(yy) ;
a = make_filter_new(4*R/SS+1,4*R/SS+1,x+R/SS,y+R/SS,cone_map.gaus_boxed,...
cone_map.cone_params.support_radius) ;
for ss=1:SS
s = (ss-0.5)/SS ;
for tt=1:SS
t = (tt-0.5)/SS ;
b = make_filter_new(4*R/SS+1,4*R/SS+1,2*R/SS+s,2*R/SS+t,...
cone_map.gaus_boxed,cone_map.cone_params.support_radius) ;
cone_map.coneConv(xx,yy,ss,tt) = dot(a(:),b(:)) ;
end
end
if (xx<=SS) && (yy<=SS)
WW(xx,yy) = dot(a(:),a(:)) ;
end
end
end
%% calculate sparsity structure and map of LL increases
[cone_map.sparse_struct,cone_map.LL] = ...
make_sparse_struct(cone_map,STA,WW,cone_map.gaus_boxed) ;
%% cone_map.NICE is a pretty visualizable version of cone_map.LL
IC = inv(cone_params.colors) ;
QC = reshape( reshape(cone_map.LL,[],3) * IC', size(cone_map.LL) ) ;
cone_map.NICE = plotable_evidence( QC ) ;
% imagesc( cone_map.NICE )
%% some default values
cone_map.N_iterations = 100000 ;
cone_map.max_time = 200000 ;
cone_map.N_fast = 1 ;
cone_map.q = 0.5 ;
cone_map.ID = 0 ;
cone_map.save_disk_space = false ;
%% initial empty X
cone_map.initX = initialize_X( cone_map.M0, cone_map.M1, ...
cone_map.N_colors, cone_map.SS, ...
cone_map.cone_params.replusion_radii, ...
1, 1) ;
%% transfer all info from cone_map to cone_map.initX
cone_map.initX = transfer_info( cone_map, cone_map.initX ) ;
%% quick sanity check: compare cone_map.make_STA_W with make_LL
mLL = max(cone_map.LL(:)) ;
mx = 63 ;
my = 32 ;
mc = 1 ;
tX = flip_LL( cone_map.initX , [mx my mc] , cone_map , [1 1] ) ;
fprintf('\nLL and flip_ll: %f,%f, at x%d,y%d,c%d\n',mLL,tX.ll,mx,my,mc) ;
range_x = mx+(-4:5) ;
range_y = my+(-4:5) ;
test = zeros(numel(range_x),numel(range_y)) ;
for iii=range_x
for jjj=range_y
tX = flip_LL( cone_map.initX , [iii jjj 1] , cone_map , [1 1] ) ;
test(iii,jjj) = tX.ll ;
end
end
fprintf('\n')
disp(test(range_x,range_y))
disp(cone_map.LL(range_x,range_y,1))
disp( test(range_x,range_y) - cone_map.LL(range_x,range_y,1) )
fprintf('\n')
end
function [sparse_struct, LL] = make_sparse_struct(cone_map,STA,WW,gaus_boxed)
% calculate sparsity structure of connections between all possible cone
% locations and GCs, as well as the map of log-likelihoods of all
% single-cone configurations
M0 = cone_map.M0 ;
M1 = cone_map.M1 ;
SS = cone_map.SS ;
support = cone_map.cone_params.support_radius ;
colors = cone_map.cone_params.colors ;
LL = zeros(M0*SS,M1*SS,cone_map.N_colors) ;
supersamples = 1/(2*SS):1/SS:1 ;
gs = cell(SS) ;
sparse_struct = cell( M0*SS, M1*SS, cone_map.N_colors ) ;
% for every supersampled location within one pixel, calculate the cone RF
for ii=1:SS
for jj=1:SS
i = supersamples(ii) ;
j = supersamples(jj) ;
g = reshape( gaus_boxed(i,j), [2*support+1 2*support+1]) ;
gs{ii,jj} = g(end:-1:1,end:-1:1) ;
end
end
fprintf('making sparse struct and LL for GC number')
for gc=1:cone_map.N_GC
gcLL = zeros(M0*SS,M1*SS,cone_map.N_colors) ;
% convolve all cone RFs with all GC STAs
for ii=1:SS
for jj=1:SS
CC = zeros(M0*M1,cone_map.N_colors) ;
for color=1:cone_map.N_colors
CCC = conv2( squeeze(STA(color,:,:,gc)), gs{ii,jj} ) ;
CCC = CCC(support+1:M0+support,support+1:M1+support) ;
CC(:,color) = CCC(:) ;
end
C = 0.5 * cone_map.quad_factors(gc) * (CC * colors').^2 / WW(ii,jj) ;
% the max here defines the sparsity
C = max(0,C+0.5*cone_map.N_cones_terms(gc)) ;
gcLL( ii:SS:M0*SS, jj:SS:M1*SS, :) = ...
gcLL( ii:SS:M0*SS, jj:SS:M1*SS, :) + reshape(C,[M0 M1 3]) ;
end
end
% record sparsity in sparse_struct
[x,yc] = find( gcLL>0 ) ;
y = 1+mod(yc-1,M1*SS) ;
c = ceil( yc/(M1*SS) ) ;
for i=1:numel(x)
sparse_struct{x(i),y(i),c(i)} = int16([sparse_struct{x(i),y(i),c(i)} gc]) ;
end
LL = LL + gcLL ;
fprintf(' %d',gc)
end
end
function filter = make_filter_new(M0,M1,i,j,gaus_boxed, support)
% make cone RF centered at (i,j), being careful about boundaries
filter = zeros(M0,M1) ;
[g,t,r,b,l] = filter_bounds( i, j, M0, M1, gaus_boxed, support) ;
filter(t:b,l:r) = g ;
% filter is inverted; doesn't matter for the dot product calculation though
% filter = filter(end:-1:1,end:-1:1) ; % uncomment to uninvert
end