-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Eager functions in API appear in conflict with lazy implementation #642
Comments
It's impossible for
I guess it's up to how you want to design your library, but I guess you really only have two choices: either make bool() implicitly perform a compute on the lazy graph, or make it raise an exception. Note that the standard does already discuss this in the context of data-dependent output shapes https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html, but perhaps we also need to identify other functions such as One question I have is, for a lazy evaluation library, how should it work for array consuming libraries (e.g. scipy) that are written against a generic array API expecting both eager and non-eager array backends? Is it the case that a scipy function should just take in a lazy array object and return a computation graph object, and fail if it tries to call a function that cannot be performed lazily? Or do these libraries need to know how to call There's also a question of how we can support libraries like this in the test suite. The test suite calls things like |
I think clearly advertising if the array backend in use currently is lazy or eager might make sense as part of the standard. If "lazy", any API return types are also lazy and must be explicitly materialised by the user with a standard top-level This may make things easier to work with in the test suite as well?
I think doing implicit computation is not possible in situations where you are looking to construct a computation graph but do not have your "inputs"/"arguments" available yet. It would also lead to a situation where some parts of the API yield concrete values and others are lazily constructing the computation graph which may be hard to navigate for a user. I suppose this very similar to https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html as you have mentioned. Just raising for |
Sorry but it makes zero sense to return anything other than the built-in If there's a computation which could compute a Boolean that you could then use in the subsequent evaluation, lazily, what you need should be simply |
I don't disagree, the lazy/eager suggestion was for the specific question not for |
Not sure what question you're referring to, Aditya? Regarding Arron's question about when to call |
Regarding lazy and
I'm not sure I agree with that. We must indeed specify Does anyone know why Dask forces evaluation? My guess would be that it's a pragmatic decision only - keeping conditionals as graph nodes can explore the graph size very quickly, since everything that follows bifurcates each time you call |
Thanks for the insightful discussion and links! A little more about our use case: We are building a lazy library that can be used to "trace" standard compliant code and then exports the resulting graph to ONNX. I.e. we are (primarily) interested in the computational graph. The values on which it will run are not necessarily available. That said, we happen to have some eager computations on top of that for debugging and testing. I would be fine with raising an error when Should the standard be updated to reflect the fact that lazy libraries may raise from |
Just to be completely clear, it's physically impossible for >>> class Test:
... def __bool__(self):
... return 1
>>> bool(Test())
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: __bool__ should return bool, returned int (the same is true for And even if this limitation didn't exist, that wouldn't really help, because 90% of the time I think the main takeaways from this discussion are:
|
I don't see how you can follow this policy in an Array API consuming library that implements an iterative solver (e.g. scikit-learn fitting a machine learning model). At each iteration you typically compute some summary scalar value and compare that to a scalar tolerance level (the result would be a scalar boolean) to make the decision to do another iteration or not. In this case it seems that we have no other choice than letting the Array API consuming library decide explicitly when it needs to trigger the evaluation to collect the scalar value, even when the outcome of this iterative loop is a collection of n-dimensional arrays of the same type as the input. |
Thanks for clarifying @asmeurer, that's the thing I was missing in my answer.
The array-consuming library doesn't have to, as pointed out by Leo and Aaron, the Python language already forces the evaluation. And this is then what Dask does: >>> import dask.array as da
>>> x = da.ones((2,3))
>>> y = x + 2*x
>>> y
dask.array<add, shape=(2, 3), dtype=float64, chunksize=(2, 3), chunktype=numpy.ndarray>
>>> y.compute()
array([[3., 3., 3.],
[3., 3., 3.]])
>>> y.sum()
dask.array<sum-aggregate, shape=(), dtype=float64, chunksize=(), chunktype=numpy.ndarray>
>>> y.sum() > 1
dask.array<gt, shape=(), dtype=bool, chunksize=(), chunktype=numpy.ndarray>
>>> bool(y.sum() > 1) # an if-statement does the same as `bool()` here
True So no need for any |
FWIW it's the same in cuNumeric too. When a Python scalar is needed, an expression is force-evaluated in a blocking manner. |
The question is here remains if the standard allows for raising when That would still mean that sklearn would have to call some form of |
If the Python language conclusive answers with "not allowed" here, as it seems to do, I think we should adhere to that. Given both Dask and cuNumeric also comply with it, that should be fine, right? |
The python language isn't that conclusive about it. We raise errors all the time, although for things where it's truly impossible. Aaron specifically mentioned that as a possible path above and it was what the @cbourjau was eyeing for, I think. cuNumeric, Dasks, and maybe others do the implicit |
Alright so an iterative function would have to call However, if def iterative_solver(data, params, tol=1e-4, maxiter=1_000):
record = defaultdict(list)
for iter_idx in range(maxiter):
params = compute_one_step(data, params)
record["iter"].append(iter_idx)
record["a"].append(float(metric_a(data, params)))
record["b"].append(float(metric_b(data, params)))
if stopping_criterion(data, params) < tol: # calls bool() implicitly
break
return params, record My understanding is that with the current implicit semantics when data & params are lazy arrays, each call to The only way around this would be to insert explicit checkpoints (such as dask.array's EDIT: fixed missing |
I think it's actively non-idiomatic to do so. You want to write code that does not care whether evaluation is triggered, but rather expresses only the logic and leaves execution semantics to the array library.
Isn't that just a quality-of-implementation issue? A good lazy library should automatically cache/persist calculated values that it can know have to be reused again. In this particular case though, the problem may be that Dask won't see the line |
Thanks for this great discussion! I think it might be useful to reiterate the following point: A lazy library (such as the one we are building on top of ONNX) may have an eager mode on top of it for debugging purposes, but those eager values must never influence the lazy computational graph that we are building. We are essentially trying to compile a sklearn-like pipeline ahead of time into an ONNX graph. We don't have any (meaningful) values available when executing the Python code that produces our lazy arrays. We have no other choice but to throw an exception if an eager value is requested. It would, however, be a pitty if that fact would stop a lazy array implementation from being standard compliant. Hence this issue to clarify if it would be ok by the standard to raise in those cases. On the topic of control flow: |
The number of necessary iterations is So we need it at least keep the
You are right that dask is probably clever enough to not recompute everything from the start in the code I wrote above because it would not garbage collect intermediate results for which they are still live dependencies in the driver Python program. Not sure if other libraries such as jax would tolerate this pattern though. |
I gave it a try and indeed dask is smart enough to avoid any recomputation while triggering computation as needed as part of the control flow: https://gist.github.com/ogrisel/6a4304e1831051203a98118875ead2d4 I am not sure if we can expect all the other lazy Array API implementations to follow the same semantics without a more explicit API though. |
I updated the above gist to also try Jax and it has the same semantics as dask w.r.t. If |
I think the current behavior of jax and dask is convenient: to come back to the original question of this issue, I think the Array API specification/documentation for |
As @cbourjau mentioned, it is not always possible to trigger a computation for lazy array implementations where you cannot use any "concrete input values" when building the computation graph (in ONNX, users serialise the computation graph for later execution with concrete inputs). This solution would make it impossible for such libraries to be fully standard compliant which would be a bit of a shame. |
In that case it would be helpful to have:
Otherwise an array consuming library that implements something akin to |
Can you explain what you mean? Coming from a scikit-learn background/use-case I can't quite imagine what the "without concrete input values" means. In what use case would it happen that Are you "tracing" the computation to build a graph? |
EDIT: decorating the
See also: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit and later in that same document: As noted previously, this function runs fine without the I updated the gist with torch and jax if you are interested in reproducing the above. |
Yes that's right. |
If we didn't make it so that the dask/jax behaviour becomes what the standard says should be done, wouldn't you still end up in trouble for tracing? In scikit-learn we'd have to explicitly be triggering the computation (instead of implicitly via I'm still not sure I fully understand how spox (I assume this is the library you are thinking about) does its thing but it feels like tracing is not a use-case for the array API? Like, it is a neat trick but you wouldn't change the design to make it easier to do tracing if more mainstream uses got harder. The work PyTorch has done is pretty exciting (as Olivier already said), in particular I think Torch Dynamo is the bit that does the tracing (or is it torch inductor?). Maybe worth investigating how they do it. |
From https://pytorch.org/docs/master/func.ux_limitations.html#data-dependent-python-control-flow: JAX supports transforming over data-dependent control flow using special control flow operators (e.g. In fact there now is a
I believe the majority of tracing use cases will work, using Python control flow based on values is one of very few things that won't work. And such code isn't a good fit for tracing anyway. So I think:
|
I agree with @rgommers's summary as a pragmatic stance for the short term. I think we need to wait a year or two w.r.t. how tensor libraries with JIT compiler support will evolve to start thinking how to standardize API for data-dependent control flow operators (and maybe even for standard a jit compiler decorator). However, those compiler related aspects are an important evolution of the numerical / data science Python ecosystem and I think we should keep them in mind to later consider Array API spec extension (similar to what is done for the |
I fully agree with @rgommers summary, too. Thanks for the great discussion. Should I make a PR that clarifies the different behaviors in the standard? |
@cbourjau that would be great, thank you |
We are looking at adapting this API for a lazy array library built on top of ONNX (and Spox). It seems to be an excellent fit for most parts. However, certain functions in the specification appear to be in conflict with a lazy implementation. One instance is
__bool__
which is defined as:The output is documented as "a Python bool object representing the single element of the array". This is problematic for lazy implementations since the value is not available at this point. How should a standard compliant lazy implementation deal with this apparent conflict?
The text was updated successfully, but these errors were encountered: