diff --git a/.changeset/good-beans-invent.md b/.changeset/good-beans-invent.md new file mode 100644 index 0000000000..646d4fe64e --- /dev/null +++ b/.changeset/good-beans-invent.md @@ -0,0 +1,5 @@ +--- +'@coinbase/onchainkit': minor +--- + +feat: add `onConnect` handler to ``. By @dschlabach #1529 diff --git a/src/wallet/components/ConnectWallet.test.tsx b/src/wallet/components/ConnectWallet.test.tsx index 1df58c97f9..ab0f1e869f 100644 --- a/src/wallet/components/ConnectWallet.test.tsx +++ b/src/wallet/components/ConnectWallet.test.tsx @@ -88,7 +88,7 @@ describe('ConnectWallet', () => { expect(connectedText).toBeInTheDocument(); }); - it('should calls connect function when connect button is clicked', () => { + it('should call connect function when connect button is clicked', () => { const connectMock = vi.fn(); vi.mocked(useConnect).mockReturnValue({ connectors: [{ id: 'mockConnector' }], @@ -98,9 +98,14 @@ describe('ConnectWallet', () => { render(); const button = screen.getByTestId('ockConnectButton'); fireEvent.click(button); - expect(connectMock).toHaveBeenCalledWith({ - connector: { id: 'mockConnector' }, - }); + expect(connectMock).toHaveBeenCalledWith( + { + connector: { id: 'mockConnector' }, + }, + { + onSuccess: expect.any(Function), + }, + ); }); it('should toggle wallet modal on button click when connected', () => { @@ -162,6 +167,56 @@ describe('ConnectWallet', () => { expect(screen.queryByText('Not Render')).not.toBeInTheDocument(); }); + it('should call onConnect callback when connect button is clicked', async () => { + const mockUseAccount = vi.mocked(useAccount); + const connectMock = vi.fn(); + const onConnectMock = vi.fn(); + + // Initial state: disconnected + mockUseAccount.mockReturnValue({ + address: undefined, + status: 'disconnected', + }); + + vi.mocked(useConnect).mockReturnValue({ + connectors: [{ id: 'mockConnector' }], + connect: connectMock, + status: 'idle', + }); + + render(); + + const button = screen.getByTestId('ockConnectButton'); + fireEvent.click(button); + + // Simulate successful connection + connectMock.mock.calls[0][1].onSuccess(); + + // Update account status to connected + mockUseAccount.mockReturnValue({ + address: '0x123', + status: 'connected', + }); + + // Force a re-render to trigger the useEffect + render(); + + expect(onConnectMock).toHaveBeenCalledTimes(1); + }); + + it('should not call onConnect callback when component is first mounted', () => { + const mockUseAccount = vi.mocked(useAccount); + mockUseAccount.mockReturnValue({ + address: '0x123', + status: 'connected', + }); + + const onConnectMock = vi.fn(); + render(); + + expect(onConnectMock).toHaveBeenCalledTimes(0); + }); + describe('withWalletAggregator', () => { beforeEach(() => { vi.mocked(useAccount).mockReturnValue({ @@ -175,7 +230,7 @@ describe('ConnectWallet', () => { }); }); - it('should render ConnectButtonRainboKit when withWalletAggregator is true', () => { + it('should render ConnectButtonRainbowKit when withWalletAggregator is true', () => { render( , ); @@ -198,12 +253,17 @@ describe('ConnectWallet', () => { ); const connectButton = screen.getByTestId('ockConnectButton'); fireEvent.click(connectButton); - expect(connectMock).toHaveBeenCalledWith({ - connector: { id: 'mockConnector' }, - }); + expect(connectMock).toHaveBeenCalledWith( + { + connector: { id: 'mockConnector' }, + }, + { + onSuccess: expect.any(Function), + }, + ); }); - it('should calls openConnectModal function when connect button is clicked', () => { + it('should call openConnectModal function when connect button is clicked', () => { vi.mocked(useWalletContext).mockReturnValue({ isOpen: false, setIsOpen: vi.fn(), @@ -215,5 +275,32 @@ describe('ConnectWallet', () => { fireEvent.click(button); expect(openConnectModalMock).toHaveBeenCalled(); }); + + it('should call onConnect callback when connect button is clicked', () => { + const mockUseAccount = vi.mocked(useAccount); + mockUseAccount.mockReturnValue({ + address: undefined, + status: 'disconnected', + }); + + const onConnectMock = vi.fn(); + render( + , + ); + const button = screen.getByTestId('ockConnectButton'); + + mockUseAccount.mockReturnValue({ + address: '0x123', + status: 'connected', + }); + + fireEvent.click(button); + + expect(onConnectMock).toHaveBeenCalledTimes(1); + }); }); }); diff --git a/src/wallet/components/ConnectWallet.tsx b/src/wallet/components/ConnectWallet.tsx index 17dc81f8d8..05ff8bb664 100644 --- a/src/wallet/components/ConnectWallet.tsx +++ b/src/wallet/components/ConnectWallet.tsx @@ -1,6 +1,7 @@ import { ConnectButton as ConnectButtonRainbowKit } from '@rainbow-me/rainbowkit'; import { Children, isValidElement, useCallback, useMemo } from 'react'; import type { ReactNode } from 'react'; +import { useEffect, useState } from 'react'; import { useAccount, useConnect } from 'wagmi'; import { IdentityProvider } from '../../identity/components/IdentityProvider'; import { Spinner } from '../../internal/components/Spinner'; @@ -24,12 +25,16 @@ export function ConnectWallet({ // but for now we will keep it for backward compatibility. text = 'Connect Wallet', withWalletAggregator = false, + onConnect, }: ConnectWalletReact) { // Core Hooks const { isOpen, setIsOpen } = useWalletContext(); const { address: accountAddress, status } = useAccount(); const { connectors, connect, status: connectStatus } = useConnect(); + // State + const [hasClickedConnect, setHasClickedConnect] = useState(false); + // Get connectWalletText from children when present, // this is used to customize the connect wallet button text const { connectWalletText } = useMemo(() => { @@ -58,6 +63,14 @@ export function ConnectWallet({ setIsOpen(!isOpen); }, [isOpen, setIsOpen]); + // Effects + useEffect(() => { + if (hasClickedConnect && status === 'connected' && onConnect) { + onConnect(); + setHasClickedConnect(false); + } + }, [status, hasClickedConnect, onConnect]); + if (status === 'disconnected') { if (withWalletAggregator) { return ( @@ -67,7 +80,10 @@ export function ConnectWallet({ openConnectModal()} + onClick={() => { + openConnectModal(); + setHasClickedConnect(true); + }} text={text} /> @@ -80,7 +96,16 @@ export function ConnectWallet({ connect({ connector })} + onClick={() => { + connect( + { connector }, + { + onSuccess: () => { + onConnect?.(); + }, + }, + ); + }} text={text} /> diff --git a/src/wallet/types.ts b/src/wallet/types.ts index f6fbdaf057..3e613d9b2b 100644 --- a/src/wallet/types.ts +++ b/src/wallet/types.ts @@ -20,6 +20,7 @@ export type ConnectWalletReact = { /** @deprecated Prefer `ConnectWalletText component` */ text?: string; // Optional text override for button withWalletAggregator?: boolean; // Optional flag to enable the wallet aggregator like RainbowKit + onConnect?: () => void; // Optional callback function to execute when the wallet is connected. }; /**