Skip to content

Commit

Permalink
The actual issue with #945 is that the curry builtin reducer could so…
Browse files Browse the repository at this point in the history
…metimes place fully applied builtins that could evaluate and fail above where they were actually used. This happened with builtins that were called with the same constants enough times for the curry builtin to try hoist to a higher scope. This is now fixed by detecting which builtins are safe to evaluate in advance before we hoist fully applied builtins
  • Loading branch information
MicroProofs committed May 22, 2024
1 parent c1a913f commit d6cc450
Showing 1 changed file with 181 additions and 33 deletions.
214 changes: 181 additions & 33 deletions crates/uplc/src/optimize/shrinker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,146 @@ impl DefaultFunction {
| DefaultFunction::ConstrData
)
}

pub fn is_error_safe(self, arg_stack: &[&Term<Name>]) -> bool {
match self {
DefaultFunction::AddInteger
| DefaultFunction::SubtractInteger
| DefaultFunction::MultiplyInteger
| DefaultFunction::EqualsInteger
| DefaultFunction::LessThanInteger
| DefaultFunction::LessThanEqualsInteger => arg_stack.iter().all(|arg| {
if let Term::Constant(c) = arg {
matches!(c.as_ref(), Constant::Integer(_))
} else {
false
}
}),
DefaultFunction::DivideInteger
| DefaultFunction::ModInteger
| DefaultFunction::QuotientInteger
| DefaultFunction::RemainderInteger => arg_stack.iter().all(|arg| {
if let Term::Constant(c) = arg {
if let Constant::Integer(i) = c.as_ref() {
*i != 0.into()
} else {
false
}
} else {
false
}
}),
DefaultFunction::EqualsByteString
| DefaultFunction::AppendByteString
| DefaultFunction::LessThanEqualsByteString
| DefaultFunction::LessThanByteString => arg_stack.iter().all(|arg| {
if let Term::Constant(c) = arg {
matches!(c.as_ref(), Constant::ByteString(_))
} else {
false
}
}),

DefaultFunction::ConsByteString => {
if let (Term::Constant(c), Term::Constant(c2)) = (&arg_stack[0], &arg_stack[1]) {
if let (Constant::Integer(i), Constant::ByteString(_)) =
(c.as_ref(), c2.as_ref())
{
i >= &0.into() && i < &255.into()
} else {
false
}
} else {
false
}
}

DefaultFunction::SliceByteString => {
if let (Term::Constant(c), Term::Constant(c2), Term::Constant(c3)) =
(&arg_stack[0], &arg_stack[1], &arg_stack[2])
{
matches!(
(c.as_ref(), c2.as_ref(), c3.as_ref()),
(
Constant::Integer(_),
Constant::Integer(_),
Constant::ByteString(_)
)
)
} else {
false
}
}

DefaultFunction::IndexByteString => {
if let (Term::Constant(c), Term::Constant(c2)) = (&arg_stack[0], &arg_stack[1]) {
if let (Constant::ByteString(bs), Constant::Integer(i)) =
(c.as_ref(), c2.as_ref())
{
i >= &0.into() && i < &bs.len().into()
} else {
false
}
} else {
false
}
}

DefaultFunction::EqualsString | DefaultFunction::AppendString => {
arg_stack.iter().all(|arg| {
if let Term::Constant(c) = arg {
matches!(c.as_ref(), Constant::String(_))
} else {
false
}
})
}

DefaultFunction::EqualsData => arg_stack.iter().all(|arg| {
if let Term::Constant(c) = arg {
matches!(c.as_ref(), Constant::Data(_))
} else {
false
}
}),

DefaultFunction::Bls12_381_G1_Equal | DefaultFunction::Bls12_381_G1_Add => {
arg_stack.iter().all(|arg| {
if let Term::Constant(c) = arg {
matches!(c.as_ref(), Constant::Bls12_381G1Element(_))
} else {
false
}
})
}

DefaultFunction::Bls12_381_G2_Equal | DefaultFunction::Bls12_381_G2_Add => {
arg_stack.iter().all(|arg| {
if let Term::Constant(c) = arg {
matches!(c.as_ref(), Constant::Bls12_381G2Element(_))
} else {
false
}
})
}

DefaultFunction::ConstrData => {
if let (Term::Constant(c), Term::Constant(c2)) = (&arg_stack[0], &arg_stack[1]) {
if let (Constant::Integer(i), Constant::ProtoList(Type::Data, _)) =
(c.as_ref(), c2.as_ref())
{
i >= &0.into()
} else {
false
}
} else {
false
}
}

_ => false,
}
}
}

#[derive(PartialEq, Clone, Debug)]
Expand All @@ -219,10 +359,12 @@ pub enum BuiltinArgs {
}

impl BuiltinArgs {
fn args_from_arg_stack(stack: Vec<(usize, Term<Name>)>, is_order_agnostic: bool) -> Self {
fn args_from_arg_stack(stack: Vec<(usize, Term<Name>)>, func: DefaultFunction) -> Self {
let error_safe = func.is_error_safe(&stack.iter().map(|(_, term)| term).collect_vec());

let mut ordered_arg_stack = stack.into_iter().sorted_by(|(_, arg1), (_, arg2)| {
// sort by constant first if the builtin is order agnostic
if is_order_agnostic {
if func.is_order_agnostic_builtin() {
if matches!(arg1, Term::Constant(_)) == matches!(arg2, Term::Constant(_)) {
Ordering::Equal
} else if matches!(arg1, Term::Constant(_)) {
Expand All @@ -235,23 +377,35 @@ impl BuiltinArgs {
}
});

if ordered_arg_stack.len() == 2 && is_order_agnostic {
if ordered_arg_stack.len() == 2 && func.is_order_agnostic_builtin() {
// This is the special case where the order of args is irrelevant to the builtin
// An example is addInteger or multiplyInteger
BuiltinArgs::TwoArgsAnyOrder {
fst: ordered_arg_stack.next().unwrap(),
snd: ordered_arg_stack.next(),
snd: if error_safe {
ordered_arg_stack.next()
} else {
None
},
}
} else if ordered_arg_stack.len() == 2 {
BuiltinArgs::TwoArgs {
fst: ordered_arg_stack.next().unwrap(),
snd: ordered_arg_stack.next(),
snd: if error_safe {
ordered_arg_stack.next()
} else {
None
},
}
} else {
BuiltinArgs::ThreeArgs {
fst: ordered_arg_stack.next().unwrap(),
snd: ordered_arg_stack.next(),
thd: ordered_arg_stack.next(),
thd: if error_safe {
ordered_arg_stack.next()
} else {
None
},
}
}
}
Expand Down Expand Up @@ -855,7 +1009,7 @@ impl Program<Name> {
pub fn lambda_reducer(self) -> Self {
let mut lambda_applied_ids = vec![];

self.traverse_uplc_with(false, &mut |id, term, mut arg_stack, _scope| {
self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| {
match term {
Term::Apply { function, .. } => {
// We are applying some arg so now we unwrap the id of the applied arg
Expand Down Expand Up @@ -905,7 +1059,7 @@ impl Program<Name> {
pub fn builtin_force_reducer(self) -> Self {
let mut builtin_map = IndexMap::new();

let program = self.traverse_uplc_with(false, &mut |_id, term, _arg_stack, _scope| {
let program = self.traverse_uplc_with(true, &mut |_id, term, _arg_stack, _scope| {
if let Term::Force(f) = term {
let f = Rc::make_mut(f);
match f {
Expand Down Expand Up @@ -965,7 +1119,7 @@ impl Program<Name> {

pub fn identity_reducer(self) -> Self {
let mut identity_applied_ids = vec![];
self.traverse_uplc_with(false, &mut |id, term, mut arg_stack, _scope| {
self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| {
match term {
Term::Apply { function, .. } => {
// We are applying some arg so now we unwrap the id of the applied arg
Expand Down Expand Up @@ -1074,7 +1228,7 @@ impl Program<Name> {
pub fn inline_reducer(self) -> Self {
let mut lambda_applied_ids = vec![];

self.traverse_uplc_with(false, &mut |id, term, mut arg_stack, _scope| match term {
self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| match term {
Term::Apply { function, .. } => {
// We are applying some arg so now we unwrap the id of the applied arg
let id = id.unwrap();
Expand Down Expand Up @@ -1140,7 +1294,7 @@ impl Program<Name> {
}

pub fn force_delay_reducer(self) -> Self {
self.traverse_uplc_with(false, &mut |_id, term, _arg_stack, _scope| {
self.traverse_uplc_with(true, &mut |_id, term, _arg_stack, _scope| {
if let Term::Force(f) = term {
let f = f.as_ref();

Expand All @@ -1152,7 +1306,7 @@ impl Program<Name> {
}

pub fn remove_no_inlines(self) -> Self {
self.traverse_uplc_with(false, &mut |_, term, _, _| match term {
self.traverse_uplc_with(true, &mut |_, term, _, _| match term {
Term::Lambda {
parameter_name,
body,
Expand All @@ -1162,7 +1316,7 @@ impl Program<Name> {
}

pub fn inline_constr_ops(self) -> Self {
self.traverse_uplc_with(false, &mut |_, term, _, _| {
self.traverse_uplc_with(true, &mut |_, term, _, _| {
if let Term::Apply { function, argument } = term {
if let Term::Var(name) = function.as_ref() {
if name.text == CONSTR_FIELDS_EXPOSER {
Expand All @@ -1184,7 +1338,7 @@ impl Program<Name> {
pub fn cast_data_reducer(self) -> Self {
let mut applied_ids = vec![];

self.traverse_uplc_with(false, &mut |id, term, mut arg_stack, _scope| {
self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| {
match term {
Term::Apply { function, .. } => {
// We are apply some arg so now we unwrap the id of the applied arg
Expand Down Expand Up @@ -1312,7 +1466,7 @@ impl Program<Name> {
pub fn convert_arithmetic_ops(self) -> Self {
let mut constants_to_flip = vec![];

self.traverse_uplc_with(false, &mut |id, term, arg_stack, _scope| match term {
self.traverse_uplc_with(true, &mut |id, term, arg_stack, _scope| match term {
Term::Apply { argument, .. } => {
let id = id.unwrap();

Expand Down Expand Up @@ -1360,13 +1514,10 @@ impl Program<Name> {
self.traverse_uplc_with(false, &mut |_id, term, arg_stack, scope| match term {
Term::Builtin(func) => {
if func.can_curry_builtin() && arg_stack.len() == func.arity() {
let is_order_agnostic = func.is_order_agnostic_builtin();

// In the case of order agnostic builtins we want to sort the args by constant first
// This gives us the opportunity to curry constants that often pop up in the code

let builtin_args =
BuiltinArgs::args_from_arg_stack(arg_stack, is_order_agnostic);
let builtin_args = BuiltinArgs::args_from_arg_stack(arg_stack, *func);

// First we see if we have already curried this builtin before
let mut id_vec = if let Some((index, _)) =
Expand Down Expand Up @@ -1480,10 +1631,7 @@ impl Program<Name> {
arg_stack.reverse();
}

let builtin_args = BuiltinArgs::args_from_arg_stack(
arg_stack,
func.is_order_agnostic_builtin(),
);
let builtin_args = BuiltinArgs::args_from_arg_stack(arg_stack, *func);

let Some(mut id_vec) = curried_builtin.get_id_args(&builtin_args) else {
return;
Expand Down Expand Up @@ -1594,14 +1742,14 @@ fn var_occurrences(
if parameter_name.text == NO_INLINE {
var_occurrences(body.as_ref(), search_for, arg_stack, force_stack)
.no_inline_if_found()
} else if parameter_name.text != search_for.text
|| parameter_name.unique != search_for.unique
} else if parameter_name.text == search_for.text
&& parameter_name.unique == search_for.unique
{
VarLookup::new()
} else {
let not_applied: isize = isize::from(arg_stack.pop().is_none());
var_occurrences(body.as_ref(), search_for, arg_stack, force_stack)
.delay_if_found(not_applied)
} else {
VarLookup::new()
}
}
Term::Apply { function, argument } => {
Expand Down Expand Up @@ -1646,15 +1794,15 @@ fn substitute_var(term: &Term<Name>, original: Rc<Name>, replace_with: &Term<Nam
parameter_name,
body,
} => {
if parameter_name.text != original.text || parameter_name.unique != original.unique {
if parameter_name.text == original.text && parameter_name.unique == original.unique {
Term::Lambda {
parameter_name: parameter_name.clone(),
body: substitute_var(body.as_ref(), original, replace_with).into(),
body: body.clone(),
}
} else {
Term::Lambda {
parameter_name: parameter_name.clone(),
body: body.clone(),
body: substitute_var(body.as_ref(), original, replace_with).into(),
}
}
}
Expand All @@ -1676,15 +1824,15 @@ fn replace_identity_usage(term: &Term<Name>, original: Rc<Name>) -> Term<Name> {
parameter_name,
body,
} => {
if parameter_name.text != original.text || parameter_name.unique != original.unique {
if parameter_name.text == original.text && parameter_name.unique == original.unique {
Term::Lambda {
parameter_name: parameter_name.clone(),
body: Rc::new(replace_identity_usage(body.as_ref(), original)),
body: body.clone(),
}
} else {
Term::Lambda {
parameter_name: parameter_name.clone(),
body: body.clone(),
body: Rc::new(replace_identity_usage(body.as_ref(), original)),
}
}
}
Expand Down

0 comments on commit d6cc450

Please sign in to comment.