Skip to content

Commit

Permalink
impl jax-jit for pulse simulation
Browse files Browse the repository at this point in the history
  • Loading branch information
to24toro committed Jan 15, 2023
1 parent 2a28f31 commit 769e342
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions qiskit_dynamics/pulse/pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,31 +130,31 @@ def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]:
# build sample array to append to signal
times = self._dt * (start_sample + np.arange(len(inst_samples)))
samples = inst_samples * np.exp(
# Array(
Array(

This comment has been minimized.

Copy link
@to24toro

to24toro Jan 16, 2023

Author Contributor

In order to avoid using complex type at JAX-jitting, I wrapped it with Array

2.0j * np.pi * freq * times
+ 1.0j * phi
+ 2.0j * np.pi * phase_accumulations[chan]
# )
)
)
signals[chan].add_samples(start_sample, samples)

if isinstance(inst, ShiftPhase):
phases[chan] += inst.phase

if isinstance(inst, ShiftFrequency):
frequency_shifts[chan] += inst.frequency
phase_accumulations[chan] -= inst.frequency * start_sample * self._dt
frequency_shifts[chan] = frequency_shifts[chan] + Array(inst.frequency)
phase_accumulations[chan] = phase_accumulations[chan] - inst.frequency * start_sample * self._dt

if isinstance(inst, SetPhase):
phases[chan] = inst.phase

if isinstance(inst, SetFrequency):
phase_accumulations[chan] -= (
phase_accumulations[chan] = phase_accumulations[chan] - (
(inst.frequency - (frequency_shifts[chan] + signals[chan].carrier_freq))
* start_sample
* self._dt
)
frequency_shifts[chan] = inst.frequency - signals[chan].carrier_freq
frequency_shifts[chan] = inst.frequency- signals[chan].carrier_freq

# ensure all signals have the same number of samples
max_duration = 0
Expand Down

0 comments on commit 769e342

Please sign in to comment.