diff --git a/crates/mun_codegen/src/ir/body.rs b/crates/mun_codegen/src/ir/body.rs index ae77b23c6..20bd7c1fe 100644 --- a/crates/mun_codegen/src/ir/body.rs +++ b/crates/mun_codegen/src/ir/body.rs @@ -11,8 +11,17 @@ use mun_hir::{ }; use std::{collections::HashMap, mem, sync::Arc}; +use inkwell::basic_block::BasicBlock; use inkwell::values::PointerValue; +struct LoopInfo { + break_values: Vec<( + inkwell::values::BasicValueEnum, + inkwell::basic_block::BasicBlock, + )>, + exit_block: BasicBlock, +} + pub(crate) struct BodyIrGenerator<'a, 'b, D: IrDatabase> { db: &'a D, module: &'a Module, @@ -25,6 +34,7 @@ pub(crate) struct BodyIrGenerator<'a, 'b, D: IrDatabase> { pat_to_name: HashMap, function_map: &'a HashMap, dispatch_table: &'b DispatchTable, + active_loop: Option, } impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { @@ -58,6 +68,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { pat_to_name: HashMap::default(), function_map, dispatch_table, + active_loop: None, } } @@ -132,6 +143,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { } => self.gen_if(expr, *condition, *then_branch, *else_branch), Expr::Return { expr: ret_expr } => self.gen_return(expr, *ret_expr), Expr::Loop { body } => self.gen_loop(expr, *body), + Expr::Break { expr: break_expr } => self.gen_break(expr, *break_expr), _ => unimplemented!("unimplemented expr type {:?}", &body[expr]), } } @@ -575,9 +587,31 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { None } + fn gen_break(&mut self, _expr: ExprId, break_expr: Option) -> Option { + let break_value = break_expr.and_then(|expr| self.gen_expr(expr)); + let loop_info = self.active_loop.as_mut().unwrap(); + if let Some(break_value) = break_value { + loop_info + .break_values + .push((break_value, self.builder.get_insert_block().unwrap())); + } + self.builder + .build_unconditional_branch(&loop_info.exit_block); + None + } + fn gen_loop(&mut self, _expr: ExprId, body_expr: ExprId) -> Option { let context = self.module.get_context(); let loop_block = context.append_basic_block(&self.fn_value, "loop"); + let exit_block = context.append_basic_block(&self.fn_value, "exit"); + + // Build a new loop info struct + let loop_info = LoopInfo { + exit_block, + break_values: Vec::new(), + }; + + let prev_loop = std::mem::replace(&mut self.active_loop, Some(loop_info)); // Insert an explicit fall through from the current block to the loop self.builder.build_unconditional_branch(&loop_block); @@ -589,6 +623,23 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { // Jump to the start of the loop self.builder.build_unconditional_branch(&loop_block); - None + let LoopInfo { + exit_block, + break_values, + } = std::mem::replace(&mut self.active_loop, prev_loop).unwrap(); + + // Move the builder to the exit block + self.builder.position_at_end(&exit_block); + + if !break_values.is_empty() { + let (value, _) = break_values.first().unwrap(); + let phi = self.builder.build_phi(value.get_type(), "exit"); + for (ref value, ref block) in break_values { + phi.add_incoming(&[(value, block)]) + } + Some(phi.as_basic_value()) + } else { + None + } } } diff --git a/crates/mun_codegen/src/snapshots/test__loop_break_expr.snap b/crates/mun_codegen/src/snapshots/test__loop_break_expr.snap new file mode 100644 index 000000000..d6e3e9909 --- /dev/null +++ b/crates/mun_codegen/src/snapshots/test__loop_break_expr.snap @@ -0,0 +1,29 @@ +--- +source: crates/mun_codegen/src/test.rs +expression: "fn foo(n:int):int {\n loop {\n if n > 5 {\n break n;\n }\n if n > 10 {\n break 10;\n }\n n += 1;\n }\n}" +--- +; ModuleID = 'main.mun' +source_filename = "main.mun" + +define i64 @foo(i64) { +body: + br label %loop + +loop: ; preds = %if_merge6, %body + %n.0 = phi i64 [ %0, %body ], [ %add, %if_merge6 ] + %greater = icmp sgt i64 %n.0, 5 + br i1 %greater, label %exit, label %if_merge + +exit: ; preds = %if_merge, %loop + %exit8 = phi i64 [ %n.0, %loop ], [ 10, %if_merge ] + ret i64 %exit8 + +if_merge: ; preds = %loop + %greater4 = icmp sgt i64 %n.0, 10 + br i1 %greater4, label %exit, label %if_merge6 + +if_merge6: ; preds = %if_merge + %add = add i64 %n.0, 1 + br label %loop +} + diff --git a/crates/mun_codegen/src/test.rs b/crates/mun_codegen/src/test.rs index 29285116b..eaaed1e2c 100644 --- a/crates/mun_codegen/src/test.rs +++ b/crates/mun_codegen/src/test.rs @@ -327,6 +327,25 @@ fn loop_expr() { ) } +#[test] +fn loop_break_expr() { + test_snapshot( + r#" + fn foo(n:int):int { + loop { + if n > 5 { + break n; + } + if n > 10 { + break 10; + } + n += 1; + } + } + "#, + ) +} + fn test_snapshot(text: &str) { let text = text.trim().replace("\n ", "\n"); diff --git a/crates/mun_runtime/src/test.rs b/crates/mun_runtime/src/test.rs index 9facd2de6..c2573cb76 100644 --- a/crates/mun_runtime/src/test.rs +++ b/crates/mun_runtime/src/test.rs @@ -214,6 +214,33 @@ fn fibonacci_loop() { assert_invoke_eq!(i64, 46368, driver, "fibonacci", 24i64); } +#[test] +fn fibonacci_loop_break() { + let mut driver = TestDriver::new( + r#" + fn fibonacci(n:int):int { + let a = 0; + let b = 1; + let i = 1; + loop { + if i > n { + break a; + } + let sum = a + b; + a = b; + b = sum; + i += 1; + } + } + "#, + ); + + assert_invoke_eq!(i64, 5, driver, "fibonacci", 5i64); + assert_invoke_eq!(i64, 89, driver, "fibonacci", 11i64); + assert_invoke_eq!(i64, 987, driver, "fibonacci", 16i64); + assert_invoke_eq!(i64, 46368, driver, "fibonacci", 24i64); +} + #[test] fn true_is_true() { let mut driver = TestDriver::new(