Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added boundLog #8870

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 205 additions & 37 deletions testdata/default/cheats/Wallet.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,200 @@ import "cheats/Vm.sol";
contract Foo {}

contract WalletTest is DSTest {
Vm constant vm = Vm(HEVM_ADDRESS);
Vm constant vm = Vm(HEVM_ADDRESS); // Vm contract address

uint256 internal constant Q = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141;
uint256 internal constant Q =
0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141; // constant acc to secp256k1 for generating PK
uint256 private constant UINT256_MAX =
115792089237316195423570985008687907853269984665640564039457584007913129639935;
115792089237316195423570985008687907853269984665640564039457584007913129639935; // max num stored in uin256

enum DistributionType {
Uniform,
Logarithmic
} // enum for the distribution type
enum TypeUint {
Uint8,
Uint16,
Uint32,
Uint64,
Uint128,
Uint256
} // enum for the uint type

struct ParamConfig {
uint256 min;
TypeUint max;
DistributionType distributionType;
uint256[] fixtures;
uint256[] excluded;
} // struct to changes the configs and all

struct UintValue {
TypeUint uintType;
TypeUint value;
}

function addressOf(uint256 x, uint256 y) internal pure returns (address) {
return address(uint160(uint256(keccak256(abi.encode(x, y)))));
ParamConfig public pkConfig;

constructor() {
pkConfig = ParamConfig({
min: 1,
max: TypeUint(Q - 1),
distributionType: DistributionType.Logarithmic,
fixtures: new uint256[](0),
excluded: new uint256[](0)
});
} // the constructor sets the DistributionType = Logarithmic

// Separate functions for different uint types
function getUintType(uint8 value) internal pure returns (UintValue memory) {
return UintValue(TypeUint.Uint8, TypeUint(value));
}
function getUintType(
uint16 value
) internal pure returns (UintValue memory) {
return UintValue(TypeUint.Uint16, TypeUint(value));
}
function getUintType(
uint32 value
) internal pure returns (UintValue memory) {
return UintValue(TypeUint.Uint32, TypeUint(value));
}
function getUintType(
uint64 value
) internal pure returns (UintValue memory) {
return UintValue(TypeUint.Uint64, TypeUint(value));
}
function getUintType(
uint128 value
) internal pure returns (UintValue memory) {
return UintValue(TypeUint.Uint128, TypeUint(value));
}
function getUintType(
uint256 value
) internal pure returns (UintValue memory) {
return UintValue(TypeUint.Uint256, TypeUint(value));
}

function bound(uint256 x, uint256 min, uint256 max) internal pure virtual returns (uint256 result) {
require(min <= max, "min needs to be less than max");
// If x is between min and max, return x directly. This is to ensure that dictionary values
// do not get shifted if the min is nonzero. More info: https://github.com/foundry-rs/forge-std/issues/188
if (x >= min && x <= max) return x;
// solve for this max and make it TypeUint type and remove the uint256 from it in the paramConfig struct
function determineDistributionType(
uint256 min,
TypeUint max
) internal pure returns (DistributionType) {
if (
max == TypeUint.Uint8 ||
max == TypeUint.Uint16 ||
max == TypeUint.Uint32 ||
max == TypeUint.Uint64
) {
return DistributionType.Uniform;
} else {
return DistributionType.Logarithmic;
}
}

uint256 size = max - min + 1;
// converts Public key to Ethereum address using keccak256 hash
function addressOf(uint256 x, uint256 y) internal pure returns (address) {
return address(uint160(uint256(keccak256(abi.encode(x, y)))));
}

// If the value is 0, 1, 2, 3, wrap that to min, min+1, min+2, min+3. Similarly for the UINT256_MAX side.
// This helps ensure coverage of the min/max values.
if (x <= 3 && size > x) return min + x;
if (x >= UINT256_MAX - 3 && size > UINT256_MAX - x) return max - (UINT256_MAX - x);
function getMaxValueForType(
TypeUint uintType
) internal pure returns (uint256) {
if (uintType == TypeUint.Uint8) return type(uint8).max;
if (uintType == TypeUint.Uint16) return type(uint16).max;
if (uintType == TypeUint.Uint32) return type(uint32).max;
if (uintType == TypeUint.Uint64) return type(uint64).max;
if (uintType == TypeUint.Uint128) return type(uint128).max;
return type(uint256).max;
}

// Otherwise, wrap x into the range [min, max], i.e. the range is inclusive.
if (x > max) {
uint256 diff = x - max;
function bound(
uint256 x,
ParamConfig memory config
) internal pure virtual returns (uint256 result) {
uint256 maxValue = getMaxValueForType(config.max);
DistributionType actualDistributionType = determineDistributionType(
config.min,
config.max
);

if (actualDistributionType == DistributionType.Logarithmic) {
return boundLog(x, config.min, config.max);
}
require(config.min <= maxValue, "min needs to be less than max");
if (x >= config.min && x <= maxValue) return x;
uint256 size = maxValue - config.min + 1;
if (x <= 3 && size > x) return config.min + x;
if (x >= UINT256_MAX - 3 && size > UINT256_MAX - x)
return maxValue - (UINT256_MAX - x);
if (x > maxValue) {
uint256 diff = x - maxValue;
uint256 rem = diff % size;
if (rem == 0) return max;
result = min + rem - 1;
} else if (x < min) {
uint256 diff = min - x;
if (rem == 0) return maxValue;
return config.min + rem - 1;
} else if (x < config.min) {
uint256 diff = config.min - x;
uint256 rem = diff % size;
if (rem == 0) return min;
result = max - rem + 1;
if (rem == 0) return config.min;
return maxValue - rem + 1;
}
}

function boundLog(
uint256 x,
uint256 min,
TypeUint max
) internal pure returns (uint256) {
require(min > 0, "min must be greater than 0 for log distribution");
uint256 maxValue = getMaxValueForType(max);
require(min < maxValue, "min must be less than max");
uint256 logMin = log2Approximation(2 * min);
uint256 logMax = log2Approximation(2 ** (maxValue + 1) - 1);
uint256 logValue = bound(
x,
ParamConfig(
logMin,
TypeUint.Uint256,
DistributionType.Uniform,
new uint256[](0),
new uint256[](0)
)
);
return exp2Approximation(logValue);
}

function log2Approximation(uint256 x) internal pure returns (uint256) {
require(x > 0, "log2 of less than equal to zero is undefined");

uint256 n = 0;
while (x > 1) {
x >>= 1;
n++;
}
return n;
}

function exp2Approximation(uint256 x) internal pure returns (uint256) {
if (x == 0) return 1;

uint256 result = 2;
for (uint256 i = 1; i < x; i++) {
result *= 2;
}
return result;
}

function getUintType(uint256 value) internal pure returns (string memory) {
if (value <= type(uint8).max) return "uint8";
if (value <= type(uint16).max) return "uint16";
if (value <= type(uint32).max) return "uint32";
if (value <= type(uint64).max) return "uint64";
if (value <= type(uint128).max) return "uint128";
if (value <= type(uint256).max) return "uint256";
}

// tests that wallet is created with the address derived from PK and label is set correctly
function testCreateWalletStringPrivAndLabel() public {
bytes memory privKey = "this is a priv key";
Vm.Wallet memory wallet = vm.createWallet(string(privKey));
Expand All @@ -60,8 +217,9 @@ contract WalletTest is DSTest {
assertEq(label, string(privKey), "labelled address != wallet.addr");
}

// tests creation of PK using a seed
function testCreateWalletPrivKeyNoLabel(uint256 pkSeed) public {
uint256 pk = bound(pkSeed, 1, Q - 1);
uint256 pk = bound(pkSeed, pkConfig);

Vm.Wallet memory wallet = vm.createWallet(pk);

Expand All @@ -74,10 +232,11 @@ contract WalletTest is DSTest {
assertEq(expectedAddr, wallet.addr);
}

// tests creation of PK using a seed and checks labels too
function testCreateWalletPrivKeyWithLabel(uint256 pkSeed) public {
string memory label = "labelled wallet";

uint256 pk = bound(pkSeed, 1, Q - 1);
uint256 pk = bound(pkSeed, pkConfig);

Vm.Wallet memory wallet = vm.createWallet(pk, label);

Expand All @@ -92,9 +251,9 @@ contract WalletTest is DSTest {
string memory expectedLabel = vm.getLabel(wallet.addr);
assertEq(expectedLabel, label, "labelled address != wallet.addr");
}

// tests signing a has using PK and checks the address recovered from the signautre is correct wallet address
function testSignWithWalletDigest(uint256 pkSeed, bytes32 digest) public {
uint256 pk = bound(pkSeed, 1, Q - 1);
uint256 pk = bound(pkSeed, pkConfig);

Vm.Wallet memory wallet = vm.createWallet(pk);

Expand All @@ -103,9 +262,12 @@ contract WalletTest is DSTest {
address recovered = ecrecover(digest, v, r, s);
assertEq(recovered, wallet.addr);
}

function testSignCompactWithWalletDigest(uint256 pkSeed, bytes32 digest) public {
uint256 pk = bound(pkSeed, 1, Q - 1);
// tests signing a has using PK and checks the address recovered from the signautre is correct wallet address and also checks the signature is compact
function testSignCompactWithWalletDigest(
uint256 pkSeed,
bytes32 digest
) public {
uint256 pk = bound(pkSeed, pkConfig);

Vm.Wallet memory wallet = vm.createWallet(pk);

Expand All @@ -125,17 +287,23 @@ contract WalletTest is DSTest {
address recovered = ecrecover(digest, v, r, s);
assertEq(recovered, wallet.addr);
}

function testSignWithWalletMessage(uint256 pkSeed, bytes memory message) public {
// signs a message after performing the checks in above functions
function testSignWithWalletMessage(
uint256 pkSeed,
bytes memory message
) public {
testSignWithWalletDigest(pkSeed, keccak256(message));
}

function testSignCompactWithWalletMessage(uint256 pkSeed, bytes memory message) public {
// // signs a message after performing the checks in above functions in compact way
function testSignCompactWithWalletMessage(
uint256 pkSeed,
bytes memory message
) public {
testSignCompactWithWalletDigest(pkSeed, keccak256(message));
}

// check sthe nonces of the wallet before and after a prank
function testGetNonceWallet(uint256 pkSeed) public {
uint256 pk = bound(pkSeed, 1, Q - 1);
uint256 pk = bound(pkSeed, pkConfig);

Vm.Wallet memory wallet = vm.createWallet(pk);

Expand Down
Loading