Skip to content

Commit

Permalink
fix(coverage): proper branch handling for if statement (#8414)
Browse files Browse the repository at this point in the history
* fix(coverage): proper instruction mapping for first branch

* Add tests
  • Loading branch information
grandizzy authored Jul 12, 2024
1 parent 758630e commit 86d583f
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 72 deletions.
94 changes: 56 additions & 38 deletions crates/evm/coverage/src/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ impl<'a> ContractVisitor<'a> {
self.visit_expression(
&node
.attribute("condition")
.ok_or_else(|| eyre::eyre!("while statement had no condition"))?,
.ok_or_else(|| eyre::eyre!("if statement had no condition"))?,
)?;

let body = node
Expand Down Expand Up @@ -211,32 +211,63 @@ impl<'a> ContractVisitor<'a> {
let branch_id = self.branch_id;

// We increase the branch ID here such that nested branches do not use the same
// branch ID as we do
// branch ID as we do.
self.branch_id += 1;

// The relevant source range for the branch is the `if(...)` statement itself and
// the true body of the if statement. The false body of the statement (if any) is
// processed as its own thing. If this source range is not processed like this, it
// is virtually impossible to correctly map instructions back to branches that
// include more complex logic like conditional logic.
self.push_branches(
&foundry_compilers::artifacts::ast::LowFidelitySourceLocation {
start: node.src.start,
length: true_body
.src
.length
.map(|length| true_body.src.start - node.src.start + length),
index: node.src.index,
},
branch_id,
);

// Process the true branch
self.visit_block_or_statement(&true_body)?;

// Process the false branch
if let Some(false_body) = node.attribute("falseBody") {
self.visit_block_or_statement(&false_body)?;
// The relevant source range for the true branch is the `if(...)` statement itself
// and the true body of the if statement. The false body of the
// statement (if any) is processed as its own thing. If this source
// range is not processed like this, it is virtually impossible to
// correctly map instructions back to branches that include more
// complex logic like conditional logic.
let true_branch_loc = &ast::LowFidelitySourceLocation {
start: node.src.start,
length: true_body
.src
.length
.map(|length| true_body.src.start - node.src.start + length),
index: node.src.index,
};

// Add the coverage item for branch 0 (true body).
self.push_item(CoverageItem {
kind: CoverageItemKind::Branch { branch_id, path_id: 0 },
loc: self.source_location_for(true_branch_loc),
hits: 0,
});

match node.attribute::<Node>("falseBody") {
// Both if/else statements.
Some(false_body) => {
// Add the coverage item for branch 1 (false body).
// The relevant source range for the false branch is the `else` statement
// itself and the false body of the else statement.
self.push_item(CoverageItem {
kind: CoverageItemKind::Branch { branch_id, path_id: 1 },
loc: self.source_location_for(&ast::LowFidelitySourceLocation {
start: node.src.start,
length: false_body.src.length.map(|length| {
false_body.src.start - true_body.src.start + length
}),
index: node.src.index,
}),
hits: 0,
});
// Process the true body.
self.visit_block_or_statement(&true_body)?;
// Process the false body.
self.visit_block_or_statement(&false_body)?;
}
None => {
// Add the coverage item for branch 1 (same true body).
self.push_item(CoverageItem {
kind: CoverageItemKind::Branch { branch_id, path_id: 1 },
loc: self.source_location_for(true_branch_loc),
hits: 0,
});
// Process the true body.
self.visit_block_or_statement(&true_body)?;
}
}

Ok(())
Expand Down Expand Up @@ -393,19 +424,6 @@ impl<'a> ContractVisitor<'a> {
line: self.source[..loc.start].lines().count(),
}
}

fn push_branches(&mut self, loc: &ast::LowFidelitySourceLocation, branch_id: usize) {
self.push_item(CoverageItem {
kind: CoverageItemKind::Branch { branch_id, path_id: 0 },
loc: self.source_location_for(loc),
hits: 0,
});
self.push_item(CoverageItem {
kind: CoverageItemKind::Branch { branch_id, path_id: 1 },
loc: self.source_location_for(loc),
hits: 0,
});
}
}

/// [`SourceAnalyzer`] result type.
Expand Down
2 changes: 1 addition & 1 deletion crates/evm/coverage/src/anchors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ pub fn find_anchor_branch(
ItemAnchor {
item_id,
// The first branch is the opcode directly after JUMPI
instruction: pc + 2,
instruction: pc + 1,
},
ItemAnchor { item_id, instruction: pc_jump },
));
Expand Down
34 changes: 1 addition & 33 deletions crates/evm/coverage/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ impl CoverageReport {
else {
continue;
};
let mut summary = summaries.entry(path).or_default();
summary += item;
*summaries.entry(path).or_default() += item;
}
}

Expand Down Expand Up @@ -424,34 +423,3 @@ impl AddAssign<&CoverageItem> for CoverageSummary {
}
}
}

impl AddAssign<&CoverageItem> for &mut CoverageSummary {
fn add_assign(&mut self, item: &CoverageItem) {
match item.kind {
CoverageItemKind::Line => {
self.line_count += 1;
if item.hits > 0 {
self.line_hits += 1;
}
}
CoverageItemKind::Statement => {
self.statement_count += 1;
if item.hits > 0 {
self.statement_hits += 1;
}
}
CoverageItemKind::Branch { .. } => {
self.branch_count += 1;
if item.hits > 0 {
self.branch_hits += 1;
}
}
CoverageItemKind::Function { .. } => {
self.function_count += 1;
if item.hits > 0 {
self.function_hits += 1;
}
}
}
}
}
179 changes: 179 additions & 0 deletions crates/forge/tests/cli/coverage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,182 @@ end[..]
"#]]
);
});

forgetest!(test_branch_coverage, |prj, cmd| {
prj.insert_ds_test();
prj.add_source(
"Foo.sol",
r#"
contract Foo {
error Gte1(uint256 number, uint256 firstElement);
enum Status {
NULL,
OPEN,
CLOSED
}
struct Item {
Status status;
uint256 value;
}
mapping(uint256 => Item) internal items;
uint256 public nextId = 1;
function getItem(uint256 id) public view returns (Item memory item) {
item = items[id];
}
function addItem(uint256 value) public returns (uint256 id) {
id = nextId;
items[id] = Item(Status.OPEN, value);
nextId++;
}
function closeIfEqValue(uint256 id, uint256 value) public {
if (items[id].value == value) {
items[id].status = Status.CLOSED;
}
}
function incrementIfEqValue(uint256 id, uint256 value) public {
if (items[id].value == value) {
items[id].value = value + 1;
}
}
function foo(uint256 a) external pure {
if (a < 10) {
if (a == 1) {
assert(a == 1);
} else {
assert(a == 5);
}
} else {
assert(a == 60);
}
}
function countOdd(uint256[] memory arr) external pure returns (uint256 count) {
uint256 length = arr.length;
for (uint256 i = 0; i < length; ++i) {
if (arr[i] % 2 == 1) {
count++;
arr[0];
}
}
}
function checkLt(uint256 number, uint256[] memory arr) external pure returns (bool) {
if (number >= arr[0]) {
revert Gte1(number, arr[0]);
}
return true;
}
}
"#,
)
.unwrap();

prj.add_source(
"FooTest.sol",
r#"
import "./test.sol";
import {Foo} from "./Foo.sol";
interface Vm {
function expectRevert(bytes calldata revertData) external;
}
contract FooTest is DSTest {
Vm constant vm = Vm(HEVM_ADDRESS);
Foo internal foo = new Foo();
function test_issue_7784() external view {
foo.foo(1);
foo.foo(5);
foo.foo(60);
}
function test_issue_4310() external {
uint256[] memory arr = new uint256[](3);
arr[0] = 78;
arr[1] = 493;
arr[2] = 700;
uint256 count = foo.countOdd(arr);
assertEq(count, 1);
arr = new uint256[](4);
arr[0] = 78;
arr[1] = 493;
arr[2] = 700;
arr[3] = 1729;
count = foo.countOdd(arr);
assertEq(count, 2);
}
function test_issue_4315() external {
uint256 value = 42;
uint256 id = foo.addItem(value);
assertEq(id, 1);
assertEq(foo.nextId(), 2);
Foo.Item memory item = foo.getItem(id);
assertEq(uint8(item.status), uint8(Foo.Status.OPEN));
assertEq(item.value, value);
foo = new Foo();
id = foo.addItem(value);
foo.closeIfEqValue(id, 903);
item = foo.getItem(id);
assertEq(uint8(item.status), uint8(Foo.Status.OPEN));
foo = new Foo();
foo.addItem(value);
foo.closeIfEqValue(id, 42);
item = foo.getItem(id);
assertEq(uint8(item.status), uint8(Foo.Status.CLOSED));
foo = new Foo();
id = foo.addItem(value);
foo.incrementIfEqValue(id, 903);
item = foo.getItem(id);
assertEq(item.value, 42);
foo = new Foo();
id = foo.addItem(value);
foo.incrementIfEqValue(id, 42);
item = foo.getItem(id);
assertEq(item.value, 43);
}
function test_issue_4309() external {
uint256[] memory arr = new uint256[](1);
arr[0] = 1;
uint256 number = 2;
vm.expectRevert(abi.encodeWithSelector(Foo.Gte1.selector, number, arr[0]));
foo.checkLt(number, arr);
number = 1;
vm.expectRevert(abi.encodeWithSelector(Foo.Gte1.selector, number, arr[0]));
foo.checkLt(number, arr);
number = 0;
bool result = foo.checkLt(number, arr);
assertTrue(result);
}
}
"#,
)
.unwrap();

// Assert 100% coverage.
cmd.arg("coverage").args(["--summary".to_string()]).assert_success().stdout_eq(str![[r#"
...
| File | % Lines | % Statements | % Branches | % Funcs |
|-------------|-----------------|-----------------|-----------------|---------------|
| src/Foo.sol | 100.00% (20/20) | 100.00% (23/23) | 100.00% (12/12) | 100.00% (7/7) |
| Total | 100.00% (20/20) | 100.00% (23/23) | 100.00% (12/12) | 100.00% (7/7) |
"#]]);
});

0 comments on commit 86d583f

Please sign in to comment.