-
-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DOC: Mention take_along_axis
in choose
#14117
Comments
@lucianopaz I think your need here should be served very well by import numpy as np
a = np.random.randint(0, 1000, size=(8, 3, 4))
b = np.random.rand(1000, 3, 4)
res = np.take_along_axis(b, a, axis=0) I think it would be good to put EDIT: If you need this in a version which is compatible with much older numpy versions, you can write it as advanced indexing. |
Thanks a lot @seberg! |
Yes, if you do not mind diving a bit in the code, in |
take_along_axis
in choose
Thanks again @seberg! |
Preamble
In the pymc3 package we implemented a series of distributions. One of such distributions is the
Categorical
distribution, where we have an ndarray,p
, that represents the probabilities of getting one ofK
categories (whereK
goes from0
top.shape[-1]-1
).Under some circumstances, its good to stack many independent categorical distributions on the same distribution instance (i.e. to write down a multidimensional categorical distribution). In these cases,
p.ndim
is larger than 1.PyMC3 focuses on doing MCMC, and to accomplish that, we write down a distribution's log probability. To fix our implementation on multidimensional categorical distributions, we recently switched to use
choose
instead of advanced indexing (we actually use theano, buttheano.tensor.choose
later dispatches tonumpy.choose
). However we encountered the following issue which can be reproduced by the simple following code:Reproducing code example:
Feature request
It would be really great if this 32 array objects limit were lifted only in the case in which the supplied
choices
parameter was anndarray
instance and had a non-object dtype (i.e.float64
,int64
or others like those).Numpy/Python version information:
python version: 3.6.7 | packaged by conda-forge | (default, Feb 26 2019, 03:50:56) [GCC 7.3.0]
numpy version: 1.16.1
The text was updated successfully, but these errors were encountered: