diff --git a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_at_empty_array.ncl.snap b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_at_empty_array.ncl.snap index 1c19afdb79..401ef8b1b5 100644 --- a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_at_empty_array.ncl.snap +++ b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_at_empty_array.ncl.snap @@ -4,9 +4,9 @@ expression: err --- error: contract broken by the caller of `at` invalid array indexing - ┌─ :162:9 + ┌─ :166:9 │ -162 │ | std.contract.unstable.IndexedArrayFun 'Index +166 │ | std.contract.unstable.IndexedArrayFun 'Index │ -------------------------------------------- expected type │ ┌─ [INPUTS_PATH]/errors/array_at_empty_array.ncl:3:16 @@ -21,5 +21,3 @@ note: │ 3 │ std.array.at 0 [] │ ----------------- (1) calling at - - diff --git a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_at_out_of_bound.ncl.snap b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_at_out_of_bound.ncl.snap index 749e2a0706..7e00925c79 100644 --- a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_at_out_of_bound.ncl.snap +++ b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_at_out_of_bound.ncl.snap @@ -4,9 +4,9 @@ expression: err --- error: contract broken by the caller of `at` invalid array indexing - ┌─ :162:9 + ┌─ :166:9 │ -162 │ | std.contract.unstable.IndexedArrayFun 'Index +166 │ | std.contract.unstable.IndexedArrayFun 'Index │ -------------------------------------------- expected type │ ┌─ [INPUTS_PATH]/errors/array_at_out_of_bound.ncl:3:16 @@ -21,5 +21,3 @@ note: │ 3 │ std.array.at 2 [1] │ ------------------ (1) calling at - - diff --git a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_range_reversed_indices.ncl.snap b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_range_reversed_indices.ncl.snap index dcd6406129..83a3dffa24 100644 --- a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_range_reversed_indices.ncl.snap +++ b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_range_reversed_indices.ncl.snap @@ -4,9 +4,9 @@ expression: err --- error: contract broken by the caller of `range` invalid range - ┌─ :673:9 + ┌─ :677:9 │ -673 │ | std.contract.unstable.RangeFun Dyn +677 │ | std.contract.unstable.RangeFun Dyn │ ---------------------------------- expected type │ ┌─ [INPUTS_PATH]/errors/array_range_reversed_indices.ncl:3:19 diff --git a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_range_step_negative_step.ncl.snap b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_range_step_negative_step.ncl.snap index b80bb0992b..f4c4f6b975 100644 --- a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_range_step_negative_step.ncl.snap +++ b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_array_range_step_negative_step.ncl.snap @@ -4,9 +4,9 @@ expression: err --- error: contract broken by the caller of `range_step` invalid range step - ┌─ :648:9 + ┌─ :652:9 │ -648 │ | std.contract.unstable.RangeFun (std.contract.unstable.RangeStep -> Dyn) +652 │ | std.contract.unstable.RangeFun (std.contract.unstable.RangeStep -> Dyn) │ ----------------------------------------------------------------------- expected type │ ┌─ [INPUTS_PATH]/errors/array_range_step_negative_step.ncl:3:27 diff --git a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_caller_contract_violation.ncl.snap b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_caller_contract_violation.ncl.snap index f5e40cca7f..17416597ea 100644 --- a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_caller_contract_violation.ncl.snap +++ b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_caller_contract_violation.ncl.snap @@ -3,9 +3,9 @@ source: cli/tests/snapshot/main.rs expression: err --- error: contract broken by the caller of `map` - ┌─ :146:33 + ┌─ :150:33 │ -146 │ : forall a b. (a -> b) -> Array a -> Array b +150 │ : forall a b. (a -> b) -> Array a -> Array b │ ------- expected type of the argument provided by the caller │ ┌─ [INPUTS_PATH]/errors/caller_contract_violation.ncl:3:31 @@ -18,5 +18,3 @@ note: │ 3 │ std.array.map std.function.id 'not-an-array │ ------------------------------------------- (1) calling map - - diff --git a/core/src/eval/mod.rs b/core/src/eval/mod.rs index 11bfa49413..d6347c8c06 100644 --- a/core/src/eval/mod.rs +++ b/core/src/eval/mod.rs @@ -90,8 +90,8 @@ use crate::{ pattern::compile::Compile, record::{Field, RecordData}, string::NickelString, - BinaryOp, BindingType, CustomContract, LetAttrs, MatchBranch, MatchData, RecordOpKind, - RichTerm, RuntimeContract, StrChunk, Term, UnaryOp, + BinaryOp, BindingType, LetAttrs, MatchBranch, MatchData, RecordOpKind, RichTerm, + RuntimeContract, StrChunk, Term, UnaryOp, }, }; @@ -1151,7 +1151,7 @@ pub fn subst( // Do not substitute under lambdas: mutually recursive function could cause an infinite // loop. Although avoidable, this requires some care and is not currently needed. | v @ Term::Fun(..) - | v @ Term::CustomContract(CustomContract::Predicate(..)) + | v @ Term::CustomContract(_) | v @ Term::Lbl(_) | v @ Term::ForeignId(_) | v @ Term::SealingKey(_) diff --git a/core/src/eval/operation.rs b/core/src/eval/operation.rs index 32c05ff3b4..a9e0e88e92 100644 --- a/core/src/eval/operation.rs +++ b/core/src/eval/operation.rs @@ -233,6 +233,7 @@ impl VirtualMachine { Term::Str(_) => "String", Term::Enum(_) | Term::EnumVariant { .. } => "Enum", Term::Fun(..) | Term::Match { .. } => "Function", + Term::CustomContract(_) => "CustomContract", Term::Array(..) => "Array", Term::Record(..) | Term::RecRecord(..) => "Record", Term::Lbl(..) => "Label", @@ -1195,18 +1196,44 @@ impl VirtualMachine { }) } UnaryOp::ContractFromPredicate => { - if let Term::Fun(id, body) = &*t { + if matches!(&*t, Term::Fun(..) | Term::Match(_)) { Ok(Closure { body: RichTerm::new( - Term::CustomContract(CustomContract::Predicate(*id, body.clone())), + Term::CustomContract(CustomContract::Predicate(RichTerm { + term: t, + pos, + })), pos, ), env, }) } else { - Err(mk_type_error!("contract_from_predicate", "Function")) + Err(mk_type_error!( + "contract/from_predicate", + "Function or MatchExpression" + )) + } + } + UnaryOp::ContractCustom => { + if matches!(&*t, Term::Fun(..) | Term::Match(_)) { + Ok(Closure { + body: RichTerm::new( + Term::CustomContract(CustomContract::PartialIdentity(RichTerm { + term: t, + pos, + })), + pos, + ), + env, + }) + } else { + Err(mk_type_error!( + "contract/custom", + "Function or MatchExpression" + )) } } + #[cfg(feature = "nix-experimental")] UnaryOp::EvalNix => { if let Term::Str(s) = &*t { @@ -1540,8 +1567,8 @@ impl VirtualMachine { pos2.into_inherited(), ); - match *t1 { - Term::Type(ref typ) => Ok(Closure { + match &*t1 { + Term::Type(typ) => Ok(Closure { body: typ.contract()?, env: env1, }), @@ -1552,16 +1579,15 @@ impl VirtualMachine { }, env: env1, }), - Term::CustomContract(CustomContract::Predicate(ref id, ref body)) => { - Ok(Closure { - body: mk_app!( - internals::predicate_to_ctr(), - RichTerm::new(Term::Fun(*id, body.clone()), pos1) - ) + Term::CustomContract(CustomContract::PartialIdentity(ctr)) => Ok(Closure { + body: ctr.clone(), + env: env1, + }), + Term::CustomContract(CustomContract::Predicate(pred)) => Ok(Closure { + body: mk_app!(internals::predicate_to_ctr(), pred.clone()) .with_pos(pos1), - env: env1, - }) - } + env: env1, + }), Term::Record(..) => { let closurized = RichTerm { term: t1, diff --git a/core/src/parser/grammar.lalrpop b/core/src/parser/grammar.lalrpop index 00c711e223..8eb00a4d74 100644 --- a/core/src/parser/grammar.lalrpop +++ b/core/src/parser/grammar.lalrpop @@ -1082,6 +1082,7 @@ UOp: UnaryOp = { "label/go_array" => UnaryOp::LabelGoArray, "label/go_dict" => UnaryOp::LabelGoDict, "contract/from_predicate" => UnaryOp::ContractFromPredicate, + "contract/custom" => UnaryOp::ContractCustom, "enum/embed" => UnaryOp::EnumEmbed(<>), "array/map" => UnaryOp::ArrayMap, "array/generate" => UnaryOp::ArrayGen, @@ -1514,6 +1515,7 @@ extern { "contract/array_lazy_app" => Token::Normal(NormalToken::ContractArrayLazyApp), "contract/record_lazy_app" => Token::Normal(NormalToken::ContractRecordLazyApp), "contract/from_predicate" => Token::Normal(NormalToken::ContractFromPredicate), + "contract/custom" => Token::Normal(NormalToken::ContractCustom), "op force" => Token::Normal(NormalToken::OpForce), "blame" => Token::Normal(NormalToken::Blame), "label/flip_polarity" => Token::Normal(NormalToken::LabelFlipPol), diff --git a/core/src/parser/lexer.rs b/core/src/parser/lexer.rs index d816b74adb..94c0c7daff 100644 --- a/core/src/parser/lexer.rs +++ b/core/src/parser/lexer.rs @@ -202,6 +202,8 @@ pub enum NormalToken<'input> { ContractRecordLazyApp, #[token("%contract/from_predicate%")] ContractFromPredicate, + #[token("%contract/custom%")] + ContractCustom, #[token("%blame%")] Blame, #[token("%label/flip_polarity%")] diff --git a/core/src/pretty.rs b/core/src/pretty.rs index 98bd078def..9c4cad7a6f 100644 --- a/core/src/pretty.rs +++ b/core/src/pretty.rs @@ -268,7 +268,7 @@ where loop { match body.as_ref() { - Term::Fun(id, rt) | Term::CustomContract(CustomContract::Predicate(id, rt)) => { + Term::Fun(id, rt) => { builder = docs![self, builder, self.line(), self.as_string(id)]; body = rt; } @@ -821,18 +821,21 @@ where Str(v) => allocator.escaped_string(v).double_quotes(), StrChunks(chunks) => allocator.chunks(chunks, StringRenderStyle::Multiline), Fun(id, body) => allocator.function(allocator.as_string(id), body), + CustomContract(ContractNode::PartialIdentity(ctr)) => docs![ + allocator, + "%contract/custom%", + docs![allocator, allocator.line(), ctr.pretty(allocator).parens()] + .nest(2) + .group() + ], FunPattern(pat, body) => allocator.function(allocator.pat_with_parens(pat), body), // Format this as the application `std.contract.from_predicate `. - CustomContract(ContractNode::Predicate(id, pred)) => docs![ + CustomContract(ContractNode::Predicate(pred)) => docs![ allocator, "%contract/from_predicate%", - docs![ - allocator, - allocator.line(), - allocator.function(allocator.as_string(id), pred).parens() - ] - .nest(2) - .group() + docs![allocator, allocator.line(), pred.pretty(allocator).parens()] + .nest(2) + .group() ], Lbl(_lbl) => allocator.text("%