diff --git a/deployment/ccip/changeset/cs_rmn_curse_uncurse.go b/deployment/ccip/changeset/cs_rmn_curse_uncurse.go new file mode 100644 index 00000000000..88dd99fe523 --- /dev/null +++ b/deployment/ccip/changeset/cs_rmn_curse_uncurse.go @@ -0,0 +1,275 @@ +package changeset + +import ( + "encoding/binary" + "fmt" + "slices" + + "github.com/pkg/errors" + + "github.com/smartcontractkit/chainlink/deployment" + commoncs "github.com/smartcontractkit/chainlink/deployment/common/changeset" +) + +func GlobalCurseSubject() Subject { + return Subject{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01} +} + +type RMNCurseAction struct { + ChainSelector uint64 + SubjectToCurse Subject +} + +type CurseAction func(e deployment.Environment) []RMNCurseAction + +type RMNCurseConfig struct { + MCMS *MCMSConfig + CurseActions []CurseAction + Reason string +} + +func (c RMNCurseConfig) Validate(e deployment.Environment) error { + state, err := LoadOnchainState(e) + + if err != nil { + return errors.Errorf("failed to load onchain state: %v", err) + } + + if len(c.CurseActions) == 0 { + return errors.Errorf("curse actions are required") + } + + if c.Reason == "" { + return errors.Errorf("reason is required") + } + + validSelectors := e.AllChainSelectors() + validSubjects := make([]Subject, 0, len(validSelectors)+1) + for _, selector := range validSelectors { + validSubjects = append(validSubjects, SelectorToSubject(selector)) + } + validSubjects = append(validSubjects, GlobalCurseSubject()) + + for _, curseAction := range c.CurseActions { + result := curseAction(e) + for _, action := range result { + targetChain := e.Chains[action.ChainSelector] + targetChainState := state.Chains[action.ChainSelector] + + if err := commoncs.ValidateOwnership(e.GetContext(), c.MCMS != nil, targetChain.DeployerKey.From, targetChainState.Timelock.Address(), targetChainState.RMNRemote); err != nil { + return fmt.Errorf("chain %s: %w", targetChain.String(), err) + } + + if !slices.Contains(validSelectors, action.ChainSelector) { + return errors.Errorf("invalid chain selector %d for chain %s", action.ChainSelector, targetChain.String()) + } + + if !slices.Contains(validSubjects, action.SubjectToCurse) { + return errors.Errorf("invalid subject %x for chain %s", action.SubjectToCurse, targetChain.String()) + } + } + } + + return nil +} + +type Subject = [16]byte + +func SelectorToSubject(subject uint64) Subject { + var b Subject + binary.BigEndian.PutUint64(b[8:], subject) + return b +} + +// CurseLaneOnlyOnSource curses a lane only on the source chain +// This will prevent message from source to destination to be initiated +// One noteworthy behaviour is that this means that message can be sent from destination to source but will not be executed on the source +func CurseLaneOnlyOnSource(sourceSelector uint64, destinationSelector uint64) CurseAction { + // Curse from source to destination + return func(e deployment.Environment) []RMNCurseAction { + return []RMNCurseAction{ + { + ChainSelector: sourceSelector, + SubjectToCurse: SelectorToSubject(destinationSelector), + }, + } + } +} + +func CurseLane(sourceSelector uint64, destinationSelector uint64) CurseAction { + // Bidirectional curse between two chains + return func(e deployment.Environment) []RMNCurseAction { + return append( + CurseLaneOnlyOnSource(sourceSelector, destinationSelector)(e), + CurseLaneOnlyOnSource(destinationSelector, sourceSelector)(e)..., + ) + } +} + +func CurseChain(chainSelector uint64) CurseAction { + return func(e deployment.Environment) []RMNCurseAction { + chainSelectors := e.AllChainSelectors() + + // Curse all other chains to prevent onramp from sending message to the cursed chain + var curseActions []RMNCurseAction + for _, otherChainSelector := range chainSelectors { + if otherChainSelector != chainSelector { + curseActions = append(curseActions, RMNCurseAction{ + ChainSelector: otherChainSelector, + SubjectToCurse: SelectorToSubject(chainSelector), + }) + } + } + + // Curse the chain with a global curse to prevent any onramp or offramp message from send message in and out of the chain + curseActions = append(curseActions, RMNCurseAction{ + ChainSelector: chainSelector, + SubjectToCurse: GlobalCurseSubject(), + }) + + return curseActions + } +} + +func groupRMNSubjectBySelector(rmnSubjects []RMNCurseAction, filter bool) map[uint64][]Subject { + grouped := make(map[uint64][]Subject) + for _, subject := range rmnSubjects { + grouped[subject.ChainSelector] = append(grouped[subject.ChainSelector], subject.SubjectToCurse) + } + + // Only keep unique subjects, preserve only global curse if present and eliminate any curse where the selector is the same as the subject + // If filter is false then only make sure that there is no duplicate subject + for chainSelector, subjects := range grouped { + uniqueSubjects := make(map[Subject]struct{}) + for _, subject := range subjects { + if subject == SelectorToSubject(chainSelector) && filter { + continue + } + uniqueSubjects[subject] = struct{}{} + } + + if _, ok := uniqueSubjects[GlobalCurseSubject()]; ok && filter { + grouped[chainSelector] = []Subject{GlobalCurseSubject()} + } else { + var uniqueSubjectsSlice []Subject + for subject := range uniqueSubjects { + uniqueSubjectsSlice = append(uniqueSubjectsSlice, subject) + } + grouped[chainSelector] = uniqueSubjectsSlice + } + } + + return grouped +} + +// NewRMNCurseChangeset creates a new changeset for cursing chains or lanes on RMNRemote contracts. +// Example usage: +// +// cfg := RMNCurseConfig{ +// CurseActions: []CurseAction{ +// CurseChain(SEPOLIA_CHAIN_SELECTOR), +// CurseLane(SEPOLIA_CHAIN_SELECTOR, AVAX_FUJI_CHAIN_SELECTOR), +// }, +// CurseReason: "test curse", +// MCMS: &MCMSConfig{MinDelay: 0}, +// } +// output, err := NewRMNCurseChangeset(env, cfg) +func NewRMNCurseChangeset(e deployment.Environment, cfg RMNCurseConfig) (deployment.ChangesetOutput, error) { + state, err := LoadOnchainState(e) + if err != nil { + return deployment.ChangesetOutput{}, errors.Errorf("failed to load onchain state: %v", err) + } + deployerGroup := NewDeployerGroup(e, state, cfg.MCMS) + + // Generate curse actions + var curseActions []RMNCurseAction + for _, curseAction := range cfg.CurseActions { + curseActions = append(curseActions, curseAction(e)...) + } + // Group curse actions by chain selector + grouped := groupRMNSubjectBySelector(curseActions, true) + + // For each chain in the environement get the RMNRemote contract and call curse + for selector, chain := range state.Chains { + deployer := deployerGroup.getDeployer(selector) + if curseSubjects, ok := grouped[selector]; ok { + // Only curse the subject that are not actually cursed + notAlreadyCursedSubjects := make([]Subject, 0) + for _, subject := range curseSubjects { + cursed, err := chain.RMNRemote.IsCursed(nil, subject) + if err != nil { + return deployment.ChangesetOutput{}, errors.Errorf("failed to check if chain %d is cursed: %v", selector, err) + } + + if !cursed { + notAlreadyCursedSubjects = append(notAlreadyCursedSubjects, subject) + } else { + e.Logger.Warnf("chain %s subject %x is already cursed, ignoring it while cursing", e.Chains[selector].Name(), subject) + } + } + _, err := chain.RMNRemote.Curse0(deployer, notAlreadyCursedSubjects) + if err != nil { + return deployment.ChangesetOutput{}, errors.Errorf("failed to curse chain %d: %v", selector, err) + } + } + } + + return deployerGroup.enact("proposal to curse RMNs: " + cfg.Reason) +} + +// NewRMNUncurseChangeset creates a new changeset for uncursing chains or lanes on RMNRemote contracts. +// Example usage: +// +// cfg := RMNCurseConfig{ +// CurseActions: []CurseAction{ +// CurseChain(SEPOLIA_CHAIN_SELECTOR), +// CurseLane(SEPOLIA_CHAIN_SELECTOR, AVAX_FUJI_CHAIN_SELECTOR), +// }, +// MCMS: &MCMSConfig{MinDelay: 0}, +// } +// output, err := NewRMNUncurseChangeset(env, cfg) +// +// Curse actions are reused and reverted instead of applied in this changeset +func NewRMNUncurseChangeset(e deployment.Environment, cfg RMNCurseConfig) (deployment.ChangesetOutput, error) { + state, err := LoadOnchainState(e) + if err != nil { + return deployment.ChangesetOutput{}, errors.Errorf("failed to load onchain state: %v", err) + } + deployerGroup := NewDeployerGroup(e, state, cfg.MCMS) + + // Generate curse actions + var curseActions []RMNCurseAction + for _, curseAction := range cfg.CurseActions { + curseActions = append(curseActions, curseAction(e)...) + } + // Group curse actions by chain selector + grouped := groupRMNSubjectBySelector(curseActions, false) + + // For each chain in the environement get the RMNRemote contract and call uncurse + for selector, chain := range state.Chains { + deployer := deployerGroup.getDeployer(selector) + if curseSubjects, ok := grouped[selector]; ok { + // Only keep the subject that are actually cursed + actuallyCursedSubjects := make([]Subject, 0) + for _, subject := range curseSubjects { + cursed, err := chain.RMNRemote.IsCursed(nil, subject) + if err != nil { + return deployment.ChangesetOutput{}, errors.Errorf("failed to check if chain %d is cursed: %v", selector, err) + } + + if cursed { + actuallyCursedSubjects = append(actuallyCursedSubjects, subject) + } else { + e.Logger.Warnf("chain %s subject %x is not cursed, ignoring it while uncursing", e.Chains[selector].Name(), subject) + } + } + + _, err := chain.RMNRemote.Uncurse0(deployer, actuallyCursedSubjects) + if err != nil { + return deployment.ChangesetOutput{}, errors.Errorf("failed to uncurse chain %d: %v", selector, err) + } + } + } + + return deployerGroup.enact("proposal to uncurse RMNs: %s" + cfg.Reason) +} diff --git a/deployment/ccip/changeset/cs_rmn_curse_uncurse_test.go b/deployment/ccip/changeset/cs_rmn_curse_uncurse_test.go new file mode 100644 index 00000000000..325e2f17300 --- /dev/null +++ b/deployment/ccip/changeset/cs_rmn_curse_uncurse_test.go @@ -0,0 +1,312 @@ +package changeset + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/require" + + commonchangeset "github.com/smartcontractkit/chainlink/deployment/common/changeset" + "github.com/smartcontractkit/chainlink/deployment/common/proposalutils" +) + +type curseAssertion struct { + chainID uint64 + subject uint64 + globalCurse bool + cursed bool +} + +type CurseTestCase struct { + name string + curseActionsBuilder func(mapIDToSelectorFunc) []CurseAction + curseAssertions []curseAssertion +} + +type mapIDToSelectorFunc func(uint64) uint64 + +var testCases = []CurseTestCase{ + { + name: "lane", + curseActionsBuilder: func(mapIDToSelector mapIDToSelectorFunc) []CurseAction { + return []CurseAction{CurseLane(mapIDToSelector(0), mapIDToSelector(1))} + }, + curseAssertions: []curseAssertion{ + {chainID: 0, subject: 1, cursed: true}, + {chainID: 0, subject: 2, cursed: false}, + {chainID: 1, subject: 0, cursed: true}, + {chainID: 1, subject: 2, cursed: false}, + {chainID: 2, subject: 0, cursed: false}, + {chainID: 2, subject: 1, cursed: false}, + }, + }, + { + name: "lane duplicate", + curseActionsBuilder: func(mapIDToSelector mapIDToSelectorFunc) []CurseAction { + return []CurseAction{CurseLane(mapIDToSelector(0), mapIDToSelector(1)), CurseLane(mapIDToSelector(0), mapIDToSelector(1))} + }, + curseAssertions: []curseAssertion{ + {chainID: 0, subject: 1, cursed: true}, + {chainID: 0, subject: 2, cursed: false}, + {chainID: 1, subject: 0, cursed: true}, + {chainID: 1, subject: 2, cursed: false}, + {chainID: 2, subject: 0, cursed: false}, + {chainID: 2, subject: 1, cursed: false}, + }, + }, + { + name: "chain", + curseActionsBuilder: func(mapIDToSelector mapIDToSelectorFunc) []CurseAction { + return []CurseAction{CurseChain(mapIDToSelector(0))} + }, + curseAssertions: []curseAssertion{ + {chainID: 0, globalCurse: true, cursed: true}, + {chainID: 1, subject: 0, cursed: true}, + {chainID: 1, subject: 2, cursed: false}, + {chainID: 2, subject: 0, cursed: true}, + {chainID: 2, subject: 1, cursed: false}, + }, + }, + { + name: "chain duplicate", + curseActionsBuilder: func(mapIDToSelector mapIDToSelectorFunc) []CurseAction { + return []CurseAction{CurseChain(mapIDToSelector(0)), CurseChain(mapIDToSelector(0))} + }, + curseAssertions: []curseAssertion{ + {chainID: 0, globalCurse: true, cursed: true}, + {chainID: 1, subject: 0, cursed: true}, + {chainID: 1, subject: 2, cursed: false}, + {chainID: 2, subject: 0, cursed: true}, + {chainID: 2, subject: 1, cursed: false}, + }, + }, + { + name: "chain and lanes", + curseActionsBuilder: func(mapIDToSelector mapIDToSelectorFunc) []CurseAction { + return []CurseAction{CurseChain(mapIDToSelector(0)), CurseLane(mapIDToSelector(1), mapIDToSelector(2))} + }, + curseAssertions: []curseAssertion{ + {chainID: 0, globalCurse: true, cursed: true}, + {chainID: 1, subject: 0, cursed: true}, + {chainID: 1, subject: 2, cursed: true}, + {chainID: 2, subject: 0, cursed: true}, + {chainID: 2, subject: 1, cursed: true}, + }, + }, +} + +func TestRMNCurse(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name+"_NO_MCMS", func(t *testing.T) { + runRmnCurseTest(t, tc) + }) + t.Run(tc.name+"_MCMS", func(t *testing.T) { + runRmnCurseMCMSTest(t, tc) + }) + } +} + +func TestRMNUncurse(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name+"_UNCURSE", func(t *testing.T) { + runRmnUncurseTest(t, tc) + }) + t.Run(tc.name+"_UNCURSE_MCMS", func(t *testing.T) { + runRmnUncurseMCMSTest(t, tc) + }) + } +} + +func TestRMNCurseConfigValidate(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name+"_VALIDATE", func(t *testing.T) { + runRmnCurseConfigValidateTest(t, tc) + }) + } +} + +func runRmnUncurseTest(t *testing.T, tc CurseTestCase) { + e := NewMemoryEnvironment(t, WithChains(3)) + + mapIDToSelector := func(id uint64) uint64 { + return e.Env.AllChainSelectors()[id] + } + + verifyNoActiveCurseOnAllChains(t, &e) + + config := RMNCurseConfig{ + CurseActions: tc.curseActionsBuilder(mapIDToSelector), + Reason: "test curse", + } + + _, err := NewRMNCurseChangeset(e.Env, config) + require.NoError(t, err) + + verifyTestCaseAssertions(t, &e, tc, mapIDToSelector) + + _, err = NewRMNUncurseChangeset(e.Env, config) + require.NoError(t, err) + + verifyNoActiveCurseOnAllChains(t, &e) +} + +func transferRMNContractToMCMS(t *testing.T, e *DeployedEnv, state CCIPOnChainState, timelocksPerChain map[uint64]*proposalutils.TimelockExecutionContracts) { + contractsByChain := make(map[uint64][]common.Address) + rmnRemoteAddressesByChain := buildRMNRemoteAddressPerChain(e.Env, state) + for chainSelector, rmnRemoteAddress := range rmnRemoteAddressesByChain { + contractsByChain[chainSelector] = []common.Address{rmnRemoteAddress} + } + + contractsByChain[e.HomeChainSel] = append(contractsByChain[e.HomeChainSel], state.Chains[e.HomeChainSel].RMNHome.Address()) + + // This is required because RMN Contracts is initially owned by the deployer + _, err := commonchangeset.ApplyChangesets(t, e.Env, timelocksPerChain, []commonchangeset.ChangesetApplication{ + { + Changeset: commonchangeset.WrapChangeSet(commonchangeset.TransferToMCMSWithTimelock), + Config: commonchangeset.TransferToMCMSWithTimelockConfig{ + ContractsByChain: contractsByChain, + MinDelay: 0, + }, + }, + }) + require.NoError(t, err) +} + +func runRmnUncurseMCMSTest(t *testing.T, tc CurseTestCase) { + e := NewMemoryEnvironment(t, WithChains(3)) + + mapIDToSelector := func(id uint64) uint64 { + return e.Env.AllChainSelectors()[id] + } + + config := RMNCurseConfig{ + CurseActions: tc.curseActionsBuilder(mapIDToSelector), + Reason: "test curse", + MCMS: &MCMSConfig{MinDelay: 0}, + } + + state, err := LoadOnchainState(e.Env) + require.NoError(t, err) + + verifyNoActiveCurseOnAllChains(t, &e) + + timelocksPerChain := buildTimelockPerChain(e.Env, state) + + transferRMNContractToMCMS(t, &e, state, timelocksPerChain) + + _, err = commonchangeset.ApplyChangesets(t, e.Env, timelocksPerChain, []commonchangeset.ChangesetApplication{ + { + Changeset: commonchangeset.WrapChangeSet(NewRMNCurseChangeset), + Config: config, + }, + }) + require.NoError(t, err) + + verifyTestCaseAssertions(t, &e, tc, mapIDToSelector) + + _, err = commonchangeset.ApplyChangesets(t, e.Env, timelocksPerChain, []commonchangeset.ChangesetApplication{ + { + Changeset: commonchangeset.WrapChangeSet(NewRMNUncurseChangeset), + Config: config, + }, + }) + require.NoError(t, err) + + verifyNoActiveCurseOnAllChains(t, &e) +} + +func runRmnCurseConfigValidateTest(t *testing.T, tc CurseTestCase) { + e := NewMemoryEnvironment(t, WithChains(3)) + + mapIDToSelector := func(id uint64) uint64 { + return e.Env.AllChainSelectors()[id] + } + + config := RMNCurseConfig{ + CurseActions: tc.curseActionsBuilder(mapIDToSelector), + Reason: "test curse", + } + + err := config.Validate(e.Env) + require.NoError(t, err) +} + +func runRmnCurseTest(t *testing.T, tc CurseTestCase) { + e := NewMemoryEnvironment(t, WithChains(3)) + + mapIDToSelector := func(id uint64) uint64 { + return e.Env.AllChainSelectors()[id] + } + + verifyNoActiveCurseOnAllChains(t, &e) + + config := RMNCurseConfig{ + CurseActions: tc.curseActionsBuilder(mapIDToSelector), + Reason: "test curse", + } + + _, err := NewRMNCurseChangeset(e.Env, config) + require.NoError(t, err) + + verifyTestCaseAssertions(t, &e, tc, mapIDToSelector) +} + +func runRmnCurseMCMSTest(t *testing.T, tc CurseTestCase) { + e := NewMemoryEnvironment(t, WithChains(3)) + + mapIDToSelector := func(id uint64) uint64 { + return e.Env.AllChainSelectors()[id] + } + + config := RMNCurseConfig{ + CurseActions: tc.curseActionsBuilder(mapIDToSelector), + Reason: "test curse", + MCMS: &MCMSConfig{MinDelay: 0}, + } + + state, err := LoadOnchainState(e.Env) + require.NoError(t, err) + + verifyNoActiveCurseOnAllChains(t, &e) + + timelocksPerChain := buildTimelockPerChain(e.Env, state) + + transferRMNContractToMCMS(t, &e, state, timelocksPerChain) + + _, err = commonchangeset.ApplyChangesets(t, e.Env, timelocksPerChain, []commonchangeset.ChangesetApplication{ + { + Changeset: commonchangeset.WrapChangeSet(NewRMNCurseChangeset), + Config: config, + }, + }) + require.NoError(t, err) + + verifyTestCaseAssertions(t, &e, tc, mapIDToSelector) +} + +func verifyTestCaseAssertions(t *testing.T, e *DeployedEnv, tc CurseTestCase, mapIDToSelector mapIDToSelectorFunc) { + state, err := LoadOnchainState(e.Env) + require.NoError(t, err) + + for _, assertion := range tc.curseAssertions { + cursedSubject := SelectorToSubject(mapIDToSelector(assertion.subject)) + if assertion.globalCurse { + cursedSubject = GlobalCurseSubject() + } + + isCursed, err := state.Chains[mapIDToSelector(assertion.chainID)].RMNRemote.IsCursed(nil, cursedSubject) + require.NoError(t, err) + require.Equal(t, assertion.cursed, isCursed, "chain %d subject %d", assertion.chainID, assertion.subject) + } +} + +func verifyNoActiveCurseOnAllChains(t *testing.T, e *DeployedEnv) { + state, err := LoadOnchainState(e.Env) + require.NoError(t, err) + + for _, chain := range e.Env.Chains { + isCursed, err := state.Chains[chain.Selector].RMNRemote.IsCursed0(nil) + require.NoError(t, err) + require.False(t, isCursed, "chain %d", chain.Selector) + } +} diff --git a/deployment/ccip/changeset/cs_update_rmn_config.go b/deployment/ccip/changeset/cs_update_rmn_config.go index 337b3756881..e26342eb494 100644 --- a/deployment/ccip/changeset/cs_update_rmn_config.go +++ b/deployment/ccip/changeset/cs_update_rmn_config.go @@ -9,7 +9,6 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" - "github.com/smartcontractkit/ccip-owner-contracts/pkg/gethwrappers" "github.com/smartcontractkit/ccip-owner-contracts/pkg/proposal/mcms" "github.com/smartcontractkit/ccip-owner-contracts/pkg/proposal/timelock" @@ -408,34 +407,6 @@ func NewPromoteCandidateConfigChangeset(e deployment.Environment, config Promote }, nil } -func buildTimelockPerChain(e deployment.Environment, state CCIPOnChainState) map[uint64]*proposalutils.TimelockExecutionContracts { - timelocksPerChain := make(map[uint64]*proposalutils.TimelockExecutionContracts) - for _, chain := range e.Chains { - timelocksPerChain[chain.Selector] = &proposalutils.TimelockExecutionContracts{ - Timelock: state.Chains[chain.Selector].Timelock, - CallProxy: state.Chains[chain.Selector].CallProxy, - } - } - return timelocksPerChain -} - -func buildTimelockAddressPerChain(e deployment.Environment, state CCIPOnChainState) map[uint64]common.Address { - timelocksPerChain := buildTimelockPerChain(e, state) - timelockAddressPerChain := make(map[uint64]common.Address) - for chain, timelock := range timelocksPerChain { - timelockAddressPerChain[chain] = timelock.Timelock.Address() - } - return timelockAddressPerChain -} - -func buildProposerPerChain(e deployment.Environment, state CCIPOnChainState) map[uint64]*gethwrappers.ManyChainMultiSig { - proposerPerChain := make(map[uint64]*gethwrappers.ManyChainMultiSig) - for _, chain := range e.Chains { - proposerPerChain[chain.Selector] = state.Chains[chain.Selector].ProposerMcm - } - return proposerPerChain -} - func buildRMNRemotePerChain(e deployment.Environment, state CCIPOnChainState) map[uint64]*rmn_remote.RMNRemote { timelocksPerChain := make(map[uint64]*rmn_remote.RMNRemote) for _, chain := range e.Chains { diff --git a/deployment/ccip/changeset/deployer_group.go b/deployment/ccip/changeset/deployer_group.go new file mode 100644 index 00000000000..5fd2c0d390e --- /dev/null +++ b/deployment/ccip/changeset/deployer_group.go @@ -0,0 +1,165 @@ +package changeset + +import ( + "context" + "fmt" + + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/smartcontractkit/ccip-owner-contracts/pkg/gethwrappers" + "github.com/smartcontractkit/ccip-owner-contracts/pkg/proposal/mcms" + "github.com/smartcontractkit/ccip-owner-contracts/pkg/proposal/timelock" + + "github.com/smartcontractkit/chainlink/deployment" + "github.com/smartcontractkit/chainlink/deployment/common/proposalutils" +) + +type DeployerGroup struct { + e deployment.Environment + state CCIPOnChainState + mcmConfig *MCMSConfig + transactions map[uint64][]*types.Transaction +} + +// DeployerGroup is an abstraction that lets developers write their changeset +// without needing to know if it's executed using a DeployerKey or an MCMS proposal. +// +// Example usage: +// +// deployerGroup := NewDeployerGroup(e, state, mcmConfig) +// selector := 0 +// # Get the right deployer key for the chain +// deployer := deployerGroup.getDeployer(selector) +// state.Chains[selector].RMNRemote.Curse() +// # Execute the transaction or create the proposal +// deployerGroup.enact("Curse RMNRemote") +func NewDeployerGroup(e deployment.Environment, state CCIPOnChainState, mcmConfig *MCMSConfig) *DeployerGroup { + return &DeployerGroup{ + e: e, + mcmConfig: mcmConfig, + state: state, + transactions: make(map[uint64][]*types.Transaction), + } +} + +func (d *DeployerGroup) getDeployer(chain uint64) *bind.TransactOpts { + txOpts := d.e.Chains[chain].DeployerKey + if d.mcmConfig != nil { + txOpts = deployment.SimTransactOpts() + } + sim := &bind.TransactOpts{ + From: txOpts.From, + Signer: txOpts.Signer, + GasLimit: txOpts.GasLimit, + GasPrice: txOpts.GasPrice, + Nonce: txOpts.Nonce, + Value: txOpts.Value, + GasFeeCap: txOpts.GasFeeCap, + GasTipCap: txOpts.GasTipCap, + Context: txOpts.Context, + AccessList: txOpts.AccessList, + NoSend: true, + } + oldSigner := sim.Signer + sim.Signer = func(a common.Address, t *types.Transaction) (*types.Transaction, error) { + tx, err := oldSigner(a, t) + if err != nil { + return nil, err + } + d.transactions[chain] = append(d.transactions[chain], tx) + return tx, nil + } + return sim +} + +func (d *DeployerGroup) enact(deploymentDescription string) (deployment.ChangesetOutput, error) { + if d.mcmConfig != nil { + return d.enactMcms(deploymentDescription) + } + + return d.enactDeployer() +} + +func (d *DeployerGroup) enactMcms(deploymentDescription string) (deployment.ChangesetOutput, error) { + batches := make([]timelock.BatchChainOperation, 0) + for selector, txs := range d.transactions { + mcmOps := make([]mcms.Operation, len(txs)) + for i, tx := range txs { + mcmOps[i] = mcms.Operation{ + To: *tx.To(), + Data: tx.Data(), + Value: tx.Value(), + } + } + batches = append(batches, timelock.BatchChainOperation{ + ChainIdentifier: mcms.ChainIdentifier(selector), + Batch: mcmOps, + }) + } + + timelocksPerChain := buildTimelockAddressPerChain(d.e, d.state) + + proposerMCMSes := buildProposerPerChain(d.e, d.state) + + prop, err := proposalutils.BuildProposalFromBatches( + timelocksPerChain, + proposerMCMSes, + batches, + deploymentDescription, + d.mcmConfig.MinDelay, + ) + + if err != nil { + return deployment.ChangesetOutput{}, fmt.Errorf("failed to build proposal %w", err) + } + + return deployment.ChangesetOutput{ + Proposals: []timelock.MCMSWithTimelockProposal{*prop}, + }, nil +} + +func (d *DeployerGroup) enactDeployer() (deployment.ChangesetOutput, error) { + for selector, txs := range d.transactions { + for _, tx := range txs { + err := d.e.Chains[selector].Client.SendTransaction(context.Background(), tx) + if err != nil { + return deployment.ChangesetOutput{}, fmt.Errorf("failed to send transaction: %w", err) + } + + _, err = d.e.Chains[selector].Confirm(tx) + if err != nil { + return deployment.ChangesetOutput{}, fmt.Errorf("waiting for tx to be mined failed: %w", err) + } + } + } + return deployment.ChangesetOutput{}, nil +} + +func buildTimelockPerChain(e deployment.Environment, state CCIPOnChainState) map[uint64]*proposalutils.TimelockExecutionContracts { + timelocksPerChain := make(map[uint64]*proposalutils.TimelockExecutionContracts) + for _, chain := range e.Chains { + timelocksPerChain[chain.Selector] = &proposalutils.TimelockExecutionContracts{ + Timelock: state.Chains[chain.Selector].Timelock, + CallProxy: state.Chains[chain.Selector].CallProxy, + } + } + return timelocksPerChain +} + +func buildTimelockAddressPerChain(e deployment.Environment, state CCIPOnChainState) map[uint64]common.Address { + timelocksPerChain := buildTimelockPerChain(e, state) + timelockAddressPerChain := make(map[uint64]common.Address) + for chain, timelock := range timelocksPerChain { + timelockAddressPerChain[chain] = timelock.Timelock.Address() + } + return timelockAddressPerChain +} + +func buildProposerPerChain(e deployment.Environment, state CCIPOnChainState) map[uint64]*gethwrappers.ManyChainMultiSig { + proposerPerChain := make(map[uint64]*gethwrappers.ManyChainMultiSig) + for _, chain := range e.Chains { + proposerPerChain[chain.Selector] = state.Chains[chain.Selector].ProposerMcm + } + return proposerPerChain +} diff --git a/integration-tests/smoke/ccip/ccip_rmn_test.go b/integration-tests/smoke/ccip/ccip_rmn_test.go index a3877013103..c70037d3e20 100644 --- a/integration-tests/smoke/ccip/ccip_rmn_test.go +++ b/integration-tests/smoke/ccip/ccip_rmn_test.go @@ -2,7 +2,6 @@ package smoke import ( "context" - "encoding/binary" "errors" "math/big" "os" @@ -17,17 +16,15 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/rs/zerolog" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" "github.com/smartcontractkit/chainlink-protos/job-distributor/v1/node" "github.com/smartcontractkit/chainlink-testing-framework/lib/utils/osutil" "github.com/smartcontractkit/chainlink-testing-framework/lib/utils/testcontext" - "github.com/smartcontractkit/chainlink-ccip/pkg/reader" - "github.com/smartcontractkit/chainlink/deployment/ccip/changeset" "github.com/smartcontractkit/chainlink/deployment/environment/devenv" - "github.com/smartcontractkit/chainlink/deployment" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/rmn_home" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/rmn_remote" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/router" @@ -614,7 +611,7 @@ func (tc rmnTestCase) callContractsToCurseChains(ctx context.Context, t *testing remoteSel := tc.pf.chainSelectors[remoteCfg.chainIdx] chState, ok := onChainState.Chains[remoteSel] require.True(t, ok) - chain, ok := envWithRMN.Env.Chains[remoteSel] + _, ok = envWithRMN.Env.Chains[remoteSel] require.True(t, ok) cursedSubjects, ok := tc.cursedSubjectsPerChain[remoteCfg.chainIdx] @@ -623,14 +620,19 @@ func (tc rmnTestCase) callContractsToCurseChains(ctx context.Context, t *testing } for _, subjectDescription := range cursedSubjects { - subj := reader.GlobalCurseSubject - if subjectDescription != globalCurse { - subj = chainSelectorToBytes16(tc.pf.chainSelectors[subjectDescription]) + curseActions := make([]changeset.CurseAction, 0) + + if subjectDescription == globalCurse { + curseActions = append(curseActions, changeset.CurseChain(remoteSel)) + } else { + curseActions = append(curseActions, changeset.CurseLane(remoteSel, (tc.pf.chainSelectors[subjectDescription]))) } - t.Logf("cursing subject %d (%d)", subj, subjectDescription) - txCurse, errCurse := chState.RMNRemote.Curse(chain.DeployerKey, subj) - _, errConfirm := deployment.ConfirmIfNoError(chain, txCurse, errCurse) - require.NoError(t, errConfirm) + + _, err := changeset.NewRMNCurseChangeset(envWithRMN.Env, changeset.RMNCurseConfig{ + CurseActions: curseActions, + Reason: "test curse", + }) + t.Error(err) } cs, err := chState.RMNRemote.GetCursedSubjects(&bind.CallOpts{Context: ctx}) @@ -644,32 +646,53 @@ func (tc rmnTestCase) callContractsToCurseAndRevokeCurse(ctx context.Context, t remoteSel := tc.pf.chainSelectors[remoteCfg.chainIdx] chState, ok := onChainState.Chains[remoteSel] require.True(t, ok) - chain, ok := envWithRMN.Env.Chains[remoteSel] + _, ok = envWithRMN.Env.Chains[remoteSel] require.True(t, ok) cursedSubjects, ok := tc.revokedCursedSubjectsPerChain[remoteCfg.chainIdx] if !ok { continue // nothing to curse on this chain } + eg := errgroup.Group{} for subjectDescription, revokeAfter := range cursedSubjects { - subj := reader.GlobalCurseSubject - if subjectDescription != globalCurse { - subj = chainSelectorToBytes16(tc.pf.chainSelectors[subjectDescription]) + curseActions := make([]changeset.CurseAction, 0) + + if subjectDescription == globalCurse { + curseActions = append(curseActions, changeset.CurseChain(remoteSel)) + } else { + curseActions = append(curseActions, changeset.CurseLane(remoteSel, (tc.pf.chainSelectors[subjectDescription]))) } - t.Logf("cursing subject %d (%d)", subj, subjectDescription) - txCurse, errCurse := chState.RMNRemote.Curse(chain.DeployerKey, subj) - _, errConfirm := deployment.ConfirmIfNoError(chain, txCurse, errCurse) - require.NoError(t, errConfirm) - go func() { + _, err := changeset.NewRMNCurseChangeset(envWithRMN.Env, changeset.RMNCurseConfig{ + CurseActions: curseActions, + Reason: "test curse", + }) + t.Error(err) + + eg.Go(func() error { <-time.NewTimer(revokeAfter).C - t.Logf("revoking curse on subject %d (%d)", subj, subjectDescription) - txUncurse, errUncurse := chState.RMNRemote.Uncurse(chain.DeployerKey, subj) - _, errConfirm = deployment.ConfirmIfNoError(chain, txUncurse, errUncurse) - require.NoError(t, errConfirm) - }() + t.Logf("revoking curse on subject %d (%d)", subjectDescription, subjectDescription) + + _, err := changeset.NewRMNUncurseChangeset(envWithRMN.Env, changeset.RMNCurseConfig{ + CurseActions: curseActions, + Reason: "test uncurse", + }) + if err != nil { + return err + } + + cs, err := chState.RMNRemote.GetCursedSubjects(&bind.CallOpts{Context: ctx}) + + if err != nil { + return err + } + + t.Logf("Cursed subjects after revoking: %v", cs) + return nil + }) } + require.NoError(t, eg.Wait()) cs, err := chState.RMNRemote.GetCursedSubjects(&bind.CallOpts{Context: ctx}) require.NoError(t, err) @@ -684,10 +707,3 @@ func (tc rmnTestCase) enableOracles(ctx context.Context, t *testing.T, envWithRM t.Logf("node %s enabled", n) } } - -func chainSelectorToBytes16(chainSel uint64) [16]byte { - var result [16]byte - // Convert the uint64 to bytes and place it in the last 8 bytes of the array - binary.BigEndian.PutUint64(result[8:], chainSel) - return result -}