-
Notifications
You must be signed in to change notification settings - Fork 620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow vector-valued QNodes with JAX using host_callback.call
#2034
Conversation
Codecov Report
@@ Coverage Diff @@
## master #2034 +/- ##
=======================================
Coverage 99.61% 99.61%
=======================================
Files 251 251
Lines 20553 20569 +16
=======================================
+ Hits 20473 20489 +16
Misses 80 80
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job @antalszava! 😍 It's awesome to see the JAX support for vector-value QNodes. I only have a few minor suggestions; but I am happy to approve.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @antalszava ! Thanks for this awesome work 💯 No major blockers, but I have some questions. I think you can also update the documentation (chart with configurations) :)
Co-authored-by: Ali Asadi <ali@xanadu.ai>
Co-authored-by: Ali Asadi <ali@xanadu.ai>
Thank you @maliasadi, @rmoyard for the reviews! The comments should be addressed now. (One thing todo for me will be to double-check an edge case of multiple tapes with multiple parameters for forward mode.) |
…not relevant, it's important measurements.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work @antalszava 💯 I've left a comment about the problem you mentioned, we can always come back to it 👍 Thank you for changing VnEntropy
and MutualInformation
measurements process!
Context:
Vector-valued QNodes may include:
return qml.expval(qml.PauliZ(0), qml.expval(qml.PauliZ(1)
;qml.probs
,qml.state
orqml.density_matrix
.The JAX interface doesn't support these return types. The main reason is that
host_callback.call
, the underlying function that is being used requires the output shape to be passed. The current logic always considers tapes with scalar outputs.Uses the machinery introduced in #2044.
Description of the Change:
MeasurementProcess
class such thatshape
andnumeric_type
are methods because there were uncovered edge cases of not having pre-set shapes/numeric types (e.g.,self._shape
wasNone
);host_callback.call
.Benefits:
Vector-valued QNodes can be evaluated using the JAX JIT interface that uses
host_callback.call
;host_callback.call
is jittable.Possible Drawbacks:
The JAX JIT interface doesn't support
jax.jacobian
(see discussion in #2163).Related GitHub Issues:
Closes #1208, #2404
Testing:
Testing categories:
Multiple scalar outputs:
Single vector valued outputs:
Mixed vector and scalar-valued outputs: