diff --git a/src/ast.rs b/src/ast.rs index c961897..694dab8 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -233,6 +233,8 @@ pub enum CallName { UnwrapLeft(ResolvedType), /// [`Either::unwrap_right`]. UnwrapRight(ResolvedType), + /// [`Option::is_none`]. + IsNone(ResolvedType), /// [`Option::unwrap`]. Unwrap, /// [`assert`]. @@ -893,6 +895,22 @@ impl AbstractSyntaxTree for Call { scope, )?]) } + CallName::IsNone(some_ty) => { + if from.args.len() != 1 { + return Err(Error::InvalidNumberOfArguments(1, from.args.len())) + .with_span(from); + } + let out_ty = ResolvedType::boolean(); + if ty != &out_ty { + return Err(Error::ExpressionTypeMismatch(ty.clone(), out_ty)).with_span(from); + } + let arg_ty = ResolvedType::option(some_ty); + Arc::from([Expression::analyze( + from.args.first().unwrap(), + &arg_ty, + scope, + )?]) + } CallName::Unwrap => { let args_ty = ResolvedType::option(ty.clone()); if from.args.len() != 1 { @@ -991,6 +1009,9 @@ impl AbstractSyntaxTree for CallName { .resolve(left_ty) .map(Self::UnwrapRight) .with_span(from), + parse::CallName::IsNone(some_ty) => { + scope.resolve(some_ty).map(Self::IsNone).with_span(from) + } parse::CallName::Unwrap => Ok(Self::Unwrap), parse::CallName::Assert => Ok(Self::Assert), parse::CallName::TypeCast(target) => { diff --git a/src/compile.rs b/src/compile.rs index c4a9032..54bbca1 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -302,6 +302,11 @@ impl Call { let get_inner = ProgNode::assertr_take(fail_cmr, &ProgNode::iden()); ProgNode::comp(&right_and_unit, &get_inner).with_span(self) } + CallName::IsNone(..) => { + let sum_and_unit = ProgNode::pair_unit(&args); + let is_right = ProgNode::case_true_false(); + ProgNode::comp(&sum_and_unit, &is_right).with_span(self) + } CallName::Assert => { let jet = ProgNode::jet(Elements::Verify); ProgNode::comp(&args, &jet).with_span(self) diff --git a/src/minimal.pest b/src/minimal.pest index 9ca59fc..d90da4a 100644 --- a/src/minimal.pest +++ b/src/minimal.pest @@ -56,10 +56,11 @@ false_expr = @{ "false" } true_expr = @{ "true" } unwrap_left = { "unwrap_left::<" ~ ty ~ ">" } unwrap_right = { "unwrap_right::<" ~ ty ~ ">" } +is_none = { "is_none::<" ~ ty ~ ">" } unwrap = @{ "unwrap" } assert = @{ "assert!" } type_cast = { "<" ~ ty ~ ">::into" } -call_name = { jet | unwrap_left | unwrap_right | unwrap | assert | type_cast | function_name } +call_name = { jet | unwrap_left | unwrap_right | is_none | unwrap | assert | type_cast | function_name } call_args = { "(" ~ (expression ~ ("," ~ expression)*)? ~ ")" } call_expr = { call_name ~ call_args } unsigned_decimal = @{ (ASCII_DIGIT | "_")+ } diff --git a/src/named.rs b/src/named.rs index bf09b50..bf8fe99 100644 --- a/src/named.rs +++ b/src/named.rs @@ -287,6 +287,16 @@ pub trait CoreExt: CoreConstructible + Sized { fn assertr_drop(cmr: Cmr, right: &Self) -> Self { Self::assertr(cmr, &Self::drop_(right)).unwrap() } + + /// `case false true` always type-checks. + fn case_false_true() -> Self { + Self::case(&Self::bit_false(), &Self::bit_true()).unwrap() + } + + /// `case true false` always type-checks. + fn case_true_false() -> Self { + Self::case(&Self::bit_true(), &Self::bit_false()).unwrap() + } } impl CoreExt for N {} diff --git a/src/parse.rs b/src/parse.rs index acd8caf..816d491 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -278,6 +278,8 @@ pub enum CallName { UnwrapRight(AliasedType), /// [`Option::unwrap`]. Unwrap, + /// [`Option::is_none`]. + IsNone(AliasedType), /// [`assert`]. Assert, /// Cast from the given source type. @@ -816,6 +818,10 @@ impl PestParse for CallName { let inner = pair.into_inner().next().unwrap(); AliasedType::parse(inner).map(Self::UnwrapRight) } + Rule::is_none => { + let inner = pair.into_inner().next().unwrap(); + AliasedType::parse(inner).map(Self::IsNone) + } Rule::unwrap => Ok(Self::Unwrap), Rule::assert => Ok(Self::Assert), Rule::type_cast => {