From d22f3f40149902717736761adf9c1ced0fdd2426 Mon Sep 17 00:00:00 2001 From: Kingster Date: Sun, 14 Apr 2024 19:44:29 -0700 Subject: [PATCH 1/4] only register modules or registry can write to IPAccount --- contracts/IPAccountImpl.sol | 13 +++- contracts/IPAccountStorage.sol | 30 +++++++- contracts/lib/Errors.sol | 5 ++ script/foundry/utils/DeployHelper.sol | 5 ++ test/foundry/IPAccountStorage.t.sol | 34 +++++---- test/foundry/IPAccountStorageOps.t.sol | 102 +++++++++++++------------ 6 files changed, 120 insertions(+), 69 deletions(-) diff --git a/contracts/IPAccountImpl.sol b/contracts/IPAccountImpl.sol index 3b3bfbe6..16ef7f02 100644 --- a/contracts/IPAccountImpl.sol +++ b/contracts/IPAccountImpl.sol @@ -30,10 +30,15 @@ contract IPAccountImpl is IPAccountStorage, IIPAccount { /// in the implementation code's storage. /// This means that each cloned IPAccount will inherently use the same AccessController /// without the need for individual configuration. - /// @param accessController_ The address of the AccessController contract to be used for permission checks - constructor(address accessController_) { - if (accessController_ == address(0)) revert Errors.IPAccount__InvalidAccessController(); - accessController = accessController_; + /// @param accessController The address of the AccessController contract to be used for permission checks + constructor( + address accessController, + address ipAssetRegistry, + address licenseRegistry, + address moduleRegistry + ) IPAccountStorage(ipAssetRegistry, licenseRegistry, moduleRegistry) { + if (accessController == address(0)) revert Errors.IPAccount__InvalidAccessController(); + accessController = accessController; } /// @notice Checks if the contract supports a specific interface diff --git a/contracts/IPAccountStorage.sol b/contracts/IPAccountStorage.sol index 18d088d3..45f4cfab 100644 --- a/contracts/IPAccountStorage.sol +++ b/contracts/IPAccountStorage.sol @@ -2,6 +2,8 @@ pragma solidity ^0.8.23; import { IIPAccountStorage } from "./interfaces/IIPAccountStorage.sol"; +import { IModuleRegistry } from "./interfaces/registries/IModuleRegistry.sol"; +import { Errors } from "./lib/Errors.sol"; import { ERC165 } from "@openzeppelin/contracts/utils/introspection/ERC165.sol"; import { ShortString, ShortStrings } from "@openzeppelin/contracts/utils/ShortStrings.sol"; /// @title IPAccount Storage @@ -13,15 +15,28 @@ import { ShortString, ShortStrings } from "@openzeppelin/contracts/utils/ShortSt contract IPAccountStorage is ERC165, IIPAccountStorage { using ShortStrings for *; - mapping(bytes32 => mapping(bytes32 => string)) public stringData; + address public immutable MODULE_REGISTRY; + address public immutable LICENSE_REGISTRY; + address public immutable IP_ASSET_REGISTRY; + mapping(bytes32 => mapping(bytes32 => bytes)) public bytesData; mapping(bytes32 => mapping(bytes32 => bytes32)) public bytes32Data; - mapping(bytes32 => mapping(bytes32 => uint256)) public uint256Data; - mapping(bytes32 => mapping(bytes32 => address)) public addressData; - mapping(bytes32 => mapping(bytes32 => bool)) public boolData; + + constructor(address ipAssetRegistry, address licenseRegistry, address moduleRegistry) { + MODULE_REGISTRY = moduleRegistry; + LICENSE_REGISTRY = licenseRegistry; + IP_ASSET_REGISTRY = ipAssetRegistry; + } /// @inheritdoc IIPAccountStorage function setBytes(bytes32 key, bytes calldata value) external { + if ( + msg.sender != IP_ASSET_REGISTRY && + msg.sender != LICENSE_REGISTRY && + !IModuleRegistry(MODULE_REGISTRY).isRegistered(msg.sender) + ) { + revert Errors.IPAccountStorage__NotRegisteredModule(msg.sender); + } bytesData[_toBytes32(msg.sender)][key] = value; } /// @inheritdoc IIPAccountStorage @@ -35,6 +50,13 @@ contract IPAccountStorage is ERC165, IIPAccountStorage { /// @inheritdoc IIPAccountStorage function setBytes32(bytes32 key, bytes32 value) external { + if ( + msg.sender != IP_ASSET_REGISTRY && + msg.sender != LICENSE_REGISTRY && + !IModuleRegistry(MODULE_REGISTRY).isRegistered(msg.sender) + ) { + revert Errors.IPAccountStorage__NotRegisteredModule(msg.sender); + } bytes32Data[_toBytes32(msg.sender)][key] = value; } /// @inheritdoc IIPAccountStorage diff --git a/contracts/lib/Errors.sol b/contracts/lib/Errors.sol index 787fa929..c563a4ec 100644 --- a/contracts/lib/Errors.sol +++ b/contracts/lib/Errors.sol @@ -13,6 +13,11 @@ library Errors { error IPAccount__InvalidCalldata(); error IPAccount__InvalidAccessController(); + //////////////////////////////////////////////////////////////////////////// + // IPAccountStorage // + //////////////////////////////////////////////////////////////////////////// + error IPAccountStorage__NotRegisteredModule(address module); + //////////////////////////////////////////////////////////////////////////// // Module // //////////////////////////////////////////////////////////////////////////// diff --git a/script/foundry/utils/DeployHelper.sol b/script/foundry/utils/DeployHelper.sol index 0db43bbb..f4275195 100644 --- a/script/foundry/utils/DeployHelper.sol +++ b/script/foundry/utils/DeployHelper.sol @@ -208,6 +208,11 @@ contract DeployHelper is Script, BroadcastManager, JsonDeploymentHandler, Storag impl = address(0); // Make sure we don't deploy wrong impl _postdeploy(contractKey, address(moduleRegistry)); + contractKey = "IPAccountImpl"; + _predeploy(contractKey); + ipAccountImpl = new IPAccountImpl(address(accessController), address(moduleRegistry)); + _postdeploy(contractKey, address(ipAccountImpl)); + contractKey = "IPAssetRegistry"; _predeploy(contractKey); impl = address(new IPAssetRegistry(address(erc6551Registry), ipAccountImplAddr)); diff --git a/test/foundry/IPAccountStorage.t.sol b/test/foundry/IPAccountStorage.t.sol index 6bba7c0d..00b65ee3 100644 --- a/test/foundry/IPAccountStorage.t.sol +++ b/test/foundry/IPAccountStorage.t.sol @@ -2,14 +2,17 @@ pragma solidity ^0.8.23; import { IIPAccount } from "../../contracts/interfaces/IIPAccount.sol"; +import { BaseModule } from "../../contracts/modules/BaseModule.sol"; import { MockModule } from "./mocks/module/MockModule.sol"; import { BaseTest } from "./utils/BaseTest.t.sol"; -contract IPAccountStorageTest is BaseTest { +contract IPAccountStorageTest is BaseTest, BaseModule { MockModule public module; IIPAccount public ipAccount; + string public override name = "IPAccountStorageTest"; + function setUp() public override { super.setUp(); @@ -19,6 +22,11 @@ contract IPAccountStorageTest is BaseTest { uint256 tokenId = 100; mockNFT.mintId(owner, tokenId); ipAccount = IIPAccount(payable(ipAccountRegistry.registerIpAccount(block.chainid, address(mockNFT), tokenId))); + vm.startPrank(admin); + moduleRegistry.registerModule("MockModule", address(module)); + moduleRegistry.registerModule("IPAccountStorageTest", address(this)); + vm.stopPrank(); + } function test_IPAccountStorage_storeBytes() public { @@ -27,10 +35,10 @@ contract IPAccountStorageTest is BaseTest { } function test_IPAccountStorage_readBytes_DifferentNamespace() public { - vm.prank(vm.addr(1)); + vm.prank(address(module)); ipAccount.setBytes("test", abi.encodePacked("test")); vm.prank(vm.addr(2)); - assertEq(ipAccount.getBytes(_toBytes32(vm.addr(1)), "test"), "test"); + assertEq(ipAccount.getBytes(_toBytes32(address(module)), "test"), "test"); } function test_IPAccountStorage_storeAddressArray() public { @@ -47,10 +55,10 @@ contract IPAccountStorageTest is BaseTest { address[] memory addresses = new address[](2); addresses[0] = vm.addr(1); addresses[1] = vm.addr(2); - vm.prank(vm.addr(1)); + vm.prank(address(module)); ipAccount.setBytes("test", abi.encode(addresses)); vm.prank(vm.addr(2)); - address[] memory result = abi.decode(ipAccount.getBytes(_toBytes32(vm.addr(1)), "test"), (address[])); + address[] memory result = abi.decode(ipAccount.getBytes(_toBytes32(address(module)), "test"), (address[])); assertEq(result[0], vm.addr(1)); assertEq(result[1], vm.addr(2)); } @@ -69,10 +77,10 @@ contract IPAccountStorageTest is BaseTest { uint256[] memory uints = new uint256[](2); uints[0] = 1; uints[1] = 2; - vm.prank(vm.addr(1)); + vm.prank(address(module)); ipAccount.setBytes("test", abi.encode(uints)); vm.prank(vm.addr(2)); - uint256[] memory result = abi.decode(ipAccount.getBytes(_toBytes32(vm.addr(1)), "test"), (uint256[])); + uint256[] memory result = abi.decode(ipAccount.getBytes(_toBytes32(address(module)), "test"), (uint256[])); assertEq(result[0], 1); assertEq(result[1], 2); } @@ -91,10 +99,10 @@ contract IPAccountStorageTest is BaseTest { string[] memory strings = new string[](2); strings[0] = "test1"; strings[1] = "test2"; - vm.prank(vm.addr(1)); + vm.prank(address(module)); ipAccount.setBytes("test", abi.encode(strings)); vm.prank(vm.addr(2)); - string[] memory result = abi.decode(ipAccount.getBytes(_toBytes32(vm.addr(1)), "test"), (string[])); + string[] memory result = abi.decode(ipAccount.getBytes(_toBytes32(address(module)), "test"), (string[])); assertEq(result[0], "test1"); assertEq(result[1], "test2"); } @@ -105,10 +113,10 @@ contract IPAccountStorageTest is BaseTest { } function test_IPAccountStorage_readBytes32_differentNameSpace() public { - vm.prank(vm.addr(1)); + vm.prank(address(module)); ipAccount.setBytes32("test", bytes32(uint256(111))); vm.prank(vm.addr(2)); - assertEq(ipAccount.getBytes32(_toBytes32(vm.addr(1)), "test"), bytes32(uint256(111))); + assertEq(ipAccount.getBytes32(_toBytes32(address(module)), "test"), bytes32(uint256(111))); } function test_IPAccountStorage_storeBytes32String() public { @@ -117,10 +125,10 @@ contract IPAccountStorageTest is BaseTest { } function test_IPAccountStorage_readBytes32String_differentNameSpace() public { - vm.prank(vm.addr(1)); + vm.prank(address(module)); ipAccount.setBytes32("test", "testData"); vm.prank(vm.addr(2)); - assertEq(ipAccount.getBytes32(_toBytes32(vm.addr(1)), "test"), "testData"); + assertEq(ipAccount.getBytes32(_toBytes32(address(module)), "test"), "testData"); } function _toBytes32(address a) internal pure returns (bytes32) { diff --git a/test/foundry/IPAccountStorageOps.t.sol b/test/foundry/IPAccountStorageOps.t.sol index 8478e744..31e5ea00 100644 --- a/test/foundry/IPAccountStorageOps.t.sol +++ b/test/foundry/IPAccountStorageOps.t.sol @@ -8,12 +8,14 @@ import { IIPAccount } from "../../contracts/interfaces/IIPAccount.sol"; import { MockModule } from "./mocks/module/MockModule.sol"; import { BaseTest } from "./utils/BaseTest.t.sol"; +import { BaseModule } from "../../contracts/modules/BaseModule.sol"; -contract IPAccountStorageOpsTest is BaseTest { +contract IPAccountStorageOpsTest is BaseTest, BaseModule { using ShortStrings for *; MockModule public module; IIPAccount public ipAccount; + string public override name = "IPAccountStorageOpsTest"; function setUp() public override { super.setUp(); @@ -24,124 +26,128 @@ contract IPAccountStorageOpsTest is BaseTest { uint256 tokenId = 100; mockNFT.mintId(owner, tokenId); ipAccount = IIPAccount(payable(ipAccountRegistry.registerIpAccount(block.chainid, address(mockNFT), tokenId))); + vm.startPrank(admin); + moduleRegistry.registerModule("MockModule", address(module)); + moduleRegistry.registerModule("IPAccountStorageOpsTest", address(this)); + vm.stopPrank(); } function test_IPAccountStorageOps_setString_ShortString() public { - vm.prank(vm.addr(1)); + vm.prank(address(module)); IPAccountStorageOps.setString(ipAccount, "test".toShortString(), "test"); - vm.prank(vm.addr(1)); + vm.prank(address(module)); assertEq(IPAccountStorageOps.getString(ipAccount, "test".toShortString()), "test"); vm.prank(vm.addr(2)); - assertEq(IPAccountStorageOps.getString(ipAccount, vm.addr(1), "test".toShortString()), "test"); + assertEq(IPAccountStorageOps.getString(ipAccount, address(module), "test".toShortString()), "test"); } function test_IPAccountStorageOps_setString_bytes32() public { - vm.prank(vm.addr(1)); + vm.prank(address(module)); IPAccountStorageOps.setString(ipAccount, bytes32("test"), "test"); - vm.prank(vm.addr(1)); + vm.prank(address(module)); assertEq(IPAccountStorageOps.getString(ipAccount, "test".toShortString()), "test"); vm.prank(vm.addr(2)); - assertEq(IPAccountStorageOps.getString(ipAccount, vm.addr(1), "test".toShortString()), "test"); + assertEq(IPAccountStorageOps.getString(ipAccount, address(module), "test".toShortString()), "test"); vm.prank(vm.addr(2)); - assertEq(IPAccountStorageOps.getString(ipAccount, vm.addr(1), bytes32("test")), "test"); + assertEq(IPAccountStorageOps.getString(ipAccount, address(module), bytes32("test")), "test"); } function test_IPAccountStorageOps_setAddress_ShortString() public { - vm.prank(vm.addr(1)); + vm.prank(address(module)); IPAccountStorageOps.setAddress(ipAccount, "test".toShortString(), vm.addr(2)); - vm.prank(vm.addr(1)); + vm.prank(address(module)); assertEq(IPAccountStorageOps.getAddress(ipAccount, "test".toShortString()), vm.addr(2)); vm.prank(vm.addr(2)); - assertEq(IPAccountStorageOps.getAddress(ipAccount, vm.addr(1), "test".toShortString()), vm.addr(2)); + assertEq(IPAccountStorageOps.getAddress(ipAccount, address(module), "test".toShortString()), vm.addr(2)); } function test_IPAccountStorageOps_setAddress_bytes32() public { - vm.prank(vm.addr(1)); + vm.prank(address(module)); IPAccountStorageOps.setAddress(ipAccount, bytes32("test"), vm.addr(2)); - vm.prank(vm.addr(1)); + vm.prank(address(module)); assertEq(IPAccountStorageOps.getAddress(ipAccount, "test".toShortString()), vm.addr(2)); vm.prank(vm.addr(2)); - assertEq(IPAccountStorageOps.getAddress(ipAccount, vm.addr(1), "test".toShortString()), vm.addr(2)); + assertEq(IPAccountStorageOps.getAddress(ipAccount, address(module), "test".toShortString()), vm.addr(2)); vm.prank(vm.addr(2)); - assertEq(IPAccountStorageOps.getAddress(ipAccount, vm.addr(1), bytes32("test")), vm.addr(2)); + assertEq(IPAccountStorageOps.getAddress(ipAccount, address(module), bytes32("test")), vm.addr(2)); } function test_IPAccountStorageOps_setUint256_ShortString() public { - vm.prank(vm.addr(1)); + vm.prank(address(module)); IPAccountStorageOps.setUint256(ipAccount, "test".toShortString(), 1); - vm.prank(vm.addr(1)); + vm.prank(address(module)); assertEq(IPAccountStorageOps.getUint256(ipAccount, "test".toShortString()), 1); vm.prank(vm.addr(2)); - assertEq(IPAccountStorageOps.getUint256(ipAccount, vm.addr(1), "test".toShortString()), 1); + assertEq(IPAccountStorageOps.getUint256(ipAccount, address(module), "test".toShortString()), 1); } function test_IPAccountStorageOps_setUint256_bytes32() public { - vm.prank(vm.addr(1)); + vm.prank(address(module)); IPAccountStorageOps.setUint256(ipAccount, bytes32("test"), 1); - vm.prank(vm.addr(1)); + vm.prank(address(module)); assertEq(IPAccountStorageOps.getUint256(ipAccount, "test".toShortString()), 1); vm.prank(vm.addr(2)); - assertEq(IPAccountStorageOps.getUint256(ipAccount, vm.addr(1), "test".toShortString()), 1); + assertEq(IPAccountStorageOps.getUint256(ipAccount, address(module), "test".toShortString()), 1); vm.prank(vm.addr(2)); - assertEq(IPAccountStorageOps.getUint256(ipAccount, vm.addr(1), bytes32("test")), 1); + assertEq(IPAccountStorageOps.getUint256(ipAccount, address(module), bytes32("test")), 1); } function test_IPAccountStorageOps_setBool_ShortString() public { - vm.prank(vm.addr(1)); + vm.prank(address(module)); IPAccountStorageOps.setBool(ipAccount, "test".toShortString(), true); - vm.prank(vm.addr(1)); + vm.prank(address(module)); assertTrue(IPAccountStorageOps.getBool(ipAccount, "test".toShortString())); vm.prank(vm.addr(2)); - assertTrue(IPAccountStorageOps.getBool(ipAccount, vm.addr(1), "test".toShortString())); + assertTrue(IPAccountStorageOps.getBool(ipAccount, address(module), "test".toShortString())); } function test_IPAccountStorageOps_setBool_bytes32() public { - vm.prank(vm.addr(1)); + vm.prank(address(module)); IPAccountStorageOps.setBool(ipAccount, bytes32("test"), true); - vm.prank(vm.addr(1)); + vm.prank(address(module)); assertTrue(IPAccountStorageOps.getBool(ipAccount, "test".toShortString())); vm.prank(vm.addr(2)); - assertTrue(IPAccountStorageOps.getBool(ipAccount, vm.addr(1), "test".toShortString())); + assertTrue(IPAccountStorageOps.getBool(ipAccount, address(module), "test".toShortString())); vm.prank(vm.addr(2)); - assertTrue(IPAccountStorageOps.getBool(ipAccount, vm.addr(1), bytes32("test"))); + assertTrue(IPAccountStorageOps.getBool(ipAccount, address(module), bytes32("test"))); } function test_IPAccountStorageOps_setBytes_ShortString() public { - vm.prank(vm.addr(1)); + vm.prank(address(module)); IPAccountStorageOps.setBytes(ipAccount, "test".toShortString(), abi.encodePacked("test")); - vm.prank(vm.addr(1)); + vm.prank(address(module)); assertEq(IPAccountStorageOps.getBytes(ipAccount, "test".toShortString()), abi.encodePacked("test")); vm.prank(vm.addr(2)); - assertEq(IPAccountStorageOps.getBytes(ipAccount, vm.addr(1), "test".toShortString()), abi.encodePacked("test")); + assertEq(IPAccountStorageOps.getBytes(ipAccount, address(module), "test".toShortString()), abi.encodePacked("test")); } function test_IPAccountStorageOps_setBytes_bytes32() public { - vm.prank(vm.addr(1)); + vm.prank(address(module)); ipAccount.setBytes(bytes32("test"), abi.encodePacked("test")); - vm.prank(vm.addr(1)); + vm.prank(address(module)); assertEq(IPAccountStorageOps.getBytes(ipAccount, "test".toShortString()), abi.encodePacked("test")); vm.prank(vm.addr(2)); - assertEq(IPAccountStorageOps.getBytes(ipAccount, vm.addr(1), "test".toShortString()), abi.encodePacked("test")); + assertEq(IPAccountStorageOps.getBytes(ipAccount, address(module), "test".toShortString()), abi.encodePacked("test")); vm.prank(vm.addr(2)); - assertEq(IPAccountStorageOps.getBytes(ipAccount, vm.addr(1), bytes32("test")), abi.encodePacked("test")); + assertEq(IPAccountStorageOps.getBytes(ipAccount, address(module), bytes32("test")), abi.encodePacked("test")); } function test_IPAccountStorageOps_setBytes_2_keys() public { - vm.prank(vm.addr(1)); + vm.prank(address(module)); IPAccountStorageOps.setBytes( ipAccount, "key1".toShortString(), "key2".toShortString(), abi.encodePacked("test") ); - vm.prank(vm.addr(1)); + vm.prank(address(module)); assertEq( IPAccountStorageOps.getBytes(ipAccount, "key1".toShortString(), "key2".toShortString()), abi.encodePacked("test") ); vm.prank(vm.addr(2)); assertEq( - IPAccountStorageOps.getBytes(ipAccount, vm.addr(1), "key1".toShortString(), "key2".toShortString()), + IPAccountStorageOps.getBytes(ipAccount, address(module), "key1".toShortString(), "key2".toShortString()), abi.encodePacked("test") ); } @@ -152,10 +158,10 @@ contract IPAccountStorageOpsTest is BaseTest { } function test_IPAccountStorage_readUint256_differentNameSpace() public { - vm.prank(vm.addr(1)); + vm.prank(address(module)); IPAccountStorageOps.setUint256(ipAccount, "test", 1); vm.prank(vm.addr(2)); - assertEq(IPAccountStorageOps.getUint256(ipAccount, _toBytes32(vm.addr(1)), "test"), 1); + assertEq(IPAccountStorageOps.getUint256(ipAccount, _toBytes32(address(module)), "test"), 1); } function test_IPAccountStorage_storeBool() public { @@ -164,10 +170,10 @@ contract IPAccountStorageOpsTest is BaseTest { } function test_IPAccountStorage_readBool_differentNameSpace() public { - vm.prank(vm.addr(1)); + vm.prank(address(module)); IPAccountStorageOps.setBool(ipAccount, "test", true); vm.prank(vm.addr(2)); - assertTrue(IPAccountStorageOps.getBool(ipAccount, _toBytes32(vm.addr(1)), "test")); + assertTrue(IPAccountStorageOps.getBool(ipAccount, _toBytes32(address(module)), "test")); } function test_IPAccountStorage_storeString() public { @@ -176,10 +182,10 @@ contract IPAccountStorageOpsTest is BaseTest { } function test_IPAccountStorage_readString_differentNameSpace() public { - vm.prank(vm.addr(1)); + vm.prank(address(module)); IPAccountStorageOps.setString(ipAccount, "test", "test"); vm.prank(vm.addr(2)); - assertEq(IPAccountStorageOps.getString(ipAccount, _toBytes32(vm.addr(1)), "test"), "test"); + assertEq(IPAccountStorageOps.getString(ipAccount, _toBytes32(address(module)), "test"), "test"); } function test_IPAccountStorage_storeAddress() public { @@ -188,10 +194,10 @@ contract IPAccountStorageOpsTest is BaseTest { } function test_IPAccountStorage_readAddress_differentNameSpace() public { - vm.prank(vm.addr(1)); - IPAccountStorageOps.setAddress(ipAccount, "test", vm.addr(1)); + vm.prank(address(module)); + IPAccountStorageOps.setAddress(ipAccount, "test", address(module)); vm.prank(vm.addr(2)); - assertEq(IPAccountStorageOps.getAddress(ipAccount, _toBytes32(vm.addr(1)), "test"), vm.addr(1)); + assertEq(IPAccountStorageOps.getAddress(ipAccount, _toBytes32(address(module)), "test"), address(module)); } function _toBytes32(address a) internal pure returns (bytes32) { From e319b39131c24d322be2a66a5c7db41e77279ac8 Mon Sep 17 00:00:00 2001 From: Kingster Date: Sun, 14 Apr 2024 23:15:48 -0700 Subject: [PATCH 2/4] deploy with GREATE3 --- contracts/IPAccountImpl.sol | 6 ++-- script/foundry/utils/DeployHelper.sol | 10 +++---- test/foundry/IPAccountStorage.t.sol | 38 +++++++++++++++++++++++++- test/foundry/IPAccountStorageOps.t.sol | 10 +++++-- 4 files changed, 52 insertions(+), 12 deletions(-) diff --git a/contracts/IPAccountImpl.sol b/contracts/IPAccountImpl.sol index 16ef7f02..62e4cd19 100644 --- a/contracts/IPAccountImpl.sol +++ b/contracts/IPAccountImpl.sol @@ -18,7 +18,7 @@ import { IPAccountStorage } from "./IPAccountStorage.sol"; /// @title IPAccountImpl /// @notice The Story Protocol's implementation of the IPAccount. contract IPAccountImpl is IPAccountStorage, IIPAccount { - address public immutable accessController; + address public immutable ACCESS_CONTROLLER; /// @notice Returns the IPAccount's internal nonce for transaction ordering. uint256 public state; @@ -38,7 +38,7 @@ contract IPAccountImpl is IPAccountStorage, IIPAccount { address moduleRegistry ) IPAccountStorage(ipAssetRegistry, licenseRegistry, moduleRegistry) { if (accessController == address(0)) revert Errors.IPAccount__InvalidAccessController(); - accessController = accessController; + ACCESS_CONTROLLER = accessController; } /// @notice Checks if the contract supports a specific interface @@ -108,7 +108,7 @@ contract IPAccountImpl is IPAccountStorage, IIPAccount { selector = bytes4(data[:4]); } // the check will revert if permission is denied - IAccessController(accessController).checkPermission(address(this), signer, to, selector); + IAccessController(ACCESS_CONTROLLER).checkPermission(address(this), signer, to, selector); return true; } diff --git a/script/foundry/utils/DeployHelper.sol b/script/foundry/utils/DeployHelper.sol index f4275195..64be42fa 100644 --- a/script/foundry/utils/DeployHelper.sol +++ b/script/foundry/utils/DeployHelper.sol @@ -208,11 +208,6 @@ contract DeployHelper is Script, BroadcastManager, JsonDeploymentHandler, Storag impl = address(0); // Make sure we don't deploy wrong impl _postdeploy(contractKey, address(moduleRegistry)); - contractKey = "IPAccountImpl"; - _predeploy(contractKey); - ipAccountImpl = new IPAccountImpl(address(accessController), address(moduleRegistry)); - _postdeploy(contractKey, address(ipAccountImpl)); - contractKey = "IPAssetRegistry"; _predeploy(contractKey); impl = address(new IPAssetRegistry(address(erc6551Registry), ipAccountImplAddr)); @@ -243,7 +238,10 @@ contract DeployHelper is Script, BroadcastManager, JsonDeploymentHandler, Storag bytes memory ipAccountImplCode = abi.encodePacked( type(IPAccountImpl).creationCode, abi.encode( - address(accessController) + address(accessController), + address(ipAssetRegistry), + address(licenseRegistry), + address(moduleRegistry) ) ); _predeploy(contractKey); diff --git a/test/foundry/IPAccountStorage.t.sol b/test/foundry/IPAccountStorage.t.sol index 00b65ee3..c8af8c3d 100644 --- a/test/foundry/IPAccountStorage.t.sol +++ b/test/foundry/IPAccountStorage.t.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.23; import { IIPAccount } from "../../contracts/interfaces/IIPAccount.sol"; import { BaseModule } from "../../contracts/modules/BaseModule.sol"; +import { Errors } from "../../contracts/lib/Errors.sol"; import { MockModule } from "./mocks/module/MockModule.sol"; import { BaseTest } from "./utils/BaseTest.t.sol"; @@ -26,7 +27,6 @@ contract IPAccountStorageTest is BaseTest, BaseModule { moduleRegistry.registerModule("MockModule", address(module)); moduleRegistry.registerModule("IPAccountStorageTest", address(this)); vm.stopPrank(); - } function test_IPAccountStorage_storeBytes() public { @@ -131,6 +131,42 @@ contract IPAccountStorageTest is BaseTest, BaseModule { assertEq(ipAccount.getBytes32(_toBytes32(address(module)), "test"), "testData"); } + function test_IPAccountStorage_setBytes32_revert_NonRegisteredModule() public { + vm.expectRevert(abi.encodeWithSelector(Errors.IPAccountStorage__NotRegisteredModule.selector, address(0x123))); + vm.prank(address(0x123)); + ipAccount.setBytes32("test", "testData"); + } + + function test_IPAccountStorage_setBytes_revert_NonRegisteredModule() public { + vm.expectRevert(abi.encodeWithSelector(Errors.IPAccountStorage__NotRegisteredModule.selector, address(0x123))); + vm.prank(address(0x123)); + ipAccount.setBytes("test", "testData"); + } + + function test_IPAccountStorage_setBytes_ByIpAssetRegistry() public { + vm.prank(address(ipAssetRegistry)); + ipAccount.setBytes("test", "testData"); + assertEq(ipAccount.getBytes(_toBytes32(address(ipAssetRegistry)), "test"), "testData"); + } + + function test_IPAccountStorage_setBytes32_ByIpAssetRegistry() public { + vm.prank(address(ipAssetRegistry)); + ipAccount.setBytes32("test", "testData"); + assertEq(ipAccount.getBytes32(_toBytes32(address(ipAssetRegistry)), "test"), "testData"); + } + + function test_IPAccountStorage_setBytes_ByLicenseRegistry() public { + vm.prank(address(licenseRegistry)); + ipAccount.setBytes("test", "testData"); + assertEq(ipAccount.getBytes(_toBytes32(address(licenseRegistry)), "test"), "testData"); + } + + function test_IPAccountStorage_setBytes32_ByLicenseRegistry() public { + vm.prank(address(licenseRegistry)); + ipAccount.setBytes32("test", "testData"); + assertEq(ipAccount.getBytes32(_toBytes32(address(licenseRegistry)), "test"), "testData"); + } + function _toBytes32(address a) internal pure returns (bytes32) { return bytes32(uint256(uint160(a))); } diff --git a/test/foundry/IPAccountStorageOps.t.sol b/test/foundry/IPAccountStorageOps.t.sol index 31e5ea00..24afec0b 100644 --- a/test/foundry/IPAccountStorageOps.t.sol +++ b/test/foundry/IPAccountStorageOps.t.sol @@ -118,7 +118,10 @@ contract IPAccountStorageOpsTest is BaseTest, BaseModule { vm.prank(address(module)); assertEq(IPAccountStorageOps.getBytes(ipAccount, "test".toShortString()), abi.encodePacked("test")); vm.prank(vm.addr(2)); - assertEq(IPAccountStorageOps.getBytes(ipAccount, address(module), "test".toShortString()), abi.encodePacked("test")); + assertEq( + IPAccountStorageOps.getBytes(ipAccount, address(module), "test".toShortString()), + abi.encodePacked("test") + ); } function test_IPAccountStorageOps_setBytes_bytes32() public { @@ -127,7 +130,10 @@ contract IPAccountStorageOpsTest is BaseTest, BaseModule { vm.prank(address(module)); assertEq(IPAccountStorageOps.getBytes(ipAccount, "test".toShortString()), abi.encodePacked("test")); vm.prank(vm.addr(2)); - assertEq(IPAccountStorageOps.getBytes(ipAccount, address(module), "test".toShortString()), abi.encodePacked("test")); + assertEq( + IPAccountStorageOps.getBytes(ipAccount, address(module), "test".toShortString()), + abi.encodePacked("test") + ); vm.prank(vm.addr(2)); assertEq(IPAccountStorageOps.getBytes(ipAccount, address(module), bytes32("test")), abi.encodePacked("test")); } From dcb4f8ffebf365d2c1be7834dab13ac204761aec Mon Sep 17 00:00:00 2001 From: Kingster Date: Mon, 15 Apr 2024 04:53:56 -0700 Subject: [PATCH 3/4] add modifier --- contracts/IPAccountStorage.sol | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/contracts/IPAccountStorage.sol b/contracts/IPAccountStorage.sol index 45f4cfab..87d6e682 100644 --- a/contracts/IPAccountStorage.sol +++ b/contracts/IPAccountStorage.sol @@ -22,14 +22,8 @@ contract IPAccountStorage is ERC165, IIPAccountStorage { mapping(bytes32 => mapping(bytes32 => bytes)) public bytesData; mapping(bytes32 => mapping(bytes32 => bytes32)) public bytes32Data; - constructor(address ipAssetRegistry, address licenseRegistry, address moduleRegistry) { - MODULE_REGISTRY = moduleRegistry; - LICENSE_REGISTRY = licenseRegistry; - IP_ASSET_REGISTRY = ipAssetRegistry; - } - /// @inheritdoc IIPAccountStorage - function setBytes(bytes32 key, bytes calldata value) external { + modifier onlyRegisteredModule() { if ( msg.sender != IP_ASSET_REGISTRY && msg.sender != LICENSE_REGISTRY && @@ -37,6 +31,17 @@ contract IPAccountStorage is ERC165, IIPAccountStorage { ) { revert Errors.IPAccountStorage__NotRegisteredModule(msg.sender); } + _; + } + + constructor(address ipAssetRegistry, address licenseRegistry, address moduleRegistry) { + MODULE_REGISTRY = moduleRegistry; + LICENSE_REGISTRY = licenseRegistry; + IP_ASSET_REGISTRY = ipAssetRegistry; + } + + /// @inheritdoc IIPAccountStorage + function setBytes(bytes32 key, bytes calldata value) external onlyRegisteredModule { bytesData[_toBytes32(msg.sender)][key] = value; } /// @inheritdoc IIPAccountStorage @@ -49,14 +54,7 @@ contract IPAccountStorage is ERC165, IIPAccountStorage { } /// @inheritdoc IIPAccountStorage - function setBytes32(bytes32 key, bytes32 value) external { - if ( - msg.sender != IP_ASSET_REGISTRY && - msg.sender != LICENSE_REGISTRY && - !IModuleRegistry(MODULE_REGISTRY).isRegistered(msg.sender) - ) { - revert Errors.IPAccountStorage__NotRegisteredModule(msg.sender); - } + function setBytes32(bytes32 key, bytes32 value) external onlyRegisteredModule { bytes32Data[_toBytes32(msg.sender)][key] = value; } /// @inheritdoc IIPAccountStorage From 7178ac31ac03e7f1073a61c1f8c65242e76f8080 Mon Sep 17 00:00:00 2001 From: Kingster Date: Mon, 15 Apr 2024 05:05:11 -0700 Subject: [PATCH 4/4] fix lint --- contracts/IPAccountStorage.sol | 1 - .../mocks/module/MockAccessControllerV2.sol | 3 +- test/foundry/upgrades/Upgrades.t.sol | 107 ++++++++++++++---- 3 files changed, 90 insertions(+), 21 deletions(-) diff --git a/contracts/IPAccountStorage.sol b/contracts/IPAccountStorage.sol index 87d6e682..cc772f86 100644 --- a/contracts/IPAccountStorage.sol +++ b/contracts/IPAccountStorage.sol @@ -22,7 +22,6 @@ contract IPAccountStorage is ERC165, IIPAccountStorage { mapping(bytes32 => mapping(bytes32 => bytes)) public bytesData; mapping(bytes32 => mapping(bytes32 => bytes32)) public bytes32Data; - modifier onlyRegisteredModule() { if ( msg.sender != IP_ASSET_REGISTRY && diff --git a/test/foundry/mocks/module/MockAccessControllerV2.sol b/test/foundry/mocks/module/MockAccessControllerV2.sol index e5ac7ee7..ac55250e 100644 --- a/test/foundry/mocks/module/MockAccessControllerV2.sol +++ b/test/foundry/mocks/module/MockAccessControllerV2.sol @@ -11,7 +11,8 @@ contract MockAccessControllerV2 is AccessController { } // keccak256(abi.encode(uint256(keccak256("story-protocol.AccessControllerV2")) - 1)) & ~bytes32(uint256(0xff)); - bytes32 private constant AccessControllerV2StorageLocation = 0xf328f2cdee4ae4df23921504bfa43e3156fb4d18b23549ca0a43fd1e64947a00; + bytes32 private constant AccessControllerV2StorageLocation = + 0xf328f2cdee4ae4df23921504bfa43e3156fb4d18b23549ca0a43fd1e64947a00; function initialize() public reinitializer(2) { _getAccessControllerV2Storage().newState = "initialized"; diff --git a/test/foundry/upgrades/Upgrades.t.sol b/test/foundry/upgrades/Upgrades.t.sol index 5243c3a7..cdb37049 100644 --- a/test/foundry/upgrades/Upgrades.t.sol +++ b/test/foundry/upgrades/Upgrades.t.sol @@ -12,7 +12,6 @@ import { MockIpRoyaltyVaultV2 } from "../mocks/module/MockIpRoyaltyVaultV2.sol"; import { MockAccessControllerV2 } from "../mocks/module/MockAccessControllerV2.sol"; contract UpgradesTest is BaseTest { - uint32 execDelay = 600; function setUp() public override { @@ -45,7 +44,6 @@ contract UpgradesTest is BaseTest { } function test_upgradeAccessController() public { - (bool immediate, uint32 delay) = protocolAccessManager.canCall( u.bob, address(accessController), @@ -54,7 +52,6 @@ contract UpgradesTest is BaseTest { assertFalse(immediate); assertEq(delay, execDelay); - address newAccessController = address(new MockAccessControllerV2()); vm.prank(u.bob); (bytes32 operationId, uint32 nonce) = protocolAccessManager.schedule( @@ -75,7 +72,10 @@ contract UpgradesTest is BaseTest { function test_deploymentSetup() public { // Deployer doesn't have the roles - (bool isMember, uint32 executionDelay) = protocolAccessManager.hasRole(ProtocolAdmin.PROTOCOL_ADMIN_ROLE, deployer); + (bool isMember, uint32 executionDelay) = protocolAccessManager.hasRole( + ProtocolAdmin.PROTOCOL_ADMIN_ROLE, + deployer + ); assertFalse(isMember); assertEq(executionDelay, 0); (isMember, executionDelay) = protocolAccessManager.hasRole(ProtocolAdmin.UPGRADER_ROLE, deployer); @@ -84,14 +84,17 @@ contract UpgradesTest is BaseTest { (isMember, executionDelay) = protocolAccessManager.hasRole(ProtocolAdmin.PAUSE_ADMIN_ROLE, deployer); assertFalse(isMember); assertEq(executionDelay, 0); - + (isMember, executionDelay) = protocolAccessManager.hasRole(ProtocolAdmin.PROTOCOL_ADMIN_ROLE, multisig); assertTrue(isMember); assertEq(executionDelay, 0); (isMember, executionDelay) = protocolAccessManager.hasRole(ProtocolAdmin.PAUSE_ADMIN_ROLE, multisig); assertTrue(isMember); assertEq(executionDelay, 0); - (isMember, executionDelay) = protocolAccessManager.hasRole(ProtocolAdmin.PAUSE_ADMIN_ROLE, address(protocolPauser)); + (isMember, executionDelay) = protocolAccessManager.hasRole( + ProtocolAdmin.PAUSE_ADMIN_ROLE, + address(protocolPauser) + ); assertTrue(isMember); assertEq(executionDelay, 0); (isMember, executionDelay) = protocolAccessManager.hasRole(ProtocolAdmin.UPGRADER_ROLE, multisig); @@ -99,7 +102,7 @@ contract UpgradesTest is BaseTest { assertEq(executionDelay, execDelay); // Target function role wiring - + (bool immediate, uint32 delay) = protocolAccessManager.canCall( multisig, address(royaltyPolicyLAP), @@ -107,8 +110,14 @@ contract UpgradesTest is BaseTest { ); assertFalse(immediate); assertEq(delay, 600); - assertEq(protocolAccessManager.getTargetFunctionRole(address(royaltyPolicyLAP), RoyaltyPolicyLAP.upgradeVaults.selector), ProtocolAdmin.UPGRADER_ROLE); - + assertEq( + protocolAccessManager.getTargetFunctionRole( + address(royaltyPolicyLAP), + RoyaltyPolicyLAP.upgradeVaults.selector + ), + ProtocolAdmin.UPGRADER_ROLE + ); + (immediate, delay) = protocolAccessManager.canCall( multisig, address(accessController), @@ -116,7 +125,13 @@ contract UpgradesTest is BaseTest { ); assertFalse(immediate); assertEq(delay, execDelay); - assertEq(protocolAccessManager.getTargetFunctionRole(address(accessController), UUPSUpgradeable.upgradeToAndCall.selector), ProtocolAdmin.UPGRADER_ROLE); + assertEq( + protocolAccessManager.getTargetFunctionRole( + address(accessController), + UUPSUpgradeable.upgradeToAndCall.selector + ), + ProtocolAdmin.UPGRADER_ROLE + ); (immediate, delay) = protocolAccessManager.canCall( multisig, @@ -125,7 +140,13 @@ contract UpgradesTest is BaseTest { ); assertFalse(immediate); assertEq(delay, execDelay); - assertEq(protocolAccessManager.getTargetFunctionRole(address(licenseToken), UUPSUpgradeable.upgradeToAndCall.selector), ProtocolAdmin.UPGRADER_ROLE); + assertEq( + protocolAccessManager.getTargetFunctionRole( + address(licenseToken), + UUPSUpgradeable.upgradeToAndCall.selector + ), + ProtocolAdmin.UPGRADER_ROLE + ); (immediate, delay) = protocolAccessManager.canCall( multisig, @@ -134,7 +155,13 @@ contract UpgradesTest is BaseTest { ); assertFalse(immediate); assertEq(delay, execDelay); - assertEq(protocolAccessManager.getTargetFunctionRole(address(disputeModule), UUPSUpgradeable.upgradeToAndCall.selector), ProtocolAdmin.UPGRADER_ROLE); + assertEq( + protocolAccessManager.getTargetFunctionRole( + address(disputeModule), + UUPSUpgradeable.upgradeToAndCall.selector + ), + ProtocolAdmin.UPGRADER_ROLE + ); (immediate, delay) = protocolAccessManager.canCall( multisig, @@ -143,7 +170,13 @@ contract UpgradesTest is BaseTest { ); assertFalse(immediate); assertEq(delay, execDelay); - assertEq(protocolAccessManager.getTargetFunctionRole(address(arbitrationPolicySP), UUPSUpgradeable.upgradeToAndCall.selector), ProtocolAdmin.UPGRADER_ROLE); + assertEq( + protocolAccessManager.getTargetFunctionRole( + address(arbitrationPolicySP), + UUPSUpgradeable.upgradeToAndCall.selector + ), + ProtocolAdmin.UPGRADER_ROLE + ); (immediate, delay) = protocolAccessManager.canCall( multisig, @@ -152,7 +185,13 @@ contract UpgradesTest is BaseTest { ); assertFalse(immediate); assertEq(delay, execDelay); - assertEq(protocolAccessManager.getTargetFunctionRole(address(licensingModule), UUPSUpgradeable.upgradeToAndCall.selector), ProtocolAdmin.UPGRADER_ROLE); + assertEq( + protocolAccessManager.getTargetFunctionRole( + address(licensingModule), + UUPSUpgradeable.upgradeToAndCall.selector + ), + ProtocolAdmin.UPGRADER_ROLE + ); (immediate, delay) = protocolAccessManager.canCall( multisig, @@ -161,7 +200,13 @@ contract UpgradesTest is BaseTest { ); assertFalse(immediate); assertEq(delay, execDelay); - assertEq(protocolAccessManager.getTargetFunctionRole(address(royaltyModule), UUPSUpgradeable.upgradeToAndCall.selector), ProtocolAdmin.UPGRADER_ROLE); + assertEq( + protocolAccessManager.getTargetFunctionRole( + address(royaltyModule), + UUPSUpgradeable.upgradeToAndCall.selector + ), + ProtocolAdmin.UPGRADER_ROLE + ); (immediate, delay) = protocolAccessManager.canCall( multisig, @@ -170,7 +215,13 @@ contract UpgradesTest is BaseTest { ); assertFalse(immediate); assertEq(delay, execDelay); - assertEq(protocolAccessManager.getTargetFunctionRole(address(licenseRegistry), UUPSUpgradeable.upgradeToAndCall.selector), ProtocolAdmin.UPGRADER_ROLE); + assertEq( + protocolAccessManager.getTargetFunctionRole( + address(licenseRegistry), + UUPSUpgradeable.upgradeToAndCall.selector + ), + ProtocolAdmin.UPGRADER_ROLE + ); (immediate, delay) = protocolAccessManager.canCall( multisig, @@ -179,7 +230,13 @@ contract UpgradesTest is BaseTest { ); assertFalse(immediate); assertEq(delay, execDelay); - assertEq(protocolAccessManager.getTargetFunctionRole(address(moduleRegistry), UUPSUpgradeable.upgradeToAndCall.selector), ProtocolAdmin.UPGRADER_ROLE); + assertEq( + protocolAccessManager.getTargetFunctionRole( + address(moduleRegistry), + UUPSUpgradeable.upgradeToAndCall.selector + ), + ProtocolAdmin.UPGRADER_ROLE + ); (immediate, delay) = protocolAccessManager.canCall( multisig, @@ -188,7 +245,13 @@ contract UpgradesTest is BaseTest { ); assertFalse(immediate); assertEq(delay, execDelay); - assertEq(protocolAccessManager.getTargetFunctionRole(address(ipAssetRegistry), UUPSUpgradeable.upgradeToAndCall.selector), ProtocolAdmin.UPGRADER_ROLE); + assertEq( + protocolAccessManager.getTargetFunctionRole( + address(ipAssetRegistry), + UUPSUpgradeable.upgradeToAndCall.selector + ), + ProtocolAdmin.UPGRADER_ROLE + ); (immediate, delay) = protocolAccessManager.canCall( multisig, @@ -197,6 +260,12 @@ contract UpgradesTest is BaseTest { ); assertFalse(immediate); assertEq(delay, execDelay); - assertEq(protocolAccessManager.getTargetFunctionRole(address(royaltyPolicyLAP), UUPSUpgradeable.upgradeToAndCall.selector), ProtocolAdmin.UPGRADER_ROLE); + assertEq( + protocolAccessManager.getTargetFunctionRole( + address(royaltyPolicyLAP), + UUPSUpgradeable.upgradeToAndCall.selector + ), + ProtocolAdmin.UPGRADER_ROLE + ); } }