Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get samples function to InstructionToSignals for JAX-jit usage #149

Merged

Conversation

to24toro
Copy link
Contributor

@to24toro to24toro commented Nov 11, 2022

Summary

This PR adds get_samples to InstructionToSignals for JAX-jitting when using qiskit-pulse and removes the usage of get_waveform method of SymbolicPulse.

Details and comments

get_samples function gets the envelope expression form SymbolicPulse and calls sympy.lambdify with numerical backend specified by Array class. The lambdified function is lru cached for performance.

@CLAassistant
Copy link

CLAassistant commented Nov 11, 2022

CLA assistant check
All committers have signed the CLA.

@DanPuzzuoli
Copy link
Collaborator

DanPuzzuoli commented Nov 14, 2022

Thanks for the PR @to24toro !

I think the code you've added to Dynamics looks good - however there is a core problem right now that prevents this from being useful. Your test:

def jit_func(amp, sigma):
    return get_samples(Gaussian(duration=5, amp=amp, sigma=sigma))

jit_samples = jax.jit(jit_func, static_argnums=(0, 1))(0.983, 2)
self.assertAllClose(jit_samples, self.gauss_get_waveform_samples, atol=1e-7, rtol=1e-7)

sets static_argnums=(0, 1), however, for this to be useful, this test should work without usingstatic_argnums. I.e. you should just be able to do:

jit_samples = jax.jit(jit_func)(0.983, 2)

If you try to do this now with the current main branch of terra, you get a JAX error, coming from the line:

...
File /opt/anaconda3/envs/devEnv310/lib/python3.10/site-packages/qiskit/pulse/library/symbolic_pulses.py:436, in SymbolicPulse.__init__(self, pulse_type, duration, parameters, name, limit_amplitude, envelope, constraints, valid_amp_conditions)
    431 # TODO remove this.
    432 #  This is due to convention in IBM Quantum backends where "amp" is treated as a
    433 #  special parameter that must be defined in the form [real, imaginary].
    434 #  this check must be removed because Qiskit pulse should be backend agnostic.
    435 if "amp" in parameters and not isinstance(parameters["amp"], ParameterExpression):
--> 436     parameters["amp"] = complex(parameters["amp"])
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `complex` function. If trying to convert the data type of a value, try using `x.astype(complex)` or `jnp.array(x, complex)` instead.
The error occurred while tracing the function jit_func at /var/folders/v5/5t1xdchn2ws5l1nj9h02vm0m0000gn/T/ipykernel_3597/1345434633.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'amp'.

which is because in this case, parameters["amp"] is a JAX tracer, and calling complex(x) for x a JAX tracer type will cause this ConcretizationTypeError.

With this error being raised, there is unfortunately no benefit to having get_samples execute using JAX - the only benefit of JAX is being able to jit/grad/etc. If we can't truly jit a computation, then there is no reason to use JAX over numpy. Obviously, since this error is being raised by terra, it will be necessary to make changes to terra before this can work.

To move forward with this PR, I think it makes sense to figure out whatever changes are necessary in terra to make the above test case work with the static_argnums call kwarg removed, make those changes, then merge this PR to dynamics once those are released in terra. @nkanazawa1989 any comments on this thought process?

@nkanazawa1989
Copy link
Contributor

The special casing of amp must be removed from terra (only duration and amp are the non-float and unfortunately Python is not typed). Currently Qiskit/qiskit#9002 is trying to introduce (amp, float) representation to symbolic pulses, and this makes all parameters float-type except for duration. This approach allows us to remove builtin typecasting and we can eventually remove static_arguments.

Feel free to on hold this until above PR is merged, or you can merge this PR without JIT test and later add the the test in a separate (follow up) PR.

@DanPuzzuoli
Copy link
Collaborator

Okay cool thanks.

Feel free to on hold this until above PR is merged, or you can merge this PR without JIT test and later add the the test in a separate (follow up) PR.

I think it makes sense to hold on this PR, as I don't think the required functionality can truly be verified until terra is at a point where the static_argnums can be removed.

@nkanazawa1989
Copy link
Contributor

Fair enough. Let's merge Qiskit/qiskit#9002 first.

def jit_func(amp):
return get_samples(Constant(100, amp))

jit_samples = jax.jit(jit_func)(0.1)
Copy link
Contributor Author

@to24toro to24toro Nov 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We expected this function will pass without using static_argnums after merging Qiskit/qiskit#9002 .

@DanPuzzuoli
Copy link
Collaborator

Hey @to24toro , there are two more tests I think it'd be good to add:

  • Can you add a test case for computing gradients as well? You can take the same jit test, but change it so that at the end, instead of jit, you do jax.jit(jax.grad(jit_func))(0.1). You'll have to change the the function so that it returns a real scalar (as grad only works on real scalars), so you could just change what the function returns to be get_samples(instance).sum().real. Testing this combination of gradding/jitting is important because it can introduce a bit more complexity into what JAX is doing, and sometimes raise errors where jit/grad don't raise them on their own.
  • I'm also thinking we should add a test for Solver that compiles/differentiates a simulation of a pulse schedule with symbolic pulses. In the test/dynamics/solvers/test_solver_classes.py file there is a test class called TestPulseSimulationJAX. Can you add a test to this class that defines a function building a schedule with symbolic pulses, simulating it with a Solver, and returning the final state, and attempts to compile/differentiate it? This may be a little redundant with the other tests but it will be a good integration test.

One last point: the TestJaxBase class has two helper functions: jit_wrap and jit_grad_wrap, which can be used to jit a function, or to jit(grad(func)) a function func, where the latter takes the sum and the real part of the output. These may eb helpful/convenient.

@to24toro to24toro force-pushed the jax_implementation_to_pulse branch from ad0e3aa to b82c152 Compare December 2, 2022 07:29
Comment on lines 285 to 286
self.jit_wrap(jit_func)(0.1)
self.jit_grad_wrap(jit_func)(0.1)
Copy link
Contributor Author

@to24toro to24toro Dec 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to use jit_wrap and jit_grad_wrap.
Very useful.

@to24toro to24toro force-pushed the jax_implementation_to_pulse branch from b82c152 to f36cb66 Compare December 10, 2022 03:53
@to24toro to24toro changed the title Add get samples function to InstructionToSignals Add get samples function to InstructionToSignals for JAX-jit usage Jan 15, 2023
@@ -523,6 +524,8 @@ def solve(
Array.default_backend() == "jax"
and (method == "jax_odeint" or _is_diffrax_method(method))
and all(isinstance(x, Schedule) for x in signals_list)
# check if jit transformation is already performed.
and not (isinstance(jnp.array(0), core.Tracer))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related with the issue #175

@to24toro
Copy link
Contributor Author

I modified two points at 769e342 for JAX-jitting.

@DanPuzzuoli
Copy link
Collaborator

I've removed the "on hold" label, as you pointed out that terra has been sufficiently updated. Once the errors are resolved and I re-review we can merge this!

Copy link
Collaborator

@DanPuzzuoli DanPuzzuoli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside from current test failing issue (which seems to just be a sympy issue), I have some comments about reorganizing the new tests a little.

Looks good though.

qiskit_dynamics/pulse/pulse_to_signals.py Show resolved Hide resolved
qiskit_dynamics/pulse/pulse_to_signals.py Outdated Show resolved Hide resolved
qiskit_dynamics/pulse/pulse_to_signals.py Outdated Show resolved Hide resolved
self.assertTrue(signals[0].carrier_freq == 5.0)
self.assertTrue(signals[1].carrier_freq == 3.1)
self.assertTrue(signals[2].carrier_freq == 0.0)
self.assertTrue(signals[3].carrier_freq == 4.0)

def test_get_samples(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrite these tests to still call the InstructionToSignals converter. get_samples and lru_cache_expr are both internal functions, so ideally we won't directly test them. There are exceptions to this in Dynamics, but in this case I don't think it's necessary. The inputs/outputs of these functions are directly fed from/to the converter, so I don't think anything is gained by directly testing these functions, as opposed to verifying the behaviour of the converter on a symbolic pulse.

Maybe change the name of this to test_SymbolicPulse.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

improved and changed the name at e2c5148

)


class TestJaxGetSamples(QiskitDynamicsTestCase, TestJaxBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the preceding comment, I don't think it's necessary to have a separate class for checking the behaviour of get_samples. These tests could be moved to the previous class, by rewriting them to verify that the converter returns the same thing, whether you pass in a Waveform or a SymbolicPulse.

The jit test is great and very important, but again, it can be moved to the previous class and rewritten to work with the converter.

Copy link
Contributor Author

@to24toro to24toro Feb 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed get samples test at 72cd9ad

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this I mean to fully get rid of this test class, and to move the cases it's testing into the preceding test class. by calling InstructionToSignals directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class you said moving to is TestPulseToSignals?
If I move to the class, we fail to pass the test because TestPulseToSignals is not a subclass of TestJaxBase.
On the other hand, if TestPulseToSignals inherits TestJaxBase, all the test in TestPulseToSignals will be skipped in python test and executed in only JAX test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand now. Maybe just rename this class then to TestPulseToSignalsJAXTransformations or something along these lines.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename and fix at 454f119 and 2814cf0

test/dynamics/solvers/test_solver_classes.py Show resolved Hide resolved
@to24toro to24toro force-pushed the jax_implementation_to_pulse branch from 3f214c5 to 03592c1 Compare February 2, 2023 09:04
@to24toro to24toro force-pushed the jax_implementation_to_pulse branch from 03592c1 to bf66412 Compare February 2, 2023 09:50
@to24toro
Copy link
Contributor Author

to24toro commented Feb 2, 2023

My author and commit name were not correct. So I am sorry to have modify them and force-push to pass license/cla.

Copy link
Collaborator

@DanPuzzuoli DanPuzzuoli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some conflicts to resolve, and some final tests calling get_samples to convert to calls directly to InstructionToSignals.


See the :meth:`get_signals` method documentation for a detailed description of how pulse
schedules are interpreted and translated into :class:`.DiscreteSignal` objects.
schedule. The converter applies to instances of :class:`Schedule`. Instances of
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think something's happened here - this is undoing some changes I had made in a previous PR (that I don't think you are intentionally trying to change).

Copy link
Contributor Author

@to24toro to24toro Feb 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am so sorry to confuse you. As you say, it is not my intention. I restored to that of main.

dt: Length of the samples. This is required by the converter as pulse schedule are
specified in units of dt and typically do not carry the value of dt with them.
carriers: A dict of analog carrier frequencies. The keys are the names of the channels
dt: Length of the samples. This is required by the converter as pulse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two code blocks are similarly changing some documentation I had changed previously. Maybe some issue with a previous attempt at merging main into this branch?

:class:`~qiskit.pulse.ScheduleBlock` must first be converted to
:class:`~qiskit.pulse.Schedule` using the
:func:`~qiskit.pulse.transforms.block_to_schedule` function in Qiskit Pulse.
:class:`ScheduleBlock` must first be converted to :class:`Schedule` using the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly here, this is undoing some changes I had made.

)


class TestJaxGetSamples(QiskitDynamicsTestCase, TestJaxBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this I mean to fully get rid of this test class, and to move the cases it's testing into the preceding test class. by calling InstructionToSignals directly.

envelope=envelope_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
return get_samples(instance)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. this test should use InstructionToSignals instead of get_samples directly.

jit_samples = jax.jit(jit_func_get_samples)(0.1)
self.assertAllClose(jit_samples, self.constant_get_waveform_samples, atol=1e-7, rtol=1e-7)

def test_pulse_types_combination_with_jax(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great test, that could be moved to the preceding class.

@DanPuzzuoli DanPuzzuoli added this to the pulse sim v1 milestone Feb 9, 2023
Copy link
Collaborator

@DanPuzzuoli DanPuzzuoli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! One final documentation comment.

Comment on lines 46 to 53

The converter can be initialized with the optional arguments ``carriers`` and ``channels``. When
``channels`` is given, only the signals specified by name in ``channels`` are returned. The
``carriers`` dictionary specifies the analog carrier frequency of each channel. Here, the keys
are the channel name, e.g. ``d12`` for drive channel number ``12``, and the values are the
corresponding frequency. If a channel is not present in ``carriers`` it is assumed that the
analog carrier frequency is zero.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My last question is: did you intentionally remove these spaces? Deleting these causes this whole block of text to be a single paragraph in the docs. I prefer these as separate paragraphs, as each is highlighting something different. Unless there is a specific reason for this, can you put the spaces back in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry. I didn't take into account the paragraph change. I returned back at c3338c5.

@DanPuzzuoli DanPuzzuoli self-requested a review February 15, 2023 22:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants