@@ -184,9 +184,12 @@ def fastd2logp(self, vars=None):
184184 def logpt (self ):
185185 """Theano scalar of log-probability of the model"""
186186 if getattr (self , 'total_size' , None ) is not None :
187- return tt .sum (self .logp_elemwiset ) * self .scaling
187+ logp = tt .sum (self .logp_elemwiset ) * self .scaling
188188 else :
189- return tt .sum (self .logp_elemwiset )
189+ logp = tt .sum (self .logp_elemwiset )
190+ if self .name is not None :
191+ logp .name = '__logp_%s' % self .name
192+ return logp
190193
191194
192195class InitContextMeta (type ):
@@ -277,6 +280,173 @@ def tree_contains(self, item):
277280 return dict .__contains__ (self , item )
278281
279282
283+ class ValueGradFunction (object ):
284+ """Create a theano function that computes a value and its gradient.
285+
286+ Parameters
287+ ----------
288+ cost : theano variable
289+ The value that we compute with its gradient.
290+ grad_vars : list of named theano variables or None
291+ The arguments with respect to which the gradient is computed.
292+ extra_args : list of named theano variables or None
293+ Other arguments of the function that are assumed constant. They
294+ are stored in shared variables and can be set using
295+ `set_extra_values`.
296+ dtype : str, default=theano.config.floatX
297+ The dtype of the arrays.
298+ casting : {'no', 'equiv', 'save', 'same_kind', 'unsafe'}, default='no'
299+ Casting rule for casting `grad_args` to the array dtype.
300+ See `numpy.can_cast` for a description of the options.
301+ Keep in mind that we cast the variables to the array *and*
302+ back from the array dtype to the variable dtype.
303+ kwargs
304+ Extra arguments are passed on to `theano.function`.
305+
306+ Attributes
307+ ----------
308+ size : int
309+ The number of elements in the parameter array.
310+ profile : theano profiling object or None
311+ The profiling object of the theano function that computes value and
312+ gradient. This is None unless `profile=True` was set in the
313+ kwargs.
314+ """
315+ def __init__ (self , cost , grad_vars , extra_vars = None , dtype = None ,
316+ casting = 'no' , ** kwargs ):
317+ if extra_vars is None :
318+ extra_vars = []
319+
320+ names = [arg .name for arg in grad_vars + extra_vars ]
321+ if any (name is None for name in names ):
322+ raise ValueError ('Arguments must be named.' )
323+ if len (set (names )) != len (names ):
324+ raise ValueError ('Names of the arguments are not unique.' )
325+
326+ if cost .ndim > 0 :
327+ raise ValueError ('Cost must be a scalar.' )
328+
329+ self ._grad_vars = grad_vars
330+ self ._extra_vars = extra_vars
331+ self ._extra_var_names = set (var .name for var in extra_vars )
332+ self ._cost = cost
333+ self ._ordering = ArrayOrdering (grad_vars )
334+ self .size = self ._ordering .size
335+ self ._extra_are_set = False
336+ if dtype is None :
337+ dtype = theano .config .floatX
338+ self .dtype = dtype
339+ for var in self ._grad_vars :
340+ if not np .can_cast (var .dtype , self .dtype , casting ):
341+ raise TypeError ('Invalid dtype for variable %s. Can not '
342+ 'cast to %s with casting rule %s.'
343+ % (var .name , self .dtype , casting ))
344+ if not np .issubdtype (var .dtype , float ):
345+ raise TypeError ('Invalid dtype for variable %s. Must be '
346+ 'floating point but is %s.'
347+ % (var .name , var .dtype ))
348+
349+ givens = []
350+ self ._extra_vars_shared = {}
351+ for var in extra_vars :
352+ shared = theano .shared (var .tag .test_value , var .name + '_shared__' )
353+ self ._extra_vars_shared [var .name ] = shared
354+ givens .append ((var , shared ))
355+
356+ self ._vars_joined , self ._cost_joined = self ._build_joined (
357+ self ._cost , grad_vars , self ._ordering .vmap )
358+
359+ grad = tt .grad (self ._cost_joined , self ._vars_joined )
360+ grad .name = '__grad'
361+
362+ inputs = [self ._vars_joined ]
363+
364+ self ._theano_function = theano .function (
365+ inputs , [self ._cost_joined , grad ], givens = givens , ** kwargs )
366+
367+ def set_extra_values (self , extra_vars ):
368+ self ._extra_are_set = True
369+ for var in self ._extra_vars :
370+ self ._extra_vars_shared [var .name ].set_value (extra_vars [var .name ])
371+
372+ def get_extra_values (self ):
373+ if not self ._extra_are_set :
374+ raise ValueError ('Extra values are not set.' )
375+
376+ return {var .name : self ._extra_vars_shared [var .name ].get_value ()
377+ for var in self ._extra_vars }
378+
379+ def __call__ (self , array , grad_out = None , extra_vars = None ):
380+ if extra_vars is not None :
381+ self .set_extra_values (extra_vars )
382+
383+ if not self ._extra_are_set :
384+ raise ValueError ('Extra values are not set.' )
385+
386+ if array .shape != (self .size ,):
387+ raise ValueError ('Invalid shape for array. Must be %s but is %s.'
388+ % ((self .size ,), array .shape ))
389+
390+ if grad_out is None :
391+ out = np .empty_like (array )
392+ else :
393+ out = grad_out
394+
395+ logp , dlogp = self ._theano_function (array )
396+ if grad_out is None :
397+ return logp , dlogp
398+ else :
399+ out [...] = dlogp
400+ return logp
401+
402+ @property
403+ def profile (self ):
404+ """Profiling information of the underlying theano function."""
405+ return self ._theano_function .profile
406+
407+ def dict_to_array (self , point ):
408+ """Convert a dictionary with values for grad_vars to an array."""
409+ array = np .empty (self .size , dtype = self .dtype )
410+ for varmap in self ._ordering .vmap :
411+ array [varmap .slc ] = point [varmap .var ].ravel ().astype (self .dtype )
412+ return array
413+
414+ def array_to_dict (self , array ):
415+ """Convert an array to a dictionary containing the grad_vars."""
416+ if array .shape != (self .size ,):
417+ raise ValueError ('Array should have shape (%s,) but has %s'
418+ % (self .size , array .shape ))
419+ if array .dtype != self .dtype :
420+ raise ValueError ('Array has invalid dtype. Should be %s but is %s'
421+ % (self ._dtype , self .dtype ))
422+ point = {}
423+ for varmap in self ._ordering .vmap :
424+ data = array [varmap .slc ].reshape (varmap .shp )
425+ point [varmap .var ] = data .astype (varmap .dtyp )
426+
427+ return point
428+
429+ def array_to_full_dict (self , array ):
430+ """Convert an array to a dictionary with grad_vars and extra_vars."""
431+ point = self .array_to_dict (array )
432+ for name , var in self ._extra_vars_shared .items ():
433+ point [name ] = var .get_value ()
434+ return point
435+
436+ def _build_joined (self , cost , args , vmap ):
437+ args_joined = tt .vector ('__args_joined' )
438+ args_joined .tag .test_value = np .zeros (self .size , dtype = self .dtype )
439+
440+ joined_slices = {}
441+ for vmap in vmap :
442+ sliced = args_joined [vmap .slc ].reshape (vmap .shp )
443+ sliced .name = vmap .var
444+ joined_slices [vmap .var ] = sliced
445+
446+ replace = {var : joined_slices [var .name ] for var in args }
447+ return args_joined , theano .clone (cost , replace = replace )
448+
449+
280450class Model (six .with_metaclass (InitContextMeta , Context , Factor )):
281451 """Encapsulates the variables and likelihood factors of a model.
282452
@@ -419,7 +589,6 @@ def bijection(self):
419589 return bij
420590
421591 @property
422- @memoize
423592 def dict_to_array (self ):
424593 return self .bijection .map
425594
@@ -428,23 +597,34 @@ def ndim(self):
428597 return sum (var .dsize for var in self .free_RVs )
429598
430599 @property
431- @memoize
432600 def logp_array (self ):
433601 return self .bijection .mapf (self .fastlogp )
434602
435603 @property
436- @memoize
437604 def dlogp_array (self ):
438605 vars = inputvars (self .cont_vars )
439606 return self .bijection .mapf (self .fastdlogp (vars ))
440607
608+ def logp_dlogp_function (self , grad_vars = None , ** kwargs ):
609+ if grad_vars is None :
610+ grad_vars = list (typefilter (self .free_RVs , continuous_types ))
611+ else :
612+ for var in grad_vars :
613+ if var .dtype not in continuous_types :
614+ raise ValueError ("Can only compute the gradient of "
615+ "continuous types: %s" % var )
616+ varnames = [var .name for var in grad_vars ]
617+ extra_vars = [var for var in self .free_RVs if var .name not in varnames ]
618+ return ValueGradFunction (self .logpt , grad_vars , extra_vars , ** kwargs )
619+
441620 @property
442- @memoize
443621 def logpt (self ):
444622 """Theano scalar of log-probability of the model"""
445623 with self :
446624 factors = [var .logpt for var in self .basic_RVs ] + self .potentials
447- return tt .add (* map (tt .sum , factors ))
625+ logp = tt .add (* map (tt .sum , factors ))
626+ logp .name = '__logp'
627+ return logp
448628
449629 @property
450630 def varlogpt (self ):
@@ -595,7 +775,6 @@ def __getitem__(self, key):
595775 except KeyError :
596776 raise e
597777
598- @memoize
599778 def makefn (self , outs , mode = None , * args , ** kwargs ):
600779 """Compiles a Theano function which returns `outs` and takes the variable
601780 ancestors of `outs` as inputs.
0 commit comments