diff --git a/.changeset/good-beans-invent.md b/.changeset/good-beans-invent.md
new file mode 100644
index 0000000000..646d4fe64e
--- /dev/null
+++ b/.changeset/good-beans-invent.md
@@ -0,0 +1,5 @@
+---
+'@coinbase/onchainkit': minor
+---
+
+feat: add `onConnect` handler to ``. By @dschlabach #1529
diff --git a/src/wallet/components/ConnectWallet.test.tsx b/src/wallet/components/ConnectWallet.test.tsx
index 1df58c97f9..ab0f1e869f 100644
--- a/src/wallet/components/ConnectWallet.test.tsx
+++ b/src/wallet/components/ConnectWallet.test.tsx
@@ -88,7 +88,7 @@ describe('ConnectWallet', () => {
expect(connectedText).toBeInTheDocument();
});
- it('should calls connect function when connect button is clicked', () => {
+ it('should call connect function when connect button is clicked', () => {
const connectMock = vi.fn();
vi.mocked(useConnect).mockReturnValue({
connectors: [{ id: 'mockConnector' }],
@@ -98,9 +98,14 @@ describe('ConnectWallet', () => {
render();
const button = screen.getByTestId('ockConnectButton');
fireEvent.click(button);
- expect(connectMock).toHaveBeenCalledWith({
- connector: { id: 'mockConnector' },
- });
+ expect(connectMock).toHaveBeenCalledWith(
+ {
+ connector: { id: 'mockConnector' },
+ },
+ {
+ onSuccess: expect.any(Function),
+ },
+ );
});
it('should toggle wallet modal on button click when connected', () => {
@@ -162,6 +167,56 @@ describe('ConnectWallet', () => {
expect(screen.queryByText('Not Render')).not.toBeInTheDocument();
});
+ it('should call onConnect callback when connect button is clicked', async () => {
+ const mockUseAccount = vi.mocked(useAccount);
+ const connectMock = vi.fn();
+ const onConnectMock = vi.fn();
+
+ // Initial state: disconnected
+ mockUseAccount.mockReturnValue({
+ address: undefined,
+ status: 'disconnected',
+ });
+
+ vi.mocked(useConnect).mockReturnValue({
+ connectors: [{ id: 'mockConnector' }],
+ connect: connectMock,
+ status: 'idle',
+ });
+
+ render();
+
+ const button = screen.getByTestId('ockConnectButton');
+ fireEvent.click(button);
+
+ // Simulate successful connection
+ connectMock.mock.calls[0][1].onSuccess();
+
+ // Update account status to connected
+ mockUseAccount.mockReturnValue({
+ address: '0x123',
+ status: 'connected',
+ });
+
+ // Force a re-render to trigger the useEffect
+ render();
+
+ expect(onConnectMock).toHaveBeenCalledTimes(1);
+ });
+
+ it('should not call onConnect callback when component is first mounted', () => {
+ const mockUseAccount = vi.mocked(useAccount);
+ mockUseAccount.mockReturnValue({
+ address: '0x123',
+ status: 'connected',
+ });
+
+ const onConnectMock = vi.fn();
+ render();
+
+ expect(onConnectMock).toHaveBeenCalledTimes(0);
+ });
+
describe('withWalletAggregator', () => {
beforeEach(() => {
vi.mocked(useAccount).mockReturnValue({
@@ -175,7 +230,7 @@ describe('ConnectWallet', () => {
});
});
- it('should render ConnectButtonRainboKit when withWalletAggregator is true', () => {
+ it('should render ConnectButtonRainbowKit when withWalletAggregator is true', () => {
render(
,
);
@@ -198,12 +253,17 @@ describe('ConnectWallet', () => {
);
const connectButton = screen.getByTestId('ockConnectButton');
fireEvent.click(connectButton);
- expect(connectMock).toHaveBeenCalledWith({
- connector: { id: 'mockConnector' },
- });
+ expect(connectMock).toHaveBeenCalledWith(
+ {
+ connector: { id: 'mockConnector' },
+ },
+ {
+ onSuccess: expect.any(Function),
+ },
+ );
});
- it('should calls openConnectModal function when connect button is clicked', () => {
+ it('should call openConnectModal function when connect button is clicked', () => {
vi.mocked(useWalletContext).mockReturnValue({
isOpen: false,
setIsOpen: vi.fn(),
@@ -215,5 +275,32 @@ describe('ConnectWallet', () => {
fireEvent.click(button);
expect(openConnectModalMock).toHaveBeenCalled();
});
+
+ it('should call onConnect callback when connect button is clicked', () => {
+ const mockUseAccount = vi.mocked(useAccount);
+ mockUseAccount.mockReturnValue({
+ address: undefined,
+ status: 'disconnected',
+ });
+
+ const onConnectMock = vi.fn();
+ render(
+ ,
+ );
+ const button = screen.getByTestId('ockConnectButton');
+
+ mockUseAccount.mockReturnValue({
+ address: '0x123',
+ status: 'connected',
+ });
+
+ fireEvent.click(button);
+
+ expect(onConnectMock).toHaveBeenCalledTimes(1);
+ });
});
});
diff --git a/src/wallet/components/ConnectWallet.tsx b/src/wallet/components/ConnectWallet.tsx
index 17dc81f8d8..05ff8bb664 100644
--- a/src/wallet/components/ConnectWallet.tsx
+++ b/src/wallet/components/ConnectWallet.tsx
@@ -1,6 +1,7 @@
import { ConnectButton as ConnectButtonRainbowKit } from '@rainbow-me/rainbowkit';
import { Children, isValidElement, useCallback, useMemo } from 'react';
import type { ReactNode } from 'react';
+import { useEffect, useState } from 'react';
import { useAccount, useConnect } from 'wagmi';
import { IdentityProvider } from '../../identity/components/IdentityProvider';
import { Spinner } from '../../internal/components/Spinner';
@@ -24,12 +25,16 @@ export function ConnectWallet({
// but for now we will keep it for backward compatibility.
text = 'Connect Wallet',
withWalletAggregator = false,
+ onConnect,
}: ConnectWalletReact) {
// Core Hooks
const { isOpen, setIsOpen } = useWalletContext();
const { address: accountAddress, status } = useAccount();
const { connectors, connect, status: connectStatus } = useConnect();
+ // State
+ const [hasClickedConnect, setHasClickedConnect] = useState(false);
+
// Get connectWalletText from children when present,
// this is used to customize the connect wallet button text
const { connectWalletText } = useMemo(() => {
@@ -58,6 +63,14 @@ export function ConnectWallet({
setIsOpen(!isOpen);
}, [isOpen, setIsOpen]);
+ // Effects
+ useEffect(() => {
+ if (hasClickedConnect && status === 'connected' && onConnect) {
+ onConnect();
+ setHasClickedConnect(false);
+ }
+ }, [status, hasClickedConnect, onConnect]);
+
if (status === 'disconnected') {
if (withWalletAggregator) {
return (
@@ -67,7 +80,10 @@ export function ConnectWallet({
openConnectModal()}
+ onClick={() => {
+ openConnectModal();
+ setHasClickedConnect(true);
+ }}
text={text}
/>
@@ -80,7 +96,16 @@ export function ConnectWallet({
connect({ connector })}
+ onClick={() => {
+ connect(
+ { connector },
+ {
+ onSuccess: () => {
+ onConnect?.();
+ },
+ },
+ );
+ }}
text={text}
/>
diff --git a/src/wallet/types.ts b/src/wallet/types.ts
index f6fbdaf057..3e613d9b2b 100644
--- a/src/wallet/types.ts
+++ b/src/wallet/types.ts
@@ -20,6 +20,7 @@ export type ConnectWalletReact = {
/** @deprecated Prefer `ConnectWalletText component` */
text?: string; // Optional text override for button
withWalletAggregator?: boolean; // Optional flag to enable the wallet aggregator like RainbowKit
+ onConnect?: () => void; // Optional callback function to execute when the wallet is connected.
};
/**