-
Notifications
You must be signed in to change notification settings - Fork 57
Open
Labels
API designAPI design & software architectureAPI design & software architectureenhancementmodule:probability&simulationprobability distributions and simulatorsprobability distributions and simulators
Description
Design issue related to adding outer product evaluation mode to distribution methods.
Related to comments of @FelixWick in #327 (comment).
Currently, BaseDistribution
descendants support up-broadcasting of arguments, but not up-broadcasting of results, if arguments with additional dimensions are given. Mathematically, this is related to taking outer products instead of element-wise products, in scalar or vector cases
This is a natural feature to offer, for instance tensorflow.probability
functions in this way, and cyclic-boosting
also has the feature.
There are a few obstacles for adding this:
- expectations on public inputs and outputs being
pd.DataFrame
numpy.ndarray
-s - where are not a supported input type- unclear target interface, e.g.,
cyclic-boosting
has an extra argument that controls elementwise vs outer product, which I find cumbersome and somewhat unintuitive.
Imo the most elegant way to allow this is as follows:
- allow
numpy.ndarray
and coercibles as input to all distribution methods. These must have shape of the distribution or 1, for dimensions present in the distribution. Inputs that are currently valid - up to 2D coercibles for array distributions, and flots for scalar distributions - are not coerced and treated as presently. - If a
numpy.ndarray
is passed, dimensions that are not present in the distribution will lead to output upcasting, special case is outer product mode for scalar distributions.
Examples:
- For a scalar distribution
n = Normal(0, 1)
, passingn.pdf([1, 2, 3, 4, 5])
will lead to evaluation at 5 points, and return anp.ndarray
of shape(5,)
. - For an array distribution
n = Normal([[0, 1], [2, 3]], 1)
, callingn.pdf([[[1, 2, 3, 4, 5]]])
leads to a returnnp.ndarray
of shape(2, 2, 3)
- upcasting happens along the 3rd dimension of length 5 that the argument introduces.
Thoughts, @setoguchi-naoki, @FelixWick, @VascoSch92, @ShreeshaM07 ?
Metadata
Metadata
Assignees
Labels
API designAPI design & software architectureAPI design & software architectureenhancementmodule:probability&simulationprobability distributions and simulatorsprobability distributions and simulators