diff --git a/.changeset/unlucky-walls-tie.md b/.changeset/unlucky-walls-tie.md new file mode 100644 index 0000000000..6b7be30782 --- /dev/null +++ b/.changeset/unlucky-walls-tie.md @@ -0,0 +1,5 @@ +--- +"@coinbase/onchainkit": minor +--- + +**feat**: Swap success state - refetch balances and clear inputs by @0xAlec #1089 diff --git a/src/swap/components/SwapProvider.test.tsx b/src/swap/components/SwapProvider.test.tsx index a1f465f3cd..fde3a09cbc 100644 --- a/src/swap/components/SwapProvider.test.tsx +++ b/src/swap/components/SwapProvider.test.tsx @@ -6,7 +6,7 @@ import { screen, waitFor, } from '@testing-library/react'; -import React, { act, useEffect } from 'react'; +import React, { act, useCallback, useEffect } from 'react'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { http, WagmiProvider, createConfig, useAccount } from 'wagmi'; import { base } from 'wagmi/chains'; @@ -17,6 +17,11 @@ import { DEGEN_TOKEN, ETH_TOKEN } from '../mocks'; import { getSwapErrorCode } from '../utils/getSwapErrorCode'; import { SwapProvider, useSwapContext } from './SwapProvider'; +const mockResetFunction = vi.fn(); +vi.mock('../hooks/useResetInputs', () => ({ + useResetInputs: () => useCallback(mockResetFunction, []), +})); + vi.mock('../../api/getSwapQuote', () => ({ getSwapQuote: vi.fn(), })); @@ -217,6 +222,7 @@ describe('useSwapContext', () => { describe('SwapProvider', () => { beforeEach(async () => { + vi.resetAllMocks(); (useAccount as ReturnType).mockReturnValue({ address: '0x123', }); @@ -249,6 +255,20 @@ describe('SwapProvider', () => { expect(result.current.error).toBeUndefined(); }); + it('should reset inputs when setLifeCycleStatus is called with success', async () => { + const { result } = renderHook(() => useSwapContext(), { wrapper }); + await act(async () => { + result.current.setLifeCycleStatus({ + statusName: 'success', + statusData: { transactionReceipt: '0x123' }, + }); + }); + await waitFor(() => { + expect(mockResetFunction).toHaveBeenCalled(); + }); + expect(mockResetFunction).toHaveBeenCalledTimes(1); + }); + it('should emit onError when setLifeCycleStatus is called with error', async () => { const onErrorMock = vi.fn(); renderWithProviders({ Component: TestSwapComponent, onError: onErrorMock }); diff --git a/src/swap/components/SwapProvider.tsx b/src/swap/components/SwapProvider.tsx index 7516bab465..c565d796d2 100644 --- a/src/swap/components/SwapProvider.tsx +++ b/src/swap/components/SwapProvider.tsx @@ -14,6 +14,7 @@ import type { Token } from '../../token'; import { GENERIC_ERROR_MESSAGE } from '../../transaction/constants'; import { isUserRejectedRequestError } from '../../transaction/utils/isUserRejectedRequestError'; import { useFromTo } from '../hooks/useFromTo'; +import { useResetInputs } from '../hooks/useResetInputs'; import type { LifeCycleStatus, SwapContextType, @@ -60,6 +61,9 @@ export function SwapProvider({ const { from, to } = useFromTo(address); const { sendTransactionAsync } = useSendTransaction(); // Sending the transaction (and approval, if applicable) + // Refreshes balances and inputs post-swap + const resetInputs = useResetInputs({ from, to }); + // Component lifecycle emitters useEffect(() => { // Error @@ -83,6 +87,7 @@ export function SwapProvider({ if (lifeCycleStatus.statusName === 'success') { setError(undefined); setLoading(false); + resetInputs(); setPendingTransaction(false); onSuccess?.(lifeCycleStatus.statusData.transactionReceipt); } @@ -95,6 +100,7 @@ export function SwapProvider({ lifeCycleStatus, lifeCycleStatus.statusData, // Keep statusData, so that the effect runs when it changes lifeCycleStatus.statusName, // Keep statusName, so that the effect runs when it changes + resetInputs, ]); const handleToggle = useCallback(() => { diff --git a/src/swap/hooks/useFromTo.test.ts b/src/swap/hooks/useFromTo.test.ts index 39b021a5e0..d5159e652c 100644 --- a/src/swap/hooks/useFromTo.test.ts +++ b/src/swap/hooks/useFromTo.test.ts @@ -1,5 +1,5 @@ -import { renderHook } from '@testing-library/react'; -import { describe, expect, it, vi } from 'vitest'; +import { act, renderHook } from '@testing-library/react'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; import { useValue } from '../../internal/hooks/useValue'; import { USDC_TOKEN } from '../mocks'; import { useFromTo } from './useFromTo'; @@ -22,41 +22,80 @@ describe('useFromTo', () => { (useSwapBalances as vi.Mock).mockReturnValue({ fromBalanceString: '100', fromTokenBalanceError: null, + fromTokenResponse: { refetch: vi.fn() }, toBalanceString: '200', toTokenBalanceError: null, + toTokenResponse: { refetch: vi.fn() }, }); - (useValue as vi.Mock).mockImplementation((props) => ({ ...props, amount: '100', + response: props.response, setAmount: vi.fn(), - token: USDC_TOKEN, - setToken: vi.fn(), setLoading: vi.fn(), + setToken: vi.fn(), + token: USDC_TOKEN, })); - const { result } = renderHook(() => useFromTo('0x123')); - expect(result.current.from).toEqual({ - balance: '100', amount: '100', - setAmount: expect.any(Function), - token: USDC_TOKEN, - setToken: expect.any(Function), + balance: '100', + balanceResponse: { refetch: expect.any(Function) }, + error: null, loading: false, + setAmount: expect.any(Function), setLoading: expect.any(Function), - error: null, + setToken: expect.any(Function), + token: USDC_TOKEN, }); - expect(result.current.to).toEqual({ - balance: '200', amount: '100', - setAmount: expect.any(Function), - token: USDC_TOKEN, - setToken: expect.any(Function), + balance: '200', + balanceResponse: { refetch: expect.any(Function) }, + error: null, loading: false, + setAmount: expect.any(Function), setLoading: expect.any(Function), - error: null, + setToken: expect.any(Function), + token: USDC_TOKEN, + }); + }); + + it('should call fromTokenResponse.refetch when from.response.refetch is called', async () => { + const mockFromRefetch = vi.fn().mockResolvedValue(undefined); + const mockToRefetch = vi.fn().mockResolvedValue(undefined); + (useSwapBalances as vi.Mock).mockReturnValue({ + fromTokenResponse: { refetch: mockFromRefetch }, + toTokenResponse: { refetch: mockToRefetch }, + }); + (useValue as vi.Mock).mockImplementation((props) => ({ + ...props, + response: props.response, + })); + const { result } = renderHook(() => useFromTo('0x123')); + await act(async () => { + await result.current.from.balanceResponse?.refetch(); + }); + expect(mockFromRefetch).toHaveBeenCalledTimes(1); + expect(mockToRefetch).not.toHaveBeenCalled(); + }); + + it('should call toTokenResponse.refetch when to.response.refetch is called', async () => { + const mockFromRefetch = vi.fn().mockResolvedValue(undefined); + const mockToRefetch = vi.fn().mockResolvedValue(undefined); + (useSwapBalances as vi.Mock).mockReturnValue({ + fromTokenResponse: { refetch: mockFromRefetch }, + toTokenResponse: { refetch: mockToRefetch }, + }); + (useValue as vi.Mock).mockImplementation((props) => ({ + ...props, + response: props.response, + })); + const { result } = renderHook(() => useFromTo('0x123')); + await act(async () => { + await result.current.to.balanceResponse?.refetch(); }); + expect(mockToRefetch).toHaveBeenCalledTimes(1); + expect(mockFromRefetch).not.toHaveBeenCalled(); }); }); diff --git a/src/swap/hooks/useFromTo.ts b/src/swap/hooks/useFromTo.ts index 57ff35fcaa..12db7f2a4f 100644 --- a/src/swap/hooks/useFromTo.ts +++ b/src/swap/hooks/useFromTo.ts @@ -2,9 +2,10 @@ import { useState } from 'react'; import type { Address } from 'viem'; import { useValue } from '../../internal/hooks/useValue'; import type { Token } from '../../token'; +import type { FromTo } from '../types'; import { useSwapBalances } from './useSwapBalances'; -export const useFromTo = (address?: Address) => { +export const useFromTo = (address?: Address): FromTo => { const [fromAmount, setFromAmount] = useState(''); const [fromToken, setFromToken] = useState(); const [toAmount, setToAmount] = useState(''); @@ -17,10 +18,13 @@ export const useFromTo = (address?: Address) => { fromTokenBalanceError, toBalanceString, toTokenBalanceError, + fromTokenResponse, + toTokenResponse, } = useSwapBalances({ address, fromToken, toToken }); const from = useValue({ balance: fromBalanceString, + balanceResponse: fromTokenResponse, amount: fromAmount, setAmount: setFromAmount, token: fromToken, @@ -32,6 +36,7 @@ export const useFromTo = (address?: Address) => { const to = useValue({ balance: toBalanceString, + balanceResponse: toTokenResponse, amount: toAmount, setAmount: setToAmount, token: toToken, diff --git a/src/swap/hooks/useResetInputs.test.ts b/src/swap/hooks/useResetInputs.test.ts new file mode 100644 index 0000000000..0f3b27e13f --- /dev/null +++ b/src/swap/hooks/useResetInputs.test.ts @@ -0,0 +1,96 @@ +import { act, renderHook } from '@testing-library/react'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; +import type { SwapUnit } from '../types'; +import { useResetInputs } from './useResetInputs'; + +describe('useResetInputs', () => { + const mockFromTokenResponse = { + refetch: vi.fn().mockResolvedValue(undefined), + }; + const mockToTokenResponse = { refetch: vi.fn().mockResolvedValue(undefined) }; + const mockFrom: SwapUnit = { + balance: '100', + balanceResponse: mockFromTokenResponse, + amount: '50', + setAmount: vi.fn(), + token: undefined, + setToken: vi.fn(), + loading: false, + setLoading: vi.fn(), + error: undefined, + }; + const mockTo: SwapUnit = { + balance: '200', + balanceResponse: mockToTokenResponse, + amount: '75', + setAmount: vi.fn(), + token: undefined, + setToken: vi.fn(), + loading: false, + setLoading: vi.fn(), + error: undefined, + }; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('should return a function', () => { + const { result } = renderHook(() => + useResetInputs({ from: mockFrom, to: mockTo }), + ); + expect(typeof result.current).toBe('function'); + }); + + it('should call refetch on responses and setAmount on both from and to when executed', async () => { + const { result } = renderHook(() => + useResetInputs({ from: mockFrom, to: mockTo }), + ); + await act(async () => { + await result.current(); + }); + expect(mockFromTokenResponse.refetch).toHaveBeenCalledTimes(1); + expect(mockToTokenResponse.refetch).toHaveBeenCalledTimes(1); + expect(mockFrom.setAmount).toHaveBeenCalledWith(''); + expect(mockTo.setAmount).toHaveBeenCalledWith(''); + }); + + it("should not create a new function reference if from and to haven't changed", () => { + const { result, rerender } = renderHook(() => + useResetInputs({ from: mockFrom, to: mockTo }), + ); + const firstRender = result.current; + rerender(); + expect(result.current).toBe(firstRender); + }); + + it('should create a new function reference if from or to change', () => { + const { result, rerender } = renderHook( + ({ from, to }) => useResetInputs({ from, to }), + { initialProps: { from: mockFrom, to: mockTo } }, + ); + const firstRender = result.current; + const newMockFrom = { + ...mockFrom, + response: { refetch: vi.fn().mockResolvedValue(undefined) }, + }; + rerender({ from: newMockFrom, to: mockTo }); + expect(result.current).not.toBe(firstRender); + }); + + it('should handle null responses gracefully', async () => { + const mockFromWithNullResponse = { ...mockFrom, response: null }; + const mockToWithNullResponse = { ...mockTo, response: null }; + const { result } = renderHook(() => + useResetInputs({ + from: mockFromWithNullResponse, + to: mockToWithNullResponse, + }), + ); + await act(async () => { + await result.current(); + }); + expect(mockFromWithNullResponse.setAmount).toHaveBeenCalledWith(''); + expect(mockToWithNullResponse.setAmount).toHaveBeenCalledWith(''); + }); +}); diff --git a/src/swap/hooks/useResetInputs.ts b/src/swap/hooks/useResetInputs.ts new file mode 100644 index 0000000000..ee52bd4356 --- /dev/null +++ b/src/swap/hooks/useResetInputs.ts @@ -0,0 +1,14 @@ +import { useCallback } from 'react'; +import type { FromTo } from '../types'; + +// Refreshes balances and inputs post-swap +export const useResetInputs = ({ from, to }: FromTo) => { + return useCallback(async () => { + await Promise.all([ + from.balanceResponse?.refetch(), + to.balanceResponse?.refetch(), + from.setAmount(''), + to.setAmount(''), + ]); + }, [from, to]); +}; diff --git a/src/swap/hooks/useSwapBalances.tsx b/src/swap/hooks/useSwapBalances.tsx index eb35389e92..ebd36bbdd6 100644 --- a/src/swap/hooks/useSwapBalances.tsx +++ b/src/swap/hooks/useSwapBalances.tsx @@ -13,14 +13,23 @@ export function useSwapBalances({ fromToken?: Token; toToken?: Token; }) { - const { convertedBalance: convertedEthBalance, error: ethBalanceError } = - useGetETHBalance(address); + const { + convertedBalance: convertedEthBalance, + error: ethBalanceError, + response: ethBalanceResponse, + } = useGetETHBalance(address); - const { convertedBalance: convertedFromBalance, error: fromBalanceError } = - useGetTokenBalance(address, fromToken); + const { + convertedBalance: convertedFromBalance, + error: fromBalanceError, + response: _fromTokenResponse, + } = useGetTokenBalance(address, fromToken); - const { convertedBalance: convertedToBalance, error: toBalanceError } = - useGetTokenBalance(address, toToken); + const { + convertedBalance: convertedToBalance, + error: toBalanceError, + response: _toTokenResponse, + } = useGetTokenBalance(address, toToken); const isFromNativeToken = fromToken?.symbol === 'ETH'; const isToNativeToken = toToken?.symbol === 'ETH'; @@ -37,12 +46,20 @@ export function useSwapBalances({ const toTokenBalanceError = isToNativeToken ? ethBalanceError : toBalanceError; + const fromTokenResponse = isFromNativeToken + ? ethBalanceResponse + : _fromTokenResponse; + const toTokenResponse = isToNativeToken + ? ethBalanceResponse + : _toTokenResponse; return useValue({ fromBalanceString, fromTokenBalanceError, + fromTokenResponse, toBalanceString, toTokenBalanceError, + toTokenResponse, }); } diff --git a/src/swap/types.ts b/src/swap/types.ts index 86f4c3e7dc..540c45e34d 100644 --- a/src/swap/types.ts +++ b/src/swap/types.ts @@ -1,6 +1,10 @@ import type { Dispatch, ReactNode, SetStateAction } from 'react'; import type { Address, Hex, TransactionReceipt } from 'viem'; -import type { Config } from 'wagmi'; +import type { + Config, + UseBalanceReturnType, + UseReadContractReturnType, +} from 'wagmi'; import type { SendTransactionMutateAsync } from 'wagmi/query'; import type { RawTransactionData } from '../api/types'; import type { Token } from '../token/types'; @@ -30,6 +34,11 @@ export type Fee = { percentage: string; // The percentage of the fee }; +export type FromTo = { + from: SwapUnit; + to: SwapUnit; +}; + export type GetSwapMessageParams = { address?: Address; error?: SwapError; @@ -228,6 +237,7 @@ export type SwapToggleButtonReact = { export type SwapUnit = { amount: string; balance?: string; + balanceResponse?: UseBalanceReturnType | UseReadContractReturnType; error?: SwapError; loading: boolean; setAmount: Dispatch>;