Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: mangle OpenQASM keywords #28

Merged
18 changes: 17 additions & 1 deletion src/autoqasm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import autoqasm.types as aq_types
from autoqasm import errors
from autoqasm.program.gate_calibrations import GateCalibration
from autoqasm.reserved_keywords import sanitize_parameter_name
from autoqasm.types import QubitIdentifierType as Qubit


Expand Down Expand Up @@ -323,6 +324,11 @@ def _convert_subroutine(
with aq_program.build_program() as program_conversion_context:
oqpy_program = program_conversion_context.get_oqpy_program()

# Iterate over list of dictionary keys to avoid runtime error
for key in list(kwargs):
new_name = sanitize_parameter_name(key)
kwargs[new_name] = kwargs.pop(key)

if f not in program_conversion_context.subroutines_processing:
# Mark that we are starting to process this function to short-circuit recursion
program_conversion_context.subroutines_processing.add(f)
Expand Down Expand Up @@ -419,6 +425,12 @@ def _wrap_for_oqpy_subroutine(f: Callable, options: converter.ConversionOptions)
def _func(*args, **kwargs) -> Any:
inner_program: oqpy.Program = args[0]
with aq_program.get_program_conversion_context().push_oqpy_program(inner_program):
# Bind args and kwargs to '_func' signature
sig = inspect.signature(_func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
args = bound_args.args
kwargs = bound_args.kwargs
result = aq_transpiler.converted_call(f, args[1:], kwargs, options=options)
inner_program.autodeclare()
return result
Expand All @@ -441,8 +453,12 @@ def _func(*args, **kwargs) -> Any:
"is missing a required type hint."
)

# Check whether 'param.name' is a reserved keyword
new_name = sanitize_parameter_name(param.name)
_func.__annotations__.pop(param.name)

new_param = inspect.Parameter(
name=param.name,
name=new_name,
kind=param.kind,
annotation=aq_types.map_parameter_type(param.annotation),
)
Expand Down
77 changes: 77 additions & 0 deletions src/autoqasm/reserved_keywords.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copied from:
# https://github.com/openqasm/openqasm/blob/main/source/grammar/qasm3Lexer.g4
# https://github.com/openqasm/openpulse-python/blob/main/source/grammar/openpulseLexer.g4

reserved_keywords = {
# openQASM keywords
"angle",
"array",
"barrier",
"bit",
"bool",
"box",
"cal",
"case",
"complex",
"const",
"creg",
"ctrl",
"default",
"defcal",
"defcalgrammar",
"delay",
"duration",
"durationof",
"end",
"euler",
"extern",
"false",
"float",
"gate",
"gphase",
"im",
"include",
"input",
"int",
"inv",
"let",
"OPENQASM",
"measure",
"mutable",
"negctrl",
"output",
"pi",
"pragma",
"qreg",
"qubit",
"readonly",
"reset",
"return",
"sizeof",
"stretch",
"switch",
"tau",
"true",
"U",
"uint",
"void",
# openpulse keywords
"frame",
"port",
"waveform",
}


def sanitize_parameter_name(name: str) -> str:
"""
Method to modify the variable name if it is a
reserved keyword

Args:
name (str): Name of the variable to be checked

Returns:
str: Returns a modified 'name' that has an underscore ('_') appended to it;
otherwise, it returns the original 'name' unchanged
"""
return f"{name}_" if name in reserved_keywords else name
26 changes: 26 additions & 0 deletions test/unit_tests/autoqasm/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
the local simulator.
"""

import math

import pytest
from braket.devices import LocalSimulator
from braket.tasks.local_quantum_task import LocalQuantumTask
Expand Down Expand Up @@ -1266,3 +1268,27 @@ def test(int[32] a, int[32] b) {
test(2, 3);
test(4, 5);"""
assert main.build().to_ir() == expected


def test_subroutine_call_with_reserved_keyword():
"""Test that subroutine call works with reserved keyword as a variable name"""

@aq.subroutine
def make_input_state(input: int, theta: float):
jcjaskula-aws marked this conversation as resolved.
Show resolved Hide resolved
rx(input, theta)
measure(input)

@aq.main(num_qubits=3)
def teleportation():
input, theta = 0, math.pi / 2
make_input_state(theta=theta, input=input)

expected = """OPENQASM 3.0;
def make_input_state(int[32] input_, float[64] theta) {
rx(theta) __qubits__[input_];
bit __bit_0__;
__bit_0__ = measure __qubits__[input_];
}
qubit[3] __qubits__;
make_input_state(0, 1.5707963267948966);"""
assert teleportation.build().to_ir() == expected
8 changes: 4 additions & 4 deletions test/unit_tests/autoqasm/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,8 @@ def sub(float[64] alpha, float[64] theta) {
rx(theta) __qubits__[0];
rx(alpha) __qubits__[1];
}
def rx_alpha(int[32] qubit) {
rx(alpha) __qubits__[qubit];
def rx_alpha(int[32] qubit_) {
rx(alpha) __qubits__[qubit_];
}
float alpha = 0.5;
float beta = 1.5;
Expand All @@ -427,8 +427,8 @@ def parametric(alpha: float, beta: float):
bound_prog = parametric.build().make_bound_program({"beta": np.pi})

expected = """OPENQASM 3.0;
def rx_alpha(int[32] qubit, float[64] theta) {
rx(theta) __qubits__[qubit];
def rx_alpha(int[32] qubit_, float[64] theta) {
rx(theta) __qubits__[qubit_];
}
input float alpha;
float beta = 3.141592653589793;
Expand Down
4 changes: 2 additions & 2 deletions test/unit_tests/autoqasm/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def zne() -> aq.BitVar:
def expected(scale, angle):
return (
"""OPENQASM 3.0;
def circuit(float[64] angle) {
rx(angle) __qubits__[0];
def circuit(float[64] angle_) {
rx(angle_) __qubits__[0];
cnot __qubits__[0], __qubits__[1];
}
output bit return_value;
Expand Down
12 changes: 6 additions & 6 deletions test/unit_tests/autoqasm/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def main():
annotation_test(True)

expected = """OPENQASM 3.0;
def annotation_test(bool input) {
def annotation_test(bool input_) {
}
annotation_test(true);"""

Expand All @@ -328,7 +328,7 @@ def main():
annotation_test(1)

expected = """OPENQASM 3.0;
def annotation_test(int[32] input) {
def annotation_test(int[32] input_) {
}
annotation_test(1);"""

Expand All @@ -347,7 +347,7 @@ def main():
annotation_test(1.0)

expected = """OPENQASM 3.0;
def annotation_test(float[64] input) {
def annotation_test(float[64] input_) {
rmshaffer marked this conversation as resolved.
Show resolved Hide resolved
}
annotation_test(1.0);"""

Expand All @@ -366,7 +366,7 @@ def main():
annotation_test(1)

expected = """OPENQASM 3.0;
def annotation_test(qubit input) {
def annotation_test(qubit input_) {
}
qubit[2] __qubits__;
annotation_test(__qubits__[1]);"""
Expand Down Expand Up @@ -403,7 +403,7 @@ def main():
annotation_test(a)

expected = """OPENQASM 3.0;
def annotation_test(bit input) {
def annotation_test(bit input_) {
}
bit a = 1;
annotation_test(a);"""
Expand All @@ -423,7 +423,7 @@ def main():
annotation_test(aq.BitVar(1))

expected = """OPENQASM 3.0;
def annotation_test(bit input) {
def annotation_test(bit input_) {
}
bit __bit_0__ = 1;
annotation_test(__bit_0__);"""
Expand Down
Loading