Skip to content

Commit

Permalink
Fix bugs in memcpyopt (#6359)
Browse files Browse the repository at this point in the history
## Description

Fixes  #6321
Fixes #6360
Fixes #6361

---------

Co-authored-by: IGI-111 <igi-111@protonmail.com>
Co-authored-by: Sophie Dankel <47993817+sdankel@users.noreply.github.com>
  • Loading branch information
3 people authored and esdrubal committed Aug 13, 2024
1 parent d4ba4f4 commit 52d2a46
Show file tree
Hide file tree
Showing 29 changed files with 284 additions and 111 deletions.
6 changes: 3 additions & 3 deletions forc-plugins/forc-client/tests/deploy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async fn test_simple_deploy() {
node.kill().unwrap();
let expected = vec![DeployedContract {
id: ContractId::from_str(
"822c8d3672471f64f14f326447793c7377b6e430122db23b622880ccbd8a33ef",
"ad0bba17e0838ef859abe2693d8a5e3bc4e7cfb901601e30f4dc34999fda6335",
)
.unwrap(),
proxy: None,
Expand Down Expand Up @@ -185,12 +185,12 @@ async fn test_deploy_fresh_proxy() {
node.kill().unwrap();
let impl_contract = DeployedContract {
id: ContractId::from_str(
"822c8d3672471f64f14f326447793c7377b6e430122db23b622880ccbd8a33ef",
"ad0bba17e0838ef859abe2693d8a5e3bc4e7cfb901601e30f4dc34999fda6335",
)
.unwrap(),
proxy: Some(
ContractId::from_str(
"3da2f8ee967c62496db4b71df0acd7c3fea1e494fee1de0cd16e7abd22e6057f",
"5237df8db3edbe825ce83f4292094923c989efe3265b0115ed050925593a3488",
)
.unwrap(),
),
Expand Down
186 changes: 136 additions & 50 deletions sway-ir/src/optimize/memcpyopt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ use rustc_hash::{FxHashMap, FxHashSet};
use sway_types::{FxIndexMap, FxIndexSet};

use crate::{
get_gep_referred_symbols, get_gep_symbol, get_referred_symbol, get_referred_symbols,
memory_utils, AnalysisResults, Block, Context, EscapedSymbols, Function, InstOp, Instruction,
IrError, LocalVar, Pass, PassMutability, ReferredSymbols, ScopedPass, Symbol, Type, Value,
ValueDatum, ESCAPED_SYMBOLS_NAME,
get_gep_symbol, get_referred_symbol, get_referred_symbols, get_stored_symbols, memory_utils,
AnalysisResults, Block, Context, EscapedSymbols, Function, InstOp, Instruction,
InstructionInserter, IrError, LocalVar, Pass, PassMutability, ReferredSymbols, ScopedPass,
Symbol, Type, Value, ValueDatum, ESCAPED_SYMBOLS_NAME,
};

pub const MEMCPYOPT_NAME: &str = "memcpyopt";
Expand Down Expand Up @@ -735,21 +735,53 @@ fn local_copy_prop(
Ok(modified)
}

struct Candidate {
load_val: Value,
store_val: Value,
dst_ptr: Value,
src_ptr: Value,
}

enum CandidateKind {
/// If aggregates are clobbered b/w a load and the store, we still need to,
/// for correctness (because asmgen cannot handle aggregate loads and stores)
/// do the memcpy. So we insert a memcpy to a temporary stack location right after
/// the load, and memcpy it to the store pointer at the point of store.
ClobberedNoncopyType(Candidate),
NonClobbered(Candidate),
}

// Is (an alias of) src_ptr clobbered on any path from load_val to store_val?
fn is_clobbered(
context: &Context,
store_block: Block,
store_val: Value,
load_val: Value,
src_ptr: Value,
Candidate {
load_val,
store_val,
dst_ptr,
src_ptr,
}: &Candidate,
) -> bool {
let store_block = store_val.get_instruction(context).unwrap().parent;

let mut iter = store_block
.instruction_iter(context)
.rev()
.skip_while(|i| i != &store_val);
assert!(iter.next().unwrap() == store_val);
.skip_while(|i| i != store_val);
assert!(iter.next().unwrap() == *store_val);

let ReferredSymbols::Complete(src_symbols) = get_referred_symbols(context, *src_ptr) else {
return true;
};

let src_symbols = get_gep_referred_symbols(context, src_ptr);
let ReferredSymbols::Complete(dst_symbols) = get_referred_symbols(context, *dst_ptr) else {
return true;
};

// If the source and destination may have an overlap, we'll end up generating a mcp
// with overlapping source/destination which is not allowed.
if src_symbols.intersection(&dst_symbols).next().is_some() {
return true;
}

// Scan backwards till we encounter load_val, checking if
// any store aliases with src_ptr.
Expand All @@ -759,25 +791,17 @@ fn is_clobbered(
'next_job: while let Some((block, iter)) = worklist.pop() {
visited.insert(block);
for inst in iter {
if inst == load_val || inst == store_val {
if inst == *load_val || inst == *store_val {
// We don't need to go beyond either the source load or the candidate store.
continue 'next_job;
}
if let Some(Instruction {
op:
InstOp::Store {
dst_val_ptr,
stored_val: _,
},
..
}) = inst.get_instruction(context)
{
if get_gep_referred_symbols(context, *dst_val_ptr)
.iter()
.any(|sym| src_symbols.contains(sym))
{
let stored_syms = get_stored_symbols(context, inst);
if let ReferredSymbols::Complete(syms) = stored_syms {
if syms.iter().any(|sym| src_symbols.contains(sym)) {
return true;
}
} else {
return true;
}
}
for pred in block.pred_iter(context) {
Expand All @@ -793,12 +817,20 @@ fn is_clobbered(
false
}

// This is a copy of sway_core::asm_generation::fuel::fuel_asm_builder::FuelAsmBuilder::is_copy_type.
fn is_copy_type(ty: &Type, context: &Context) -> bool {
ty.is_unit(context)
|| ty.is_never(context)
|| ty.is_bool(context)
|| ty.get_uint_width(context).map(|x| x < 256).unwrap_or(false)
}

fn load_store_to_memcopy(context: &mut Context, function: Function) -> Result<bool, IrError> {
// Find any `store`s of `load`s. These can be replaced with `mem_copy` and are especially
// important for non-copy types on architectures which don't support loading them.
let candidates = function
.instruction_iter(context)
.filter_map(|(block, store_instr_val)| {
.filter_map(|(_, store_instr_val)| {
store_instr_val
.get_instruction(context)
.and_then(|instr| {
Expand Down Expand Up @@ -826,41 +858,95 @@ fn load_store_to_memcopy(context: &mut Context, function: Function) -> Result<bo
..
} = src_instr
{
Some((
block,
src_instr_val,
store_instr_val,
*dst_val_ptr,
*src_val_ptr,
))
Some(Candidate {
load_val: src_instr_val,
store_val: store_instr_val,
dst_ptr: *dst_val_ptr,
src_ptr: *src_val_ptr,
})
} else {
None
}
})
.and_then(|candidate @ Candidate { dst_ptr, .. }| {
// Check that there's no path from load_val to store_val that might overwrite src_ptr.
if !is_clobbered(context, &candidate) {
Some(CandidateKind::NonClobbered(candidate))
} else if !is_copy_type(&dst_ptr.match_ptr_type(context).unwrap(), context) {
Some(CandidateKind::ClobberedNoncopyType(candidate))
} else {
None
}
})
.and_then(
|candidate @ (block, load_val, store_val, _dst_ptr, src_ptr)| {
// Ensure that there's no path from load_val to store_val that might overwrite src_ptr.
(!is_clobbered(context, block, store_val, load_val, src_ptr))
.then_some(candidate)
},
)
})
.collect::<Vec<_>>();

if candidates.is_empty() {
return Ok(false);
}

for (block, _src_instr_val, store_val, dst_val_ptr, src_val_ptr) in candidates {
let mem_copy_val = Value::new_instruction(
context,
block,
InstOp::MemCopyVal {
dst_val_ptr,
src_val_ptr,
},
);
block.replace_instruction(context, store_val, mem_copy_val, true)?;
for candidate in candidates {
match candidate {
CandidateKind::ClobberedNoncopyType(Candidate {
load_val,
store_val,
dst_ptr,
src_ptr,
}) => {
let load_block = load_val.get_instruction(context).unwrap().parent;
let temp = function.new_unique_local_var(
context,
"__aggr_memcpy_0".into(),
src_ptr.match_ptr_type(context).unwrap(),
None,
true,
);
let temp_local =
Value::new_instruction(context, load_block, InstOp::GetLocal(temp));
let to_temp = Value::new_instruction(
context,
load_block,
InstOp::MemCopyVal {
dst_val_ptr: temp_local,
src_val_ptr: src_ptr,
},
);
let mut inserter = InstructionInserter::new(
context,
load_block,
crate::InsertionPosition::After(load_val),
);
inserter.insert_slice(&[temp_local, to_temp]);

let store_block = store_val.get_instruction(context).unwrap().parent;
let mem_copy_val = Value::new_instruction(
context,
store_block,
InstOp::MemCopyVal {
dst_val_ptr: dst_ptr,
src_val_ptr: temp_local,
},
);
store_block.replace_instruction(context, store_val, mem_copy_val, true)?;
}
CandidateKind::NonClobbered(Candidate {
dst_ptr: dst_val_ptr,
src_ptr: src_val_ptr,
store_val,
..
}) => {
let store_block = store_val.get_instruction(context).unwrap().parent;
let mem_copy_val = Value::new_instruction(
context,
store_block,
InstOp::MemCopyVal {
dst_val_ptr,
src_val_ptr,
},
);
store_block.replace_instruction(context, store_val, mem_copy_val, true)?;
}
}
}

Ok(true)
Expand Down
Loading

0 comments on commit 52d2a46

Please sign in to comment.