Skip to content

Commit

Permalink
WIP: Establish code to handle vector outputs
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Seyboldt <aseyboldt@users.noreply.github.com>
  • Loading branch information
ricardoV94 and aseyboldt committed Apr 23, 2024
1 parent 08db524 commit 769809e
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 158 deletions.
6 changes: 5 additions & 1 deletion pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,12 @@ def numba_funcify_Elemwise(op, node, **kwargs):
scalar_op_fn = numba_funcify(
op.scalar_op, node=scalar_node, parent_node=node, fastmath=flags, **kwargs
)
# TODO: Implement me
# scalar_op_fn = save_outputs_of_scalar_fn(op, scalar_op_fn)

ndim = node.outputs[0].ndim
output_bc_patterns = tuple([(False,) * ndim for _ in node.outputs])
nout = len(node.outputs)
output_bc_patterns = tuple([(False,) * ndim for _ in nout])
input_bc_patterns = tuple([input_var.broadcastable for input_var in node.inputs])
output_dtypes = tuple(variable.dtype for variable in node.outputs)
inplace_pattern = tuple(op.inplace_pattern.items())
Expand All @@ -516,6 +519,7 @@ def elemwise_wrapper(*inputs):
inplace_pattern_enc,
(), # constant_inputs
inputs,
[() for _ in range(nout)], # core_shapes
None, # size
)

Expand Down
119 changes: 65 additions & 54 deletions pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pickle
from collections.abc import Callable
from copy import copy
from functools import singledispatch
from textwrap import dedent, indent
from typing import Any

Expand Down Expand Up @@ -168,25 +169,10 @@ def impl(rng):
return impl


@numba_funcify.register(ptr.RandomVariable)
def numba_funcify_RandomVariable(op: RandomVariable, node, **kwargs):
_, size, _, *args = node.inputs
# None sizes are represented as empty tuple for the time being
# https://github.com/pymc-devs/pytensor/issues/568
[size_len] = size.type.shape
size_is_None = size_len == 0

inplace = op.inplace

if op.ndim_supp > 0:
raise NotImplementedError("Multivariate random variables not supported yet")

# if any(ndim_param > 0 for ndim_param in op.ndims_params):
# raise NotImplementedError(
# "Random variables with non scalar core inputs not supported yet"
# )
@singledispatch
def core_rv_fn(op: Op):
"""Return the core function for a random variable operation."""

# TODO: Use dispatch, so users can define the core case
# Use string repr for default like below
# inner_code = dedent(f"""
# @numba_basic.numba_njit
Expand All @@ -197,15 +183,67 @@ def numba_funcify_RandomVariable(op: RandomVariable, node, **kwargs):
# exec(inner_code)
# scalar_op_fn = locals()['scalar_op_fn']

# @numba_basic.numba_njit
# def core_op_fn(rng, mu, scale):
# return rng.normal(mu, scale)
raise NotImplementedError()


@core_rv_fn.register(ptr.NormalRV)
def core_NormalRV(op):
@numba_basic.numba_njit
def core_op_fn(rng, p):
def random_fn(rng, mu, scale, out):
out[...] = rng.normal(mu, scale)

random_fn.handles_out = True
return random_fn


@core_rv_fn.register(ptr.CategoricalRV)
def core_CategoricalRV(op):
@numba_basic.numba_njit
def random_fn(rng, p, out):
unif_sample = rng.uniform(0, 1)
return np.searchsorted(np.cumsum(p), unif_sample)
# TODO: Check if LLVM can lift constant cumsum(p) out of the loop
out[...] = np.searchsorted(np.cumsum(p), unif_sample)

random_fn.handles_out = True
return random_fn


@core_rv_fn.register(ptr.MvNormalRV)
def core_MvNormalRV(op):
@numba.njit
def random_fn(rng, mean, cov, out):
chol = np.linalg.cholesky(cov)
stdnorm = rng.normal(size=cov.shape[-1])
# np.dot(chol, stdnorm, out=out)
# out[...] += mean
out[...] = mean + np.dot(chol, stdnorm)

random_fn.handles_out = True
return random_fn


@numba_funcify.register(ptr.RandomVariable)
def numba_funcify_RandomVariable(op: RandomVariable, node, **kwargs):
_, size, _, *args = node.inputs
# None sizes are represented as empty tuple for the time being
# https://github.com/pymc-devs/pytensor/issues/568
[size_len] = size.type.shape
size_is_None = size_len == 0

inplace = op.inplace

# TODO: Add core_shape to node.inputs
if op.ndim_supp > 0:
raise NotImplementedError("Multivariate RandomVariable not implemented yet")

# TODO: Create a wrapper (string processing?) that takes a core function without outputs
# and saves those outputs in the variables passed by `_vectorized`
core_op_fn = core_rv_fn(op)
if not getattr(core_op_fn, "handles_out", False):
# core_op_fn = store_core_outputs(op, core_op_fn)
raise NotImplementedError()

# TODO: Refactor this code, it's the same with Elemwise
batch_ndim = node.default_output().ndim - op.ndim_supp
output_bc_patterns = ((False,) * batch_ndim,)
input_bc_patterns = tuple(
Expand Down Expand Up @@ -234,12 +272,14 @@ def random_wrapper(rng, size, dtype, *inputs):
inplace_pattern_enc,
(rng,),
inputs,
None if size_is_None else numba_ndarray.to_fixed_tuple(size, size_len),
((),), # TODO: correct core_shapes
None
if size_is_None
else numba_ndarray.to_fixed_tuple(size, size_len), # size
)
return rng, draws

def random(rng, size, dtype, *inputs):
# TODO: Add code that will be tested for coverage
pass

@overload(random)
Expand Down Expand Up @@ -330,35 +370,6 @@ def body_fn(a):
)


# @numba_funcify.register(ptr.CategoricalRV)
def numba_funcify_CategoricalRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
size_len = int(get_vector_length(node.inputs[1]))
p_ndim = node.inputs[-1].ndim

@numba_basic.numba_njit
def categorical_rv(rng, size, dtype, p):
if not size_len:
size_tpl = p.shape[:-1]
else:
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
p = np.broadcast_to(p, size_tpl + p.shape[-1:])

# Workaround https://github.com/numba/numba/issues/8975
if not size_len and p_ndim == 1:
unif_samples = np.asarray(np.random.uniform(0, 1))
else:
unif_samples = np.random.uniform(0, 1, size_tpl)

res = np.empty(size_tpl, dtype=out_dtype)
for idx in np.ndindex(*size_tpl):
res[idx] = np.searchsorted(np.cumsum(p[idx]), unif_samples[idx])

return (rng, res)

return categorical_rv


@numba_funcify.register(ptr.DirichletRV)
def numba_funcify_DirichletRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
Expand Down
Loading

0 comments on commit 769809e

Please sign in to comment.