@@ -74,7 +74,11 @@ def __init__(self, vars=None, covariance=None, scaling=1., n_chains=100,
7474 self .check_bnd = check_bound
7575 self .tune_interval = tune_interval
7676 self .steps_until_tune = tune_interval
77+
7778 self .proposal_dist = proposal_dist (self .covariance )
79+ self .proposal_samples_array = self .proposal_dist (n_chains )
80+ self .stage_sample = 0
81+
7882 self .accepted = 0
7983 self .beta = 0
8084 self .stage = 0
@@ -85,6 +89,7 @@ def __init__(self, vars=None, covariance=None, scaling=1., n_chains=100,
8589 [[v .dtype in pm .discrete_types ] * (v .dsize or 1 ) for v in vars ])
8690 self .any_discrete = self .discrete .any ()
8791 self .all_discrete = self .discrete .all ()
92+
8893 # create initial population
8994 self .population = []
9095 self .array_population = np .zeros (n_chains )
@@ -113,7 +118,8 @@ def astep(self, q0):
113118 self .steps_until_tune = self .tune_interval
114119 self .accepted = 0
115120
116- delta = self .proposal_dist () * self .scaling
121+ delta = self .proposal_samples_array [self .stage_sample , :] * \
122+ self .scaling
117123
118124 if self .any_discrete :
119125 if self .all_discrete :
@@ -143,6 +149,7 @@ def astep(self, q0):
143149 if q_new is q :
144150 self .accepted += 1
145151
152+ self .stage_sample += 1
146153 self .steps_until_tune -= 1
147154 return q_new
148155
@@ -179,7 +186,7 @@ def calc_beta(self):
179186 def calc_covariance (self ):
180187 """
181188 Calculate trace covariance matrix based on importance weights.
182-
189+
183190 Returns
184191 -------
185192 Ndarray of weighted covariances (NumPy > 1.10. required)
@@ -228,7 +235,7 @@ def select_end_points(self, mtrace):
228235
229236 # map end array_endpoints to dict points
230237 for i in range (self .n_chains ):
231- population .append (bij .rmap (array_population [i ,:]))
238+ population .append (bij .rmap (array_population [i , :]))
232239
233240 else :
234241 # for initial stage only one trace that contains all points for
@@ -243,7 +250,7 @@ def select_end_points(self, mtrace):
243250
244251 population = []
245252 for i in range (self .n_chains ):
246- population .append (bij .rmap (array_population [i ,:]))
253+ population .append (bij .rmap (array_population [i , :]))
247254
248255 return population , array_population , likelihoods
249256
@@ -403,6 +410,8 @@ def ATMIP_sample(n_steps, step=None, start=None, trace=None, chain=0,
403410 # Metropolis sampling intermediate stages
404411 stage_path = homepath + '/stage_' + str (step .stage )
405412 step .proposal_dist = MvNPd (step .covariance )
413+ step .proposal_samples_array = step .proposal_dist (
414+ step .n_chains * n_steps )
406415 sample_args = {
407416 'draws' : n_steps ,
408417 'step' : step ,
@@ -414,14 +423,17 @@ def ATMIP_sample(n_steps, step=None, start=None, trace=None, chain=0,
414423 step .population , step .array_population , step .likelihoods = \
415424 step .select_end_points (mtrace )
416425 step .beta , step .old_beta , step .weights = step .calc_beta ()
426+ step .stage += 1
427+ step .stage_sample = 0
428+
417429 if step .beta > 1. :
418430 print 'Beta > 1.:' , str (step .beta )
419431 step .beta = 1.
420432 break
421433
422434 step .covariance = step .calc_covariance ()
423435 step .res_indx = step .resample ()
424- step . stage += 1
436+
425437
426438 # Metropolis sampling final stage
427439 print 'Sample final stage'
@@ -431,6 +443,8 @@ def ATMIP_sample(n_steps, step=None, start=None, trace=None, chain=0,
431443 step .weights = temp / np .sum (temp )
432444 step .covariance = step .calc_covariance ()
433445 step .proposal_dist = MvNPd (step .covariance )
446+ step .proposal_samples_array = step .proposal_dist (
447+ step .n_chains * n_steps )
434448 step .res_indx = step .resample ()
435449
436450 sample_args ['step' ] = step
@@ -524,12 +538,12 @@ def tune(acc_rate):
524538 ----------
525539 acc_rate: scalar float
526540 Acceptance rate of the Metropolis sampling
527-
541+
528542 Returns
529543 -------
530544 scaling: scalar float
531545 """
532-
546+
533547 # a and b after Muto & Beck 2008 .
534548 a = 1. / 9
535549 b = 8. / 9
0 commit comments