Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom resets using the @Reset() attribute #1981

Merged
merged 9 commits into from
Oct 29, 2024
71 changes: 70 additions & 1 deletion compiler/qsc/src/codegen/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ mod adaptive_ri_profile {

declare void @__quantum__qis__x__body(%Qubit*)

declare void @__quantum__qis__reset__body(%Qubit*)
declare void @__quantum__qis__reset__body(%Qubit*) #1

declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1

Expand Down Expand Up @@ -1263,4 +1263,73 @@ mod adaptive_ri_profile {
!10 = !{i32 1, !"multiple_target_branching", i1 false}
"#]].assert_eq(&qir);
}

#[test]
fn custom_reset_generates_correct_qir() {
let source = "namespace Test {
operation Main() : Result {
use q = Qubit();
__quantum__qis__custom_reset__body(q);
M(q)
}

@Reset()
operation __quantum__qis__custom_reset__body(target: Qubit) : Unit {
body intrinsic;
}
}";
let sources = SourceMap::new([("test.qs".into(), source.into())], None);
let language_features = LanguageFeatures::default();
let capabilities = TargetCapabilityFlags::Adaptive
| TargetCapabilityFlags::QubitReset
| TargetCapabilityFlags::IntegerComputations;

let (std_id, store) = crate::compile::package_store_with_stdlib(capabilities);
let qir = get_qir(
sources,
language_features,
capabilities,
store,
&[(std_id, None)],
)
.expect("the input program set in the `source` variable should be valid Q#");
expect![[r#"
%Result = type opaque
%Qubit = type opaque

define void @ENTRYPOINT__main() #0 {
block_0:
call void @__quantum__qis__custom_reset__body(%Qubit* inttoptr (i64 0 to %Qubit*))
call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*))
call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* null)
ret void
}

declare void @__quantum__qis__custom_reset__body(%Qubit*) #1

declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1

declare void @__quantum__rt__result_record_output(%Result*, i8*)

attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" }
attributes #1 = { "irreversible" }

; module flags

!llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7, !8, !9, !10}

!0 = !{i32 1, !"qir_major_version", i32 1}
!1 = !{i32 7, !"qir_minor_version", i32 0}
!2 = !{i32 1, !"dynamic_qubit_management", i1 false}
!3 = !{i32 1, !"dynamic_result_management", i1 false}
!4 = !{i32 1, !"classical_ints", i1 true}
!5 = !{i32 1, !"qubit_resetting", i1 true}
!6 = !{i32 1, !"classical_floats", i1 false}
!7 = !{i32 1, !"backwards_branching", i1 false}
!8 = !{i32 1, !"classical_fixed_points", i1 false}
!9 = !{i32 1, !"user_functions", i1 false}
!10 = !{i32 1, !"multiple_target_branching", i1 false}
"#]]
.assert_eq(&qir);
}
}
7 changes: 5 additions & 2 deletions compiler/qsc_codegen/src/qir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,11 @@ impl ToQir<String> for rir::Callable {
return format!(
"declare {output_type} @{}({input_type}){}",
self.name,
if self.call_type == rir::CallableType::Measurement {
// Measurement callables are a special case that needs the irreversable attribute.
if matches!(
self.call_type,
rir::CallableType::Measurement | rir::CallableType::Reset
) {
// These callables are a special case that need the irreversable attribute.
" #1"
} else {
""
Expand Down
4 changes: 4 additions & 0 deletions compiler/qsc_fir/src/fir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,8 @@ pub struct CallableDecl {
pub functors: FunctorSetValue,
/// The callable implementation.
pub implementation: CallableImpl,
/// The attributes of the callable, (e.g.: Measurement or Reset).
pub attrs: Vec<Attr>,
}

impl CallableDecl {
Expand Down Expand Up @@ -1521,6 +1523,8 @@ pub enum Attr {
EntryPoint,
/// Indicates that a callable is a measurement.
Measurement,
/// Indicates that a callable is a reset.
Reset,
}

/// A field.
Expand Down
1 change: 1 addition & 0 deletions compiler/qsc_frontend/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ pub(super) fn lift(
adj: None,
ctl: None,
ctl_adj: None,
attrs: Vec::default(),
};

(free_vars, callable)
Expand Down
54 changes: 37 additions & 17 deletions compiler/qsc_frontend/src/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ pub(super) enum Error {
#[error("invalid attribute arguments: expected {0}")]
#[diagnostic(code("Qsc.LowerAst.InvalidAttrArgs"))]
InvalidAttrArgs(String, #[label] Span),
#[error("invalid use of the Measurement attribute on a function")]
#[error("invalid use of the {0} attribute on a function")]
#[diagnostic(help("try declaring the callable as an operation"))]
#[diagnostic(code("Qsc.LowerAst.InvalidMeasurementAttrOnFunction"))]
InvalidMeasurementAttrOnFunction(#[label] Span),
#[diagnostic(code("Qsc.LowerAst.InvalidAttrOnFunction"))]
InvalidAttrOnFunction(String, #[label] Span),
#[error("missing callable body")]
#[diagnostic(code("Qsc.LowerAst.MissingBody"))]
MissingBody(#[label] Span),
Expand Down Expand Up @@ -368,6 +368,15 @@ impl With<'_> {
None
}
},
Ok(hir::Attr::Reset) => match &*attr.arg.kind {
ast::ExprKind::Tuple(args) if args.is_empty() => Some(hir::Attr::Reset),
_ => {
self.lowerer
.errors
.push(Error::InvalidAttrArgs("()".to_string(), attr.arg.span));
None
}
},
Err(()) => {
self.lowerer.errors.push(Error::UnknownAttr(
attr.name.name.to_string(),
Expand Down Expand Up @@ -429,29 +438,40 @@ impl With<'_> {
adj,
ctl,
ctl_adj,
attrs: attrs.to_vec(),
}
}

fn check_invalid_attrs_on_function(&mut self, attrs: &[hir::Attr], span: Span) {
const INVALID_ATTRS: [hir::Attr; 2] = [hir::Attr::Measurement, hir::Attr::Reset];

for invalid_attr in &INVALID_ATTRS {
if attrs.contains(invalid_attr) {
self.lowerer.errors.push(Error::InvalidAttrOnFunction(
format!("{invalid_attr:?}"),
span,
));
}
}
}

fn lower_callable_kind(
&mut self,
kind: ast::CallableKind,
attrs: &[qsc_hir::hir::Attr],
attrs: &[hir::Attr],
span: Span,
) -> hir::CallableKind {
if attrs.contains(&qsc_hir::hir::Attr::Measurement) {
match kind {
ast::CallableKind::Operation => hir::CallableKind::Measurement,
ast::CallableKind::Function => {
self.lowerer
.errors
.push(Error::InvalidMeasurementAttrOnFunction(span));
hir::CallableKind::Function
}
match kind {
ast::CallableKind::Function => {
self.check_invalid_attrs_on_function(attrs, span);
hir::CallableKind::Function
}
orpuente-MS marked this conversation as resolved.
Show resolved Hide resolved
} else {
match kind {
ast::CallableKind::Operation => hir::CallableKind::Operation,
ast::CallableKind::Function => hir::CallableKind::Function,
ast::CallableKind::Operation => {
if attrs.contains(&hir::Attr::Measurement) {
hir::CallableKind::Measurement
} else {
hir::CallableKind::Operation
}
}
}
}
Expand Down
28 changes: 27 additions & 1 deletion compiler/qsc_frontend/src/lower/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2221,7 +2221,8 @@ fn test_measurement_attr_on_function_issues_error() {
"#},
&expect![[r#"
[
InvalidMeasurementAttrOnFunction(
InvalidAttrOnFunction(
"Measurement",
Span {
lo: 49,
hi: 52,
Expand All @@ -2232,6 +2233,31 @@ fn test_measurement_attr_on_function_issues_error() {
);
}

#[test]
fn test_reset_attr_on_function_issues_error() {
check_errors(
indoc! {r#"
namespace Test {
@Reset()
function Foo(q: Qubit) : Unit {
body intrinsic;
}
}
"#},
&expect![[r#"
[
InvalidAttrOnFunction(
"Reset",
Span {
lo: 43,
hi: 46,
},
),
]
"#]],
);
}

#[test]
fn item_docs() {
check_hir(
Expand Down
8 changes: 7 additions & 1 deletion compiler/qsc_hir/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ pub struct CallableDecl {
pub ctl: Option<SpecDecl>,
/// The controlled adjoint specialization.
pub ctl_adj: Option<SpecDecl>,
/// The attributes of the callable, (e.g.: Measurement or Reset).
pub attrs: Vec<Attr>,
}

impl CallableDecl {
Expand Down Expand Up @@ -1352,8 +1354,11 @@ pub enum Attr {
/// and any implementation should be ignored.
SimulatableIntrinsic,
/// Indicates that a callable is a measurement. This means that the operation will be marked as
/// "irreversible" in the generated QIR.
/// "irreversible" in the generated QIR, and output Result types will be moved to the arguments.
Measurement,
/// Indicates that a callable is a reset. This means that the operation will be marked as
/// "irreversible" in the generated QIR.
Reset,
}

impl FromStr for Attr {
Expand All @@ -1366,6 +1371,7 @@ impl FromStr for Attr {
"Unimplemented" => Ok(Self::Unimplemented),
"SimulatableIntrinsic" => Ok(Self::SimulatableIntrinsic),
"Measurement" => Ok(Self::Measurement),
"Reset" => Ok(Self::Reset),
_ => Err(()),
}
}
Expand Down
3 changes: 3 additions & 0 deletions compiler/qsc_lowerer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ impl Lowerer {
};
CallableImpl::Spec(specialized_implementation)
};
let attrs = lower_attrs(&decl.attrs);

self.assigner.reset_local();
self.locals.clear();
Expand All @@ -299,6 +300,7 @@ impl Lowerer {
output,
functors,
implementation,
attrs,
}
}

Expand Down Expand Up @@ -888,6 +890,7 @@ fn lower_attrs(attrs: &[hir::Attr]) -> Vec<fir::Attr> {
.filter_map(|attr| match attr {
hir::Attr::EntryPoint => Some(fir::Attr::EntryPoint),
hir::Attr::Measurement => Some(fir::Attr::Measurement),
hir::Attr::Reset => Some(fir::Attr::Reset),
hir::Attr::SimulatableIntrinsic | hir::Attr::Unimplemented | hir::Attr::Config => None,
})
.collect()
Expand Down
57 changes: 57 additions & 0 deletions compiler/qsc_partial_eval/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,9 @@ impl<'a> PartialEvaluator<'a> {
if matches!(callable_decl.kind, qsc_fir::fir::CallableKind::Measurement) {
return self.measure_qubits(callable_decl, args_value, args_span);
}
if callable_decl.attrs.contains(&fir::Attr::Reset) {
return self.reset_qubits(store_item_id, callable_decl, args_value);
}

// There are a few special cases regarding intrinsic callables. Identify them and handle them properly.
match callable_decl.name.name.as_ref() {
Expand Down Expand Up @@ -2372,6 +2375,60 @@ impl<'a> PartialEvaluator<'a> {
result_value
}

fn reset_qubits(
orpuente-MS marked this conversation as resolved.
Show resolved Hide resolved
&mut self,
store_item_id: StoreItemId,
callable_decl: &CallableDecl,
args_value: Value,
) -> Result<Value, Error> {
let callable_package = self.package_store.get(store_item_id.package);
let input_type: Vec<rir::Ty> = callable_package
.derive_callable_input_params(callable_decl)
.iter()
.map(|input_param| map_fir_type_to_rir_type(&input_param.ty))
.collect();
let output_type = if callable_decl.output == Ty::UNIT {
None
} else {
panic!("the expressions that make it to this point should return Unit");
};

let measurement_callable = Callable {
name: callable_decl.name.name.to_string(),
input_type,
output_type,
body: None,
call_type: CallableType::Reset,
};

// Resolve the call arguments, create the call instruction and insert it to the current block.
let (args, ctls_arg) = self
.resolve_args(
(store_item_id.package, callable_decl.input).into(),
args_value,
None,
None,
None,
)
.expect("no controls to verify");
assert!(
ctls_arg.is_none(),
"intrinsic operations cannot have controls"
);
let operands = args
.into_iter()
.map(|arg| self.map_eval_value_to_rir_operand(&arg.into_value()))
.collect();

// Check if the callable has already been added to the program and if not do so now.
let measure_callable_id = self.get_or_insert_callable(measurement_callable);
let instruction = Instruction::Call(measure_callable_id, operands, None);
let current_block = self.get_current_rir_block_mut();
current_block.0.push(instruction);

Ok(Value::unit())
}

fn release_qubit(&mut self, args_value: Value) -> Value {
let qubit = args_value.unwrap_qubit();
self.resource_manager.release_qubit(qubit);
Expand Down
Loading
Loading