Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed loop capture of snapshoted variables #5934

Merged
merged 1 commit into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions corelib/src/test/language_features/while_test.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,34 @@ fn test_outer_loop_break() {
};
assert_eq!(i, 10);
}

#[test]
fn test_borrow_usage() {
let mut i = 0;
let arr = array![1, 2, 3, 4];
while i != arr.len() {
i += 1;
};
assert_eq!(arr.len(), 4);
}

#[derive(Drop)]
struct NonCopy {
x: felt252,
}

fn assert_x_eq(a: @NonCopy, x: felt252) {
assert_eq!(a.x, @x);
}

#[test]
fn test_borrow_with_inner_change() {
let mut a = NonCopy { x: 0 };
let mut i = 0;
while i != 5 {
a.x = i;
assert_x_eq(@a, i);
i += 1;
};
}

49 changes: 49 additions & 0 deletions crates/cairo-lang-lowering/src/lower/block_builder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use cairo_lang_defs::ids::{MemberId, NamedLanguageElementId};
use cairo_lang_diagnostics::Maybe;
use cairo_lang_semantic as semantic;
use cairo_lang_semantic::types::{peel_snapshots, wrap_in_snapshots};
use cairo_lang_syntax::node::TypedStablePtr;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
Expand All @@ -26,6 +27,8 @@ use crate::{
pub struct BlockBuilder {
/// A store for semantic variables, owning their OwnedVariable instances.
pub semantics: SemanticLoweringMapping,
/// The semantic variables that are captured as snapshots in this block.
pub snapped_semantics: OrderedHashMap<MemberPath, VariableId>,
/// The semantic variables that are added/changed in this block.
changed_member_paths: OrderedHashSet<MemberPath>,
/// Current sequence of lowered statements emitted.
Expand All @@ -38,6 +41,7 @@ impl BlockBuilder {
pub fn root(_ctx: &mut LoweringContext<'_, '_>, block_id: BlockId) -> Self {
BlockBuilder {
semantics: Default::default(),
snapped_semantics: Default::default(),
changed_member_paths: Default::default(),
statements: Default::default(),
block_id,
Expand All @@ -48,6 +52,7 @@ impl BlockBuilder {
pub fn child_block_builder(&self, block_id: BlockId) -> BlockBuilder {
BlockBuilder {
semantics: self.semantics.clone(),
snapped_semantics: self.snapped_semantics.clone(),
changed_member_paths: Default::default(),
statements: Default::default(),
block_id,
Expand All @@ -59,6 +64,7 @@ impl BlockBuilder {
pub fn sibling_block_builder(&self, block_id: BlockId) -> BlockBuilder {
BlockBuilder {
semantics: self.semantics.clone(),
snapped_semantics: self.snapped_semantics.clone(),
changed_member_paths: self.changed_member_paths.clone(),
statements: Default::default(),
block_id,
Expand Down Expand Up @@ -119,6 +125,49 @@ impl BlockBuilder {
.map(|var_id| VarUsage { var_id, location })
}

/// Updates the reference of a semantic variable to a snapshot of its lowered variable.
pub fn update_snap_ref(&mut self, member_path: &ExprVarMemberPath, var: VariableId) {
self.snapped_semantics.insert(member_path.into(), var);
}

/// Gets the reference of a snapshot of semantic variable, possibly by deconstructing a
/// its parents.
pub fn get_snap_ref(
&mut self,
ctx: &mut LoweringContext<'_, '_>,
member_path: &ExprVarMemberPath,
) -> Option<VarUsage> {
let location = ctx.get_location(member_path.stable_ptr().untyped());
if let Some(var_id) = self.snapped_semantics.get::<MemberPath>(&member_path.into()) {
return Some(VarUsage { var_id: *var_id, location });
}
let ExprVarMemberPath::Member { parent, member_id, concrete_struct_id, .. } = member_path
else {
return None;
};
// TODO(TomerStarkware): Consider adding the result to snap_semantics to avoid
// recomputation.
let parent_var = self.get_snap_ref(ctx, parent)?;
let members = ctx.db.concrete_struct_members(*concrete_struct_id).ok()?;
let (parent_number_of_snapshots, _) =
peel_snapshots(ctx.db.upcast(), ctx.variables[parent_var.var_id].ty);
let member_idx = members.iter().position(|(_, member)| member.id == *member_id)?;
Some(
generators::StructMemberAccess {
input: parent_var,
member_tys: members
.into_iter()
.map(|(_, member)| {
wrap_in_snapshots(ctx.db.upcast(), member.ty, parent_number_of_snapshots)
})
.collect(),
member_idx,
location,
}
.add(ctx, &mut self.statements),
)
}

/// Gets the type of a semantic variable.
pub fn get_ty(
&mut self,
Expand Down
15 changes: 2 additions & 13 deletions crates/cairo-lang-lowering/src/lower/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ use cairo_lang_utils::Intern;
use defs::diagnostic_utils::StableLocation;
use id_arena::Arena;
use itertools::{zip_eq, Itertools};
use semantic::corelib::{core_module, get_ty_by_name, get_usize_ty};
use semantic::corelib::{core_module, get_ty_by_name};
use semantic::expr::inference::InferenceError;
use semantic::items::constant::value_as_const_value;
use semantic::types::wrap_in_snapshots;
use semantic::{ExprVarMemberPath, MatchArmSelector, TypeLongId};
use {cairo_lang_defs as defs, cairo_lang_semantic as semantic};
Expand Down Expand Up @@ -300,17 +299,7 @@ impl LoweredExpr {
LoweredExpr::Snapshot { expr, .. } => {
wrap_in_snapshots(ctx.db.upcast(), expr.ty(ctx), 1)
}
LoweredExpr::FixedSizeArray { exprs, .. } => semantic::TypeLongId::FixedSizeArray {
type_id: exprs[0].ty(ctx),
size: value_as_const_value(
ctx.db.upcast(),
get_usize_ty(ctx.db.upcast()),
&exprs.len().into(),
)
.unwrap()
.intern(ctx.db),
}
.intern(ctx.db),
LoweredExpr::FixedSizeArray { ty, .. } => *ty,
}
}
pub fn location(&self) -> LocationId {
Expand Down
101 changes: 89 additions & 12 deletions crates/cairo-lang-lowering/src/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use std::vec;
use block_builder::BlockBuilder;
use cairo_lang_debug::DebugWithDb;
use cairo_lang_diagnostics::{Diagnostics, Maybe};
use cairo_lang_semantic::corelib::{self, unwrap_error_propagation_type, ErrorPropagationType};
use cairo_lang_semantic::corelib::{unwrap_error_propagation_type, ErrorPropagationType};
use cairo_lang_semantic::db::SemanticGroup;
use cairo_lang_semantic::{LocalVariable, VarId};
use cairo_lang_semantic::{corelib, ExprVar, LocalVariable, VarId};
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
use cairo_lang_syntax::node::TypedStablePtr;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
Expand All @@ -27,6 +27,7 @@ use semantic::{
ExprFunctionCallArg, ExprId, ExprPropagateError, ExprVarMemberPath, GenericArgumentId,
MatchArmSelector, SemanticDiagnostic, TypeLongId,
};
use usage::MemberPath;
use {cairo_lang_defs as defs, cairo_lang_semantic as semantic};

use self::block_builder::SealedBlockBuilder;
Expand Down Expand Up @@ -413,6 +414,7 @@ pub fn lower_loop_function(
function_id: FunctionWithBodyId,
loop_signature: Signature,
loop_expr_id: semantic::ExprId,
snapped_params: &OrderedHashMap<MemberPath, semantic::ExprVarMemberPath>,
) -> Maybe<FlatLowered> {
let mut ctx = LoweringContext::new(encapsulating_ctx, function_id, loop_signature.clone())?;
let old_loop_expr_id = std::mem::replace(&mut ctx.current_loop_expr_id, Some(loop_expr_id));
Expand All @@ -429,7 +431,11 @@ pub fn lower_loop_function(
.map(|param| {
let location = ctx.get_location(param.stable_ptr().untyped());
let var = ctx.new_var(VarRequest { ty: param.ty(), location });
builder.semantics.introduce((&param).into(), var);
if snapped_params.contains_key::<MemberPath>(&(&param).into()) {
builder.update_snap_ref(&param, var)
} else {
builder.semantics.introduce((&param).into(), var);
}
var
})
.collect_vec();
Expand Down Expand Up @@ -1116,8 +1122,31 @@ fn lower_expr_snapshot(
builder: &mut BlockBuilder,
) -> LoweringResult<LoweredExpr> {
log::trace!("Lowering a snapshot: {:?}", expr.debug(&ctx.expr_formatter));
// If the inner expression is a variable, or a member access, and we already have a snapshot var
// we can use it without creating a new one.
// Note that in a closure we might only have a snapshot of the variable and not the original.
match &ctx.function_body.exprs[expr.inner] {
semantic::Expr::Var(expr_var) => {
let member_path = ExprVarMemberPath::Var(expr_var.clone());
if let Some(var) = builder.get_snap_ref(ctx, &member_path) {
return Ok(LoweredExpr::AtVariable(var));
}
}
semantic::Expr::MemberAccess(expr) => {
if let Some(var) = expr
.member_path
.clone()
.and_then(|member_path| builder.get_snap_ref(ctx, &member_path))
{
return Ok(LoweredExpr::AtVariable(var));
}
}
_ => {}
}
let lowered = lower_expr(ctx, builder, expr.inner)?;

let location = ctx.get_location(expr.stable_ptr.untyped());
let expr = Box::new(lower_expr(ctx, builder, expr.inner)?);
let expr = Box::new(lowered);
Ok(LoweredExpr::Snapshot { expr, location })
}

Expand Down Expand Up @@ -1348,7 +1377,26 @@ fn lower_expr_loop(
let usage = &ctx.block_usages.block_usages[&loop_expr_id];

// Determine signature.
let params = usage.usage.iter().map(|(_, expr)| expr.clone()).collect_vec();
let params = usage
.usage
.iter()
.map(|(_, expr)| expr.clone())
.chain(usage.snap_usage.iter().map(|(_, expr)| match expr {
ExprVarMemberPath::Var(var) => ExprVarMemberPath::Var(ExprVar {
ty: wrap_in_snapshots(ctx.db.upcast(), var.ty, 1),
..*var
}),
ExprVarMemberPath::Member { parent, member_id, stable_ptr, concrete_struct_id, ty } => {
ExprVarMemberPath::Member {
parent: parent.clone(),
member_id: *member_id,
stable_ptr: *stable_ptr,
concrete_struct_id: *concrete_struct_id,
ty: wrap_in_snapshots(ctx.db.upcast(), *ty, 1),
}
}
}))
.collect_vec();
let extra_rets = usage.changes.iter().map(|(_, expr)| expr.clone()).collect_vec();

let loop_signature = Signature {
Expand All @@ -1367,17 +1415,37 @@ fn lower_expr_loop(
}
.intern(ctx.db);

let snap_usage = ctx.block_usages.block_usages[&loop_expr_id].snap_usage.clone();

// Generate the function.
let encapsulating_ctx = std::mem::take(&mut ctx.encapsulating_ctx).unwrap();
let lowered =
lower_loop_function(encapsulating_ctx, function, loop_signature.clone(), loop_expr_id)
.map_err(LoweringFlowError::Failed)?;
let lowered = lower_loop_function(
encapsulating_ctx,
function,
loop_signature.clone(),
loop_expr_id,
&snap_usage,
)
.map_err(LoweringFlowError::Failed)?;
// TODO(spapini): Recursive call.
encapsulating_ctx.lowerings.insert(loop_expr_id, lowered);

ctx.encapsulating_ctx = Some(encapsulating_ctx);
let old_loop_expr_id = std::mem::replace(&mut ctx.current_loop_expr_id, Some(loop_expr_id));
for snapshot_param in snap_usage.values() {
// if we have access to the real member we generate a snapshot, otherwise it should be
// accessible with `builder.get_snap_ref`
if let Some(input) = builder.get_ref(ctx, snapshot_param) {
let (original, snapped) = generators::Snapshot {
input,
location: ctx.get_location(snapshot_param.stable_ptr().untyped()),
}
.add(ctx, &mut builder.statements);
builder.update_snap_ref(snapshot_param, snapped);
builder.update_ref(ctx, snapshot_param, original);
}
}
let call = call_loop_func(ctx, loop_signature, builder, loop_expr_id, stable_ptr.untyped());

ctx.current_loop_expr_id = old_loop_expr_id;
call
}
Expand All @@ -1402,9 +1470,18 @@ fn call_loop_func(
.params
.into_iter()
.map(|param| {
builder.get_ref(ctx, &param).ok_or_else(|| {
LoweringFlowError::Failed(ctx.diagnostics.report(stable_ptr, MemberPathLoop))
})
builder
.get_ref(ctx, &param)
.and_then(|var| (ctx.variables[var.var_id].ty == param.ty()).then_some(var))
.or_else(|| {
let var = builder.get_snap_ref(ctx, &param)?;
(ctx.variables[var.var_id].ty == param.ty()).then_some(var)
})
.ok_or_else(|| {
// TODO(TomerStaskware): make sure this is unreachable and remove
// `MemberPathLoop` diagnostic.
LoweringFlowError::Failed(ctx.diagnostics.report(stable_ptr, MemberPathLoop))
})
})
.collect::<LoweringResult<Vec<_>>>()?;
let extra_ret_tys = loop_signature.extra_rets.iter().map(|path| path.ty()).collect_vec();
Expand Down
Loading