diff --git a/crates/mun_codegen/src/ir/body.rs b/crates/mun_codegen/src/ir/body.rs index 20bd7c1fe..ce95d50ac 100644 --- a/crates/mun_codegen/src/ir/body.rs +++ b/crates/mun_codegen/src/ir/body.rs @@ -35,6 +35,7 @@ pub(crate) struct BodyIrGenerator<'a, 'b, D: IrDatabase> { function_map: &'a HashMap, dispatch_table: &'b DispatchTable, active_loop: Option, + hir_function: hir::Function, } impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { @@ -69,6 +70,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { function_map, dispatch_table, active_loop: None, + hir_function, } } @@ -107,11 +109,20 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { // Construct a return statement from the returned value of the body if a return is expected // in the first place. If the return type of the body is `never` there is no need to // generate a return statement. - let ret_type = &self.infer[self.body.body_expr()]; - if let Some(value) = ret_value { - self.builder.build_return(Some(&value)); - } else if !ret_type.is_never() { - self.builder.build_return(None); + let block_ret_type = &self.infer[self.body.body_expr()]; + let fn_ret_type = self + .hir_function + .ty(self.db) + .callable_sig(self.db) + .unwrap() + .ret() + .clone(); + if !block_ret_type.is_never() { + if fn_ret_type.is_empty() { + self.builder.build_return(None); + } else if let Some(value) = ret_value { + self.builder.build_return(Some(&value)); + } } } @@ -143,6 +154,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { } => self.gen_if(expr, *condition, *then_branch, *else_branch), Expr::Return { expr: ret_expr } => self.gen_return(expr, *ret_expr), Expr::Loop { body } => self.gen_loop(expr, *body), + Expr::While { condition, body } => self.gen_while(expr, *condition, *body), Expr::Break { expr: break_expr } => self.gen_break(expr, *break_expr), _ => unimplemented!("unimplemented expr type {:?}", &body[expr]), } @@ -178,6 +190,11 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { } } + /// Constructs an empty struct value e.g. `{}` + fn gen_empty(&mut self) -> BasicValueEnum { + self.module.get_context().const_struct(&[], false).into() + } + /// Generates IR for the specified block expression. fn gen_block( &mut self, @@ -193,16 +210,13 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { self.gen_let_statement(*pat, *initializer); } Statement::Expr(expr) => { - self.gen_expr(*expr); - // No need to generate code after a statement that has a `never` return type. - if self.infer[*expr].is_never() { - return None; - } + self.gen_expr(*expr)?; } }; } tail.and_then(|expr| self.gen_expr(expr)) + .or_else(|| Some(self.gen_empty())) } /// Constructs a builder that should be used to emit an `alloca` instruction. These instructions @@ -365,7 +379,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { }; let place = self.gen_place_expr(lhs_expr); self.builder.build_store(place, rhs); - None + Some(self.gen_empty()) } _ => unimplemented!("Operator {:?} is not implemented for float", op), } @@ -422,7 +436,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { }; let place = self.gen_place_expr(lhs_expr); self.builder.build_store(place, rhs); - None + Some(self.gen_empty()) } _ => unreachable!(format!("Operator {:?} is not implemented for integer", op)), } @@ -507,10 +521,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { else_branch: Option, ) -> Option { // Generate IR for the condition - let condition_ir = self - .gen_expr(condition) - .expect("condition must have a value") - .into_int_value(); + let condition_ir = self.gen_expr(condition)?.into_int_value(); // Generate the code blocks to branch to let context = self.module.get_context(); @@ -570,7 +581,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { Some(then_block_ir) } } else { - None + Some(self.gen_empty()) } } @@ -600,34 +611,92 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { None } - fn gen_loop(&mut self, _expr: ExprId, body_expr: ExprId) -> Option { - let context = self.module.get_context(); - let loop_block = context.append_basic_block(&self.fn_value, "loop"); - let exit_block = context.append_basic_block(&self.fn_value, "exit"); - + fn gen_loop_block_expr( + &mut self, + block: ExprId, + exit_block: BasicBlock, + ) -> ( + BasicBlock, + Vec<(BasicValueEnum, BasicBlock)>, + Option, + ) { // Build a new loop info struct let loop_info = LoopInfo { exit_block, break_values: Vec::new(), }; + // Replace previous loop info let prev_loop = std::mem::replace(&mut self.active_loop, Some(loop_info)); - // Insert an explicit fall through from the current block to the loop - self.builder.build_unconditional_branch(&loop_block); - // Start generating code inside the loop - self.builder.position_at_end(&loop_block); - let _ = self.gen_expr(body_expr); - - // Jump to the start of the loop - self.builder.build_unconditional_branch(&loop_block); + let value = self.gen_expr(block); let LoopInfo { exit_block, break_values, } = std::mem::replace(&mut self.active_loop, prev_loop).unwrap(); + (exit_block, break_values, value) + } + + fn gen_while( + &mut self, + _expr: ExprId, + condition_expr: ExprId, + body_expr: ExprId, + ) -> Option { + let context = self.module.get_context(); + let cond_block = context.append_basic_block(&self.fn_value, "whilecond"); + let loop_block = context.append_basic_block(&self.fn_value, "while"); + let exit_block = context.append_basic_block(&self.fn_value, "afterwhile"); + + // Insert an explicit fall through from the current block to the condition check + self.builder.build_unconditional_branch(&cond_block); + + // Generate condition block + self.builder.position_at_end(&cond_block); + let condition_ir = self.gen_expr(condition_expr); + if let Some(condition_ir) = condition_ir { + self.builder.build_conditional_branch( + condition_ir.into_int_value(), + &loop_block, + &exit_block, + ); + } else { + // If the condition doesn't return a value, we also immediately return without a value. + // This can happen if the expression is a `never` expression. + return None; + } + + // Generate loop block + self.builder.position_at_end(&loop_block); + let (exit_block, _, value) = self.gen_loop_block_expr(body_expr, exit_block); + if value.is_some() { + self.builder.build_unconditional_branch(&cond_block); + } + + // Generate exit block + self.builder.position_at_end(&exit_block); + + Some(self.gen_empty()) + } + + fn gen_loop(&mut self, _expr: ExprId, body_expr: ExprId) -> Option { + let context = self.module.get_context(); + let loop_block = context.append_basic_block(&self.fn_value, "loop"); + let exit_block = context.append_basic_block(&self.fn_value, "exit"); + + // Insert an explicit fall through from the current block to the loop + self.builder.build_unconditional_branch(&loop_block); + + // Generate the body of the loop + self.builder.position_at_end(&loop_block); + let (exit_block, break_values, value) = self.gen_loop_block_expr(body_expr, exit_block); + if value.is_some() { + self.builder.build_unconditional_branch(&loop_block); + } + // Move the builder to the exit block self.builder.position_at_end(&exit_block); diff --git a/crates/mun_codegen/src/snapshots/test__while_expr.snap b/crates/mun_codegen/src/snapshots/test__while_expr.snap new file mode 100644 index 000000000..e56439f69 --- /dev/null +++ b/crates/mun_codegen/src/snapshots/test__while_expr.snap @@ -0,0 +1,24 @@ +--- +source: crates/mun_codegen/src/test.rs +expression: "fn foo(n:int) {\n while n<3 {\n n += 1;\n };\n\n // This will be completely optimized out\n while n<4 {\n break;\n };\n}" +--- +; ModuleID = 'main.mun' +source_filename = "main.mun" + +define void @foo(i64) { +body: + br label %whilecond + +whilecond: ; preds = %while, %body + %n.0 = phi i64 [ %0, %body ], [ %add, %while ] + %less = icmp slt i64 %n.0, 3 + br i1 %less, label %while, label %whilecond3 + +while: ; preds = %whilecond + %add = add i64 %n.0, 1 + br label %whilecond + +whilecond3: ; preds = %whilecond + ret void +} + diff --git a/crates/mun_codegen/src/test.rs b/crates/mun_codegen/src/test.rs index eaaed1e2c..9e1f2c61e 100644 --- a/crates/mun_codegen/src/test.rs +++ b/crates/mun_codegen/src/test.rs @@ -346,6 +346,24 @@ fn loop_break_expr() { ) } +#[test] +fn while_expr() { + test_snapshot( + r#" + fn foo(n:int) { + while n<3 { + n += 1; + }; + + // This will be completely optimized out + while n<4 { + break; + }; + } + "#, + ) +} + fn test_snapshot(text: &str) { let text = text.trim().replace("\n ", "\n"); diff --git a/crates/mun_hir/src/diagnostics.rs b/crates/mun_hir/src/diagnostics.rs index ab2896cd0..28f3bb9b3 100644 --- a/crates/mun_hir/src/diagnostics.rs +++ b/crates/mun_hir/src/diagnostics.rs @@ -384,3 +384,27 @@ impl Diagnostic for BreakOutsideLoop { self } } + +#[derive(Debug)] +pub struct BreakWithValueOutsideLoop { + pub file: FileId, + pub break_expr: SyntaxNodePtr, +} + +impl Diagnostic for BreakWithValueOutsideLoop { + fn message(&self) -> String { + "`break` with value can only appear in a `loop`".to_owned() + } + + fn file(&self) -> FileId { + self.file + } + + fn syntax_node_ptr(&self) -> SyntaxNodePtr { + self.break_expr + } + + fn as_any(&self) -> &(dyn Any + Send + 'static) { + self + } +} diff --git a/crates/mun_hir/src/expr.rs b/crates/mun_hir/src/expr.rs index 00b837455..651665ecf 100644 --- a/crates/mun_hir/src/expr.rs +++ b/crates/mun_hir/src/expr.rs @@ -207,6 +207,10 @@ pub enum Expr { Loop { body: ExprId, }, + While { + condition: ExprId, + body: ExprId, + }, Literal(Literal), } @@ -304,6 +308,10 @@ impl Expr { Expr::Loop { body } => { f(*body); } + Expr::While { condition, body } => { + f(*condition); + f(*body); + } } } } @@ -468,6 +476,7 @@ where let syntax_ptr = AstPtr::new(&expr.clone()); match expr.kind() { ast::ExprKind::LoopExpr(expr) => self.collect_loop(expr), + ast::ExprKind::WhileExpr(expr) => self.collect_while(expr), ast::ExprKind::ReturnExpr(r) => self.collect_return(r), ast::ExprKind::BreakExpr(r) => self.collect_break(r), ast::ExprKind::BlockExpr(b) => self.collect_block(b), @@ -587,13 +596,7 @@ where } }); - let condition = match e.condition() { - None => self.missing_expr(), - Some(condition) => match condition.pat() { - None => self.collect_expr_opt(condition.expr()), - _ => unreachable!("patterns in conditions are not yet supported"), - }, - }; + let condition = self.collect_condition_opt(e.condition()); self.alloc_expr( Expr::If { @@ -622,6 +625,21 @@ where } } + fn collect_condition_opt(&mut self, cond: Option) -> ExprId { + if let Some(cond) = cond { + self.collect_condition(cond) + } else { + self.exprs.alloc(Expr::Missing) + } + } + + fn collect_condition(&mut self, cond: ast::Condition) -> ExprId { + match cond.pat() { + None => self.collect_expr_opt(cond.expr()), + _ => unreachable!("patterns in conditions are not yet supported"), + } + } + fn collect_pat(&mut self, pat: ast::Pat) -> PatId { let pattern = match pat.kind() { ast::PatKind::BindPat(bp) => { @@ -655,6 +673,13 @@ where self.alloc_expr(Expr::Loop { body }, syntax_node_ptr) } + fn collect_while(&mut self, expr: ast::WhileExpr) -> ExprId { + let syntax_node_ptr = AstPtr::new(&expr.clone().into()); + let condition = self.collect_condition_opt(expr.condition()); + let body = self.collect_block_opt(expr.loop_body()); + self.alloc_expr(Expr::While { condition, body }, syntax_node_ptr) + } + fn finish(mut self) -> (Body, BodySourceMap) { let (type_refs, type_ref_source_map) = self.type_ref_builder.finish(); let body = Body { diff --git a/crates/mun_hir/src/ty/infer.rs b/crates/mun_hir/src/ty/infer.rs index e4686f89f..db74dcf52 100644 --- a/crates/mun_hir/src/ty/infer.rs +++ b/crates/mun_hir/src/ty/infer.rs @@ -23,15 +23,16 @@ mod type_variable; pub use type_variable::TypeVarId; +#[macro_export] macro_rules! ty_app { ($ctor:pat, $param:pat) => { - $crate::ty::Ty::Apply($crate::ty::ApplicationTy { + $crate::Ty::Apply($crate::ApplicationTy { ctor: $ctor, parameters: $param, }) }; ($ctor:pat) => { - $crate::ty::Ty::Apply($crate::ty::ApplicationTy { + $crate::Ty::Apply($crate::ApplicationTy { ctor: $ctor, .. }) @@ -92,6 +93,12 @@ pub fn infer_query(db: &impl HirDatabase, def: DefWithBody) -> Arc { db: &'a D, @@ -108,7 +115,7 @@ struct InferenceResultBuilder<'a, D: HirDatabase> { /// entry contains the current type of the loop statement (initially `never`) and the expected /// type of the loop expression. Both these values are updated when a break statement is /// encountered. - active_loop: Option<(Ty, Expectation)>, + active_loop: Option, /// The return type of the function being inferred. return_ty: Ty, @@ -315,6 +322,9 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { } Expr::Break { expr } => self.infer_break(tgt_expr, *expr), Expr::Loop { body } => self.infer_loop_expr(tgt_expr, *body, expected), + Expr::While { condition, body } => { + self.infer_while_expr(tgt_expr, *condition, *body, expected) + } _ => Ty::Unknown, // Expr::UnaryOp { expr: _, op: _ } => {} // Expr::Block { statements: _, tail: _ } => {} @@ -519,13 +529,20 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { } fn infer_break(&mut self, tgt_expr: ExprId, expr: Option) -> Ty { - // Fetch the expected type - let expected = if let Some((_, info)) = &self.active_loop { - info.clone() - } else { - self.diagnostics - .push(InferenceDiagnostic::BreakOutsideLoop { id: tgt_expr }); - return Ty::simple(TypeCtor::Never); + let expected = match &self.active_loop { + Some(ActiveLoop::Loop(_, info)) => info.clone(), + Some(_) => { + if expr.is_some() { + self.diagnostics + .push(InferenceDiagnostic::BreakWithValueOutsideLoop { id: tgt_expr }); + } + return Ty::simple(TypeCtor::Never); + } + None => { + self.diagnostics + .push(InferenceDiagnostic::BreakOutsideLoop { id: tgt_expr }); + return Ty::simple(TypeCtor::Never); + } }; // Infer the type of the break expression @@ -548,29 +565,46 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { }; // Update the expected type for the rest of the loop - self.active_loop = Some((ty.clone(), Expectation::has_type(ty))); + self.active_loop = Some(ActiveLoop::Loop(ty.clone(), Expectation::has_type(ty))); Ty::simple(TypeCtor::Never) } fn infer_loop_expr(&mut self, _tgt_expr: ExprId, body: ExprId, expected: &Expectation) -> Ty { - self.infer_loop_block(body, expected) + if let ActiveLoop::Loop(ty, _) = self.infer_loop_block( + body, + ActiveLoop::Loop(Ty::simple(TypeCtor::Never), expected.clone()), + ) { + ty + } else { + panic!("returned active loop must be a loop") + } } - /// Infers the type of a loop body, taking into account breaks. - fn infer_loop_block(&mut self, body: ExprId, expected: &Expectation) -> Ty { - // Take the previous loop information and replace it with a new entry - let top_level_loop = std::mem::replace( - &mut self.active_loop, - Some((Ty::simple(TypeCtor::Never), expected.clone())), - ); + fn infer_loop_block(&mut self, body: ExprId, lp: ActiveLoop) -> ActiveLoop { + let top_level_loop = std::mem::replace(&mut self.active_loop, Some(lp)); // Infer the body of the loop self.infer_expr_coerce(body, &Expectation::has_type(Ty::Empty)); // Take the result of the loop information and replace with top level loop - let (ty, _) = std::mem::replace(&mut self.active_loop, top_level_loop).unwrap(); - ty + std::mem::replace(&mut self.active_loop, top_level_loop).unwrap() + } + + fn infer_while_expr( + &mut self, + _tgt_expr: ExprId, + condition: ExprId, + body: ExprId, + _expected: &Expectation, + ) -> Ty { + self.infer_expr( + condition, + &Expectation::has_type(Ty::simple(TypeCtor::Bool)), + ); + + self.infer_loop_block(body, ActiveLoop::While); + Ty::Empty } pub fn report_pat_inference_failure(&mut self, _pat: PatId) { @@ -633,8 +667,9 @@ impl From for ExprOrPatId { mod diagnostics { use crate::diagnostics::{ - BreakOutsideLoop, CannotApplyBinaryOp, ExpectedFunction, IncompatibleBranch, InvalidLHS, - MismatchedType, MissingElseBranch, ParameterCountMismatch, ReturnMissingExpression, + BreakOutsideLoop, BreakWithValueOutsideLoop, CannotApplyBinaryOp, ExpectedFunction, + IncompatibleBranch, InvalidLHS, MismatchedType, MissingElseBranch, ParameterCountMismatch, + ReturnMissingExpression, }; use crate::{ code_model::src::HasSource, @@ -690,6 +725,9 @@ mod diagnostics { BreakOutsideLoop { id: ExprId, }, + BreakWithValueOutsideLoop { + id: ExprId, + }, } impl InferenceDiagnostic { @@ -806,6 +844,13 @@ mod diagnostics { break_expr: id, }); } + InferenceDiagnostic::BreakWithValueOutsideLoop { id } => { + let id = body.expr_syntax(*id).unwrap().ast.syntax_node_ptr(); + sink.push(BreakWithValueOutsideLoop { + file, + break_expr: id, + }); + } } } } diff --git a/crates/mun_hir/src/ty/snapshots/tests__infer_while.snap b/crates/mun_hir/src/ty/snapshots/tests__infer_while.snap new file mode 100644 index 000000000..094af5cf7 --- /dev/null +++ b/crates/mun_hir/src/ty/snapshots/tests__infer_while.snap @@ -0,0 +1,40 @@ +--- +source: crates/mun_hir/src/ty/tests.rs +expression: "fn foo() {\n let n = 0;\n while n < 3 { n += 1; };\n while n < 3 { n += 1; break; };\n while n < 3 { break 3; }; // error: break with value can only appear in a loop\n while n < 3 { loop { break 3; }; };\n}" +--- +[109; 116): `break` with value can only appear in a `loop` +[9; 217) '{ ...; }; }': nothing +[19; 20) 'n': int +[23; 24) '0': int +[30; 53) 'while ...= 1; }': nothing +[36; 37) 'n': int +[36; 41) 'n < 3': bool +[40; 41) '3': int +[42; 53) '{ n += 1; }': nothing +[44; 45) 'n': int +[44; 50) 'n += 1': nothing +[49; 50) '1': int +[59; 89) 'while ...eak; }': nothing +[65; 66) 'n': int +[65; 70) 'n < 3': bool +[69; 70) '3': int +[71; 89) '{ n +=...eak; }': never +[73; 74) 'n': int +[73; 79) 'n += 1': nothing +[78; 79) '1': int +[81; 86) 'break': never +[95; 119) 'while ...k 3; }': nothing +[101; 102) 'n': int +[101; 106) 'n < 3': bool +[105; 106) '3': int +[107; 119) '{ break 3; }': never +[109; 116) 'break 3': never +[180; 214) 'while ...; }; }': nothing +[186; 187) 'n': int +[186; 191) 'n < 3': bool +[190; 191) '3': int +[192; 214) '{ loop...; }; }': nothing +[194; 211) 'loop {...k 3; }': int +[199; 211) '{ break 3; }': never +[201; 208) 'break 3': never +[207; 208) '3': int diff --git a/crates/mun_hir/src/ty/tests.rs b/crates/mun_hir/src/ty/tests.rs index f65ad8214..439f22907 100644 --- a/crates/mun_hir/src/ty/tests.rs +++ b/crates/mun_hir/src/ty/tests.rs @@ -144,6 +144,21 @@ fn infer_break() { ) } +#[test] +fn infer_while() { + infer_snapshot( + r#" + fn foo() { + let n = 0; + while n < 3 { n += 1; }; + while n < 3 { n += 1; break; }; + while n < 3 { break 3; }; // error: break with value can only appear in a loop + while n < 3 { loop { break 3; }; }; + } + "#, + ) +} + fn infer_snapshot(text: &str) { let text = text.trim().replace("\n ", "\n"); insta::assert_snapshot!(insta::_macro_support::AutoName, infer(&text), &text); diff --git a/crates/mun_runtime/src/test.rs b/crates/mun_runtime/src/test.rs index c2573cb76..203865b90 100644 --- a/crates/mun_runtime/src/test.rs +++ b/crates/mun_runtime/src/test.rs @@ -241,6 +241,31 @@ fn fibonacci_loop_break() { assert_invoke_eq!(i64, 46368, driver, "fibonacci", 24i64); } +#[test] +fn fibonacci_while() { + let mut driver = TestDriver::new( + r#" + fn fibonacci(n:int):int { + let a = 0; + let b = 1; + let i = 1; + while i <= n { + let sum = a + b; + a = b; + b = sum; + i += 1; + } + a + } + "#, + ); + + assert_invoke_eq!(i64, 5, driver, "fibonacci", 5i64); + assert_invoke_eq!(i64, 89, driver, "fibonacci", 11i64); + assert_invoke_eq!(i64, 987, driver, "fibonacci", 16i64); + assert_invoke_eq!(i64, 46368, driver, "fibonacci", 24i64); +} + #[test] fn true_is_true() { let mut driver = TestDriver::new( diff --git a/crates/mun_syntax/src/ast/generated.rs b/crates/mun_syntax/src/ast/generated.rs index 112c644e1..f2f7fabd2 100644 --- a/crates/mun_syntax/src/ast/generated.rs +++ b/crates/mun_syntax/src/ast/generated.rs @@ -249,7 +249,7 @@ impl AstNode for Expr { fn can_cast(kind: SyntaxKind) -> bool { match kind { LITERAL | PREFIX_EXPR | PATH_EXPR | BIN_EXPR | PAREN_EXPR | CALL_EXPR | IF_EXPR - | LOOP_EXPR | RETURN_EXPR | BREAK_EXPR | BLOCK_EXPR => true, + | LOOP_EXPR | WHILE_EXPR | RETURN_EXPR | BREAK_EXPR | BLOCK_EXPR => true, _ => false, } } @@ -274,6 +274,7 @@ pub enum ExprKind { CallExpr(CallExpr), IfExpr(IfExpr), LoopExpr(LoopExpr), + WhileExpr(WhileExpr), ReturnExpr(ReturnExpr), BreakExpr(BreakExpr), BlockExpr(BlockExpr), @@ -318,6 +319,11 @@ impl From for Expr { Expr { syntax: n.syntax } } } +impl From for Expr { + fn from(n: WhileExpr) -> Expr { + Expr { syntax: n.syntax } + } +} impl From for Expr { fn from(n: ReturnExpr) -> Expr { Expr { syntax: n.syntax } @@ -345,6 +351,7 @@ impl Expr { CALL_EXPR => ExprKind::CallExpr(CallExpr::cast(self.syntax.clone()).unwrap()), IF_EXPR => ExprKind::IfExpr(IfExpr::cast(self.syntax.clone()).unwrap()), LOOP_EXPR => ExprKind::LoopExpr(LoopExpr::cast(self.syntax.clone()).unwrap()), + WHILE_EXPR => ExprKind::WhileExpr(WhileExpr::cast(self.syntax.clone()).unwrap()), RETURN_EXPR => ExprKind::ReturnExpr(ReturnExpr::cast(self.syntax.clone()).unwrap()), BREAK_EXPR => ExprKind::BreakExpr(BreakExpr::cast(self.syntax.clone()).unwrap()), BLOCK_EXPR => ExprKind::BlockExpr(BlockExpr::cast(self.syntax.clone()).unwrap()), @@ -1237,3 +1244,35 @@ impl AstNode for Visibility { } } impl Visibility {} + +// WhileExpr + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct WhileExpr { + pub(crate) syntax: SyntaxNode, +} + +impl AstNode for WhileExpr { + fn can_cast(kind: SyntaxKind) -> bool { + match kind { + WHILE_EXPR => true, + _ => false, + } + } + fn cast(syntax: SyntaxNode) -> Option { + if Self::can_cast(syntax.kind()) { + Some(WhileExpr { syntax }) + } else { + None + } + } + fn syntax(&self) -> &SyntaxNode { + &self.syntax + } +} +impl ast::LoopBodyOwner for WhileExpr {} +impl WhileExpr { + pub fn condition(&self) -> Option { + super::child_opt(self) + } +} diff --git a/crates/mun_syntax/src/grammar.ron b/crates/mun_syntax/src/grammar.ron index 787c9c57c..1fda9fb86 100644 --- a/crates/mun_syntax/src/grammar.ron +++ b/crates/mun_syntax/src/grammar.ron @@ -122,6 +122,7 @@ Grammar( "IF_EXPR", "BLOCK_EXPR", "RETURN_EXPR", + "WHILE_EXPR", "LOOP_EXPR", "BREAK_EXPR", "CONDITION", @@ -189,6 +190,11 @@ Grammar( traits: ["LoopBodyOwner"] ), + "WhileExpr": ( + traits: ["LoopBodyOwner"], + options: [ "Condition" ] + ), + "PathExpr": (options: ["Path"]), "PrefixExpr": (options: ["Expr"]), "BinExpr": (), @@ -218,6 +224,7 @@ Grammar( "CallExpr", "IfExpr", "LoopExpr", + "WhileExpr", "ReturnExpr", "BreakExpr", "BlockExpr", diff --git a/crates/mun_syntax/src/parsing/grammar/expressions.rs b/crates/mun_syntax/src/parsing/grammar/expressions.rs index bcbce1b2e..1599aa59b 100644 --- a/crates/mun_syntax/src/parsing/grammar/expressions.rs +++ b/crates/mun_syntax/src/parsing/grammar/expressions.rs @@ -14,6 +14,7 @@ const ATOM_EXPR_FIRST: TokenSet = LITERAL_FIRST.union(PATH_FIRST).union(token_se T![loop], T![return], T![break], + T![while], ]); const LHS_FIRST: TokenSet = ATOM_EXPR_FIRST.union(token_set![EXCLAMATION, MINUS]); @@ -234,6 +235,7 @@ fn atom_expr(p: &mut Parser, r: Restrictions) -> Option { T![if] => if_expr(p), T![loop] => loop_expr(p), T![return] => ret_expr(p), + T![while] => while_expr(p), T![break] => break_expr(p, r), _ => { p.error_recover("expected expression", EXPR_RECOVERY_SET); @@ -317,3 +319,12 @@ fn break_expr(p: &mut Parser, r: Restrictions) -> CompletedMarker { } m.complete(p, BREAK_EXPR) } + +fn while_expr(p: &mut Parser) -> CompletedMarker { + assert!(p.at(T![while])); + let m = p.start(); + p.bump(T![while]); + cond(p); + block(p); + m.complete(p, WHILE_EXPR) +} diff --git a/crates/mun_syntax/src/syntax_kind/generated.rs b/crates/mun_syntax/src/syntax_kind/generated.rs index 0bf4058b2..8e6034f49 100644 --- a/crates/mun_syntax/src/syntax_kind/generated.rs +++ b/crates/mun_syntax/src/syntax_kind/generated.rs @@ -103,6 +103,7 @@ pub enum SyntaxKind { IF_EXPR, BLOCK_EXPR, RETURN_EXPR, + WHILE_EXPR, LOOP_EXPR, BREAK_EXPR, CONDITION, @@ -369,6 +370,7 @@ impl SyntaxKind { IF_EXPR => &SyntaxInfo { name: "IF_EXPR" }, BLOCK_EXPR => &SyntaxInfo { name: "BLOCK_EXPR" }, RETURN_EXPR => &SyntaxInfo { name: "RETURN_EXPR" }, + WHILE_EXPR => &SyntaxInfo { name: "WHILE_EXPR" }, LOOP_EXPR => &SyntaxInfo { name: "LOOP_EXPR" }, BREAK_EXPR => &SyntaxInfo { name: "BREAK_EXPR" }, CONDITION => &SyntaxInfo { name: "CONDITION" }, diff --git a/crates/mun_syntax/src/tests/parser.rs b/crates/mun_syntax/src/tests/parser.rs index 97624773f..bc9128f47 100644 --- a/crates/mun_syntax/src/tests/parser.rs +++ b/crates/mun_syntax/src/tests/parser.rs @@ -199,3 +199,15 @@ fn break_expr() { "#, ) } + +#[test] +fn while_expr() { + ok_snapshot_test( + r#" + fn foo() { + while true {}; + while { true } {}; + } + "#, + ) +} diff --git a/crates/mun_syntax/src/tests/snapshots/parser__while_expr.snap b/crates/mun_syntax/src/tests/snapshots/parser__while_expr.snap new file mode 100644 index 000000000..267b037f2 --- /dev/null +++ b/crates/mun_syntax/src/tests/snapshots/parser__while_expr.snap @@ -0,0 +1,50 @@ +--- +source: crates/mun_syntax/src/tests/parser.rs +expression: "fn foo() {\n while true {};\n while { true } {};\n}" +--- +SOURCE_FILE@[0; 54) + FUNCTION_DEF@[0; 54) + FN_KW@[0; 2) "fn" + WHITESPACE@[2; 3) " " + NAME@[3; 6) + IDENT@[3; 6) "foo" + PARAM_LIST@[6; 8) + L_PAREN@[6; 7) "(" + R_PAREN@[7; 8) ")" + WHITESPACE@[8; 9) " " + BLOCK_EXPR@[9; 54) + L_CURLY@[9; 10) "{" + WHITESPACE@[10; 15) "\n " + EXPR_STMT@[15; 29) + WHILE_EXPR@[15; 28) + WHILE_KW@[15; 20) "while" + WHITESPACE@[20; 21) " " + CONDITION@[21; 25) + LITERAL@[21; 25) + TRUE_KW@[21; 25) "true" + WHITESPACE@[25; 26) " " + BLOCK_EXPR@[26; 28) + L_CURLY@[26; 27) "{" + R_CURLY@[27; 28) "}" + SEMI@[28; 29) ";" + WHITESPACE@[29; 34) "\n " + EXPR_STMT@[34; 52) + WHILE_EXPR@[34; 51) + WHILE_KW@[34; 39) "while" + WHITESPACE@[39; 40) " " + CONDITION@[40; 48) + BLOCK_EXPR@[40; 48) + L_CURLY@[40; 41) "{" + WHITESPACE@[41; 42) " " + LITERAL@[42; 46) + TRUE_KW@[42; 46) "true" + WHITESPACE@[46; 47) " " + R_CURLY@[47; 48) "}" + WHITESPACE@[48; 49) " " + BLOCK_EXPR@[49; 51) + L_CURLY@[49; 50) "{" + R_CURLY@[50; 51) "}" + SEMI@[51; 52) ";" + WHITESPACE@[52; 53) "\n" + R_CURLY@[53; 54) "}" +