Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistency in numba mode when passing scalar to function #1063

Closed
twiecki opened this issue Jul 19, 2022 · 4 comments · Fixed by #1073
Closed

Inconsistency in numba mode when passing scalar to function #1063

twiecki opened this issue Jul 19, 2022 · 4 comments · Fixed by #1073
Labels
backend compatibility Issues relating to compatibility with backends called by this library bug Something isn't working Numba Involves Numba transpilation

Comments

@twiecki
Copy link
Contributor

twiecki commented Jul 19, 2022

In numba mode:

import aesara
aesara.config.mode = "NUMBA"
import aesara.tensor as at

x = at.scalar(name="x")
f = aesara.function([x], 2*x)
print(repr(f(10)))  # prints 20.0

In c mode:

import aesara
import aesara.tensor as at

x = at.scalar(name="x")
f = aesara.function([x], 2*x)
print(repr(f(10))) # prints array(20.)

This came up in pymc-devs/pymc#5937. Found by bherwerth.

@twiecki twiecki added bug Something isn't working backend compatibility Issues relating to compatibility with backends called by this library Numba Involves Numba transpilation labels Jul 19, 2022
@brandonwillard
Copy link
Member

brandonwillard commented Jul 19, 2022

This is mostly due to Aesara's lack of support for built-in scalar types (e.g. int, float, etc.). Every Aesara scalar is actually a NumPy ndarray scalar.

Since Numba supports the built-in scalar types and uses/returns them wherever/whenever it's reasonable, we face the challenge of deciding whether or not to manually convert all outputs to ndarrays within our Numba implementations, or let Numba use the built-in scalar types as much as possible and only construct ndarrays when it's necessary (e.g. when a Numba implementation expects/requires an ndarray input).

We've essentially chosen the latter, because it's cumulatively more efficient, since it produces Numba implementations with fewer ndarray constructions and allows the use of more performant scalars in some cases.

In this case, it looks like we might need to add another manual conversion somewhere, but, ideally, not the return values of our Numba implementations of Elemwise Ops. Perhaps, it's possible for us to perform the conversion outside of Numba, or judiciously within Numba (e.g. applied only to the final outputs of a Numba-compiled graph).

@bherwerth
Copy link

If not touching the internals of the Numba implementation, how about applying np.asarray in the return statements in aesara.compile.function.types.Function.__call__?

In the docstring, one could then also specify the return type as List[ndarray]:

Returns
-------
list
List of outputs on indices/keys from ``output_subset`` or all of them,
if ``output_subset`` is not passed.

Btw, I noticed the docstring says the return type is list, but there seem to be dicts returned in some cases:

if output_subset is None:
return dict(zip(self.output_keys, outputs))
else:
return {
self.output_keys[index]: outputs[index]
for index in output_subset
}

@brandonwillard
Copy link
Member

If not touching the internals of the Numba implementation, how about applying np.asarray in the return statements in aesara.compile.function.types.Function.__call__?

That's what I would consider performing the conversion outside of Numba, and I'm very hesitant to take such an approach, mostly because of the way it mixes concerns/contexts. The code in Function is supposed to handle the general use/orchestration of VMs, and the assumption is that a VM will produce valid outputs. Adding a conversion step to Function mixes the responsibilities.

Moving further up the chain, even the VM assumes that the values put in its output storage lists are valid, so I think the right approach involves adding a numpy.asarray to the outputs somewhere after the call to fgraph_to_python in aesara.link.numba.linker.NumbaLinker. The most direct place would probably be here:

aesara/aesara/link/basic.py

Lines 666 to 673 in c8908e5

for o_node, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
compute_map[o_node][0] = True
if len(o_storage) > 1:
assert len(o_storage) == len(o_val)
for i, o_sub_val in enumerate(o_val):
o_storage[i] = o_sub_val
else:
o_storage[0] = o_val

Since this code is in a base class (i.e. JITLinker), we would need to consider whether or not the addition of numpy.asarray is appropriate in all cases. My first thought is that it generally is.

N.B. We can't use numpy.asarray alone; instead, we would need to use numpy.asarray(x, dtype=o_node.outputs[i].dtype) in order to makes sure that the dtype of the converted scalar matches the graph's specification.

@brandonwillard
Copy link
Member

Btw, I noticed the docstring says the return type is list, but there seem to be dicts returned in some cases:

Yeah, we haven't started typing that module yet. It's important that we start doing that work, though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend compatibility Issues relating to compatibility with backends called by this library bug Something isn't working Numba Involves Numba transpilation
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants