-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add get samples function to InstructionToSignals for JAX-jit usage #149
Add get samples function to InstructionToSignals for JAX-jit usage #149
Conversation
Thanks for the PR @to24toro ! I think the code you've added to Dynamics looks good - however there is a core problem right now that prevents this from being useful. Your test:
sets
If you try to do this now with the current
which is because in this case, With this error being raised, there is unfortunately no benefit to having To move forward with this PR, I think it makes sense to figure out whatever changes are necessary in terra to make the above test case work with the |
The special casing of amp must be removed from terra (only duration and amp are the non-float and unfortunately Python is not typed). Currently Qiskit/qiskit#9002 is trying to introduce (amp, float) representation to symbolic pulses, and this makes all parameters float-type except for duration. This approach allows us to remove builtin typecasting and we can eventually remove Feel free to on hold this until above PR is merged, or you can merge this PR without JIT test and later add the the test in a separate (follow up) PR. |
Okay cool thanks.
I think it makes sense to hold on this PR, as I don't think the required functionality can truly be verified until terra is at a point where the |
Fair enough. Let's merge Qiskit/qiskit#9002 first. |
def jit_func(amp): | ||
return get_samples(Constant(100, amp)) | ||
|
||
jit_samples = jax.jit(jit_func)(0.1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We expected this function will pass without using static_argnums after merging Qiskit/qiskit#9002 .
Hey @to24toro , there are two more tests I think it'd be good to add:
One last point: the |
ad0e3aa
to
b82c152
Compare
self.jit_wrap(jit_func)(0.1) | ||
self.jit_grad_wrap(jit_func)(0.1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to use jit_wrap
and jit_grad_wrap
.
Very useful.
b82c152
to
f36cb66
Compare
@@ -523,6 +524,8 @@ def solve( | |||
Array.default_backend() == "jax" | |||
and (method == "jax_odeint" or _is_diffrax_method(method)) | |||
and all(isinstance(x, Schedule) for x in signals_list) | |||
# check if jit transformation is already performed. | |||
and not (isinstance(jnp.array(0), core.Tracer)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is related with the issue #175
I modified two points at 769e342 for JAX-jitting. |
I've removed the "on hold" label, as you pointed out that terra has been sufficiently updated. Once the errors are resolved and I re-review we can merge this! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aside from current test failing issue (which seems to just be a sympy issue), I have some comments about reorganizing the new tests a little.
Looks good though.
self.assertTrue(signals[0].carrier_freq == 5.0) | ||
self.assertTrue(signals[1].carrier_freq == 3.1) | ||
self.assertTrue(signals[2].carrier_freq == 0.0) | ||
self.assertTrue(signals[3].carrier_freq == 4.0) | ||
|
||
def test_get_samples(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rewrite these tests to still call the InstructionToSignals
converter. get_samples
and lru_cache_expr
are both internal functions, so ideally we won't directly test them. There are exceptions to this in Dynamics, but in this case I don't think it's necessary. The inputs/outputs of these functions are directly fed from/to the converter, so I don't think anything is gained by directly testing these functions, as opposed to verifying the behaviour of the converter on a symbolic pulse.
Maybe change the name of this to test_SymbolicPulse
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
improved and changed the name at e2c5148
) | ||
|
||
|
||
class TestJaxGetSamples(QiskitDynamicsTestCase, TestJaxBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the preceding comment, I don't think it's necessary to have a separate class for checking the behaviour of get_samples
. These tests could be moved to the previous class, by rewriting them to verify that the converter returns the same thing, whether you pass in a Waveform
or a SymbolicPulse
.
The jit
test is great and very important, but again, it can be moved to the previous class and rewritten to work with the converter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed get samples test at 72cd9ad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this I mean to fully get rid of this test class, and to move the cases it's testing into the preceding test class. by calling InstructionToSignals
directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class you said moving to is TestPulseToSignals
?
If I move to the class, we fail to pass the test because TestPulseToSignals
is not a subclass of TestJaxBase
.
On the other hand, if TestPulseToSignals
inherits TestJaxBase
, all the test in TestPulseToSignals
will be skipped in python test and executed in only JAX test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand now. Maybe just rename this class then to TestPulseToSignalsJAXTransformations
or something along these lines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3f214c5
to
03592c1
Compare
Co-authored-by: Daniel Puzzuoli <dan.puzzuoli@gmail.com>
03592c1
to
bf66412
Compare
My author and commit name were not correct. So I am sorry to have modify them and force-push to pass license/cla. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some conflicts to resolve, and some final tests calling get_samples
to convert to calls directly to InstructionToSignals
.
|
||
See the :meth:`get_signals` method documentation for a detailed description of how pulse | ||
schedules are interpreted and translated into :class:`.DiscreteSignal` objects. | ||
schedule. The converter applies to instances of :class:`Schedule`. Instances of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think something's happened here - this is undoing some changes I had made in a previous PR (that I don't think you are intentionally trying to change).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am so sorry to confuse you. As you say, it is not my intention. I restored to that of main.
dt: Length of the samples. This is required by the converter as pulse schedule are | ||
specified in units of dt and typically do not carry the value of dt with them. | ||
carriers: A dict of analog carrier frequencies. The keys are the names of the channels | ||
dt: Length of the samples. This is required by the converter as pulse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two code blocks are similarly changing some documentation I had changed previously. Maybe some issue with a previous attempt at merging main
into this branch?
:class:`~qiskit.pulse.ScheduleBlock` must first be converted to | ||
:class:`~qiskit.pulse.Schedule` using the | ||
:func:`~qiskit.pulse.transforms.block_to_schedule` function in Qiskit Pulse. | ||
:class:`ScheduleBlock` must first be converted to :class:`Schedule` using the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly here, this is undoing some changes I had made.
) | ||
|
||
|
||
class TestJaxGetSamples(QiskitDynamicsTestCase, TestJaxBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this I mean to fully get rid of this test class, and to move the cases it's testing into the preceding test class. by calling InstructionToSignals
directly.
envelope=envelope_expr, | ||
valid_amp_conditions=valid_amp_conditions_expr, | ||
) | ||
return get_samples(instance) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
E.g. this test should use InstructionToSignals
instead of get_samples
directly.
jit_samples = jax.jit(jit_func_get_samples)(0.1) | ||
self.assertAllClose(jit_samples, self.constant_get_waveform_samples, atol=1e-7, rtol=1e-7) | ||
|
||
def test_pulse_types_combination_with_jax(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great test, that could be moved to the preceding class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! One final documentation comment.
|
||
The converter can be initialized with the optional arguments ``carriers`` and ``channels``. When | ||
``channels`` is given, only the signals specified by name in ``channels`` are returned. The | ||
``carriers`` dictionary specifies the analog carrier frequency of each channel. Here, the keys | ||
are the channel name, e.g. ``d12`` for drive channel number ``12``, and the values are the | ||
corresponding frequency. If a channel is not present in ``carriers`` it is assumed that the | ||
analog carrier frequency is zero. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My last question is: did you intentionally remove these spaces? Deleting these causes this whole block of text to be a single paragraph in the docs. I prefer these as separate paragraphs, as each is highlighting something different. Unless there is a specific reason for this, can you put the spaces back in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry. I didn't take into account the paragraph change. I returned back at c3338c5.
Summary
This PR adds
get_samples
toInstructionToSignals
for JAX-jitting when using qiskit-pulse and removes the usage ofget_waveform
method ofSymbolicPulse
.Details and comments
get_samples
function gets the envelope expression formSymbolicPulse
and calls sympy.lambdify with numerical backend specified by Array class. The lambdified function is lru cached for performance.