From e9c48ed48d36c7c58c0d820d94fc8fb459d32538 Mon Sep 17 00:00:00 2001 From: Ajay Vasisht <43521356+avasisht23@users.noreply.github.com> Date: Wed, 1 Nov 2023 13:19:03 -0400 Subject: [PATCH] feat: add zod runtime validation for base account (#186) * feat: add zod runtime validation for base account * feat: add zod runtime validation for base account * feat: add zod runtime validation for simple account * refactor: clean up base schemas * refactor: rebase * refactor: rename abitype import --- .../core/src/account/__tests__/simple.test.ts | 59 +++++++++++++++++++ packages/core/src/account/base.ts | 34 +++-------- packages/core/src/account/schema.ts | 22 +++++++ packages/core/src/account/simple.ts | 8 +-- packages/core/src/account/types.ts | 9 ++- packages/core/src/client/schema.ts | 17 ++++++ packages/core/src/index.ts | 11 +++- .../core/src/provider/__tests__/base.test.ts | 5 +- packages/core/src/provider/base.ts | 4 +- packages/core/src/provider/schema.ts | 46 +++------------ packages/core/src/provider/types.ts | 6 +- packages/core/src/signer/schema.ts | 13 ++++ packages/core/src/utils/index.ts | 1 + packages/core/src/utils/schema.ts | 20 +++++++ .../src/__tests__/provider-adapter.test.ts | 4 -- 15 files changed, 179 insertions(+), 80 deletions(-) create mode 100644 packages/core/src/account/schema.ts create mode 100644 packages/core/src/client/schema.ts create mode 100644 packages/core/src/signer/schema.ts create mode 100644 packages/core/src/utils/schema.ts diff --git a/packages/core/src/account/__tests__/simple.test.ts b/packages/core/src/account/__tests__/simple.test.ts index b5c0c3db02..04a047b403 100644 --- a/packages/core/src/account/__tests__/simple.test.ts +++ b/packages/core/src/account/__tests__/simple.test.ts @@ -1,3 +1,4 @@ +import type { Address } from "viem"; import { polygonMumbai, type Chain } from "viem/chains"; import { describe, it } from "vitest"; import { getDefaultSimpleAccountFactoryAddress } from "../../index.js"; @@ -43,6 +44,64 @@ describe("Account Simple Tests", () => { '"0x18dfb3c7000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000002000000000000000000000000deadbeefdeadbeefdeadbeefdeadbeefdeadbeef0000000000000000000000008ba1f109551bd432803012645ac136ddd64dba720000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000004deadbeef000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004cafebabe00000000000000000000000000000000000000000000000000000000"' ); }); + + it("should correctly do base runtime validation when entrypoint are invalid", () => { + expect( + () => + new SimpleSmartContractAccount({ + entryPointAddress: 1 as unknown as Address, + chain, + owner, + factoryAddress: "0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef", + rpcClient: "ALCHEMY_RPC_URL", + }) + ).toThrowErrorMatchingInlineSnapshot(` + "[ + { + \\"code\\": \\"invalid_type\\", + \\"expected\\": \\"string\\", + \\"received\\": \\"number\\", + \\"path\\": [ + \\"entryPointAddress\\" + ], + \\"message\\": \\"Expected string, received number\\" + } + ]" + `); + }); + + it("should correctly do base runtime validation when multiple inputs are invalid", () => { + expect( + () => + new SimpleSmartContractAccount({ + entryPointAddress: 1 as unknown as Address, + chain: "0x1" as unknown as Chain, + owner, + factoryAddress: "0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef", + rpcClient: "ALCHEMY_RPC_URL", + }) + ).toThrowErrorMatchingInlineSnapshot(` + "[ + { + \\"code\\": \\"invalid_type\\", + \\"expected\\": \\"string\\", + \\"received\\": \\"number\\", + \\"path\\": [ + \\"entryPointAddress\\" + ], + \\"message\\": \\"Expected string, received number\\" + }, + { + \\"code\\": \\"custom\\", + \\"fatal\\": true, + \\"path\\": [ + \\"chain\\" + ], + \\"message\\": \\"Invalid input\\" + } + ]" + `); + }); }); const givenConnectedProvider = ({ diff --git a/packages/core/src/account/base.ts b/packages/core/src/account/base.ts index 670caf56c2..bcff8b3517 100644 --- a/packages/core/src/account/base.ts +++ b/packages/core/src/account/base.ts @@ -20,7 +20,12 @@ import type { SmartAccountSigner } from "../signer/types.js"; import { wrapSignatureWith6492 } from "../signer/utils.js"; import type { BatchUserOperationCallData } from "../types.js"; import { getDefaultEntryPointAddress } from "../utils/defaults.js"; -import type { ISmartContractAccount, SignTypedDataParams } from "./types.js"; +import { createBaseSmartAccountParamsSchema } from "./schema.js"; +import type { + BaseSmartAccountParams, + ISmartContractAccount, + SignTypedDataParams, +} from "./types.js"; export enum DeploymentState { UNDEFINED = "0x0", @@ -28,31 +33,6 @@ export enum DeploymentState { DEPLOYED = "0x2", } -export interface BaseSmartAccountParams< - TTransport extends SupportedTransports = Transport -> { - rpcClient: string | PublicErc4337Client; - factoryAddress: Address; - chain: Chain; - - /** - * The address of the entry point contract. - * If not provided, the default entry point contract will be used. - * Check out https://docs.alchemy.com/reference/eth-supportedentrypoints for all the supported entrypoints - */ - entryPointAddress?: Address; - - /** - * Owner account signer for the account if there is one. - */ - owner?: SmartAccountSigner | undefined; - - /** - * The address of the account if it is already deployed. - */ - accountAddress?: Address; -} - export abstract class BaseSmartContractAccount< TTransport extends SupportedTransports = Transport > implements ISmartContractAccount @@ -72,6 +52,8 @@ export abstract class BaseSmartContractAccount< | PublicErc4337Client; constructor(params: BaseSmartAccountParams) { + createBaseSmartAccountParamsSchema().parse(params); + this.entryPointAddress = params.entryPointAddress ?? getDefaultEntryPointAddress(params.chain); diff --git a/packages/core/src/account/schema.ts b/packages/core/src/account/schema.ts new file mode 100644 index 0000000000..7d25a0ce56 --- /dev/null +++ b/packages/core/src/account/schema.ts @@ -0,0 +1,22 @@ +import { Address } from "abitype/zod"; +import type { Transport } from "viem"; +import z from "zod"; +import { createPublicErc4337ClientSchema } from "../client/schema.js"; +import type { SupportedTransports } from "../client/types"; +import { SignerSchema } from "../signer/schema.js"; +import { ChainSchema } from "../utils/index.js"; + +export const createBaseSmartAccountParamsSchema = < + TTransport extends SupportedTransports = Transport +>() => + z.object({ + rpcClient: z.union([ + z.string(), + createPublicErc4337ClientSchema(), + ]), + factoryAddress: Address, + owner: SignerSchema.optional(), + entryPointAddress: Address.optional(), + chain: ChainSchema, + accountAddress: Address.optional(), + }); diff --git a/packages/core/src/account/simple.ts b/packages/core/src/account/simple.ts index 6c2a2aec70..c88f0c56c0 100644 --- a/packages/core/src/account/simple.ts +++ b/packages/core/src/account/simple.ts @@ -9,12 +9,10 @@ import { } from "viem"; import { SimpleAccountAbi } from "../abis/SimpleAccountAbi.js"; import { SimpleAccountFactoryAbi } from "../abis/SimpleAccountFactoryAbi.js"; -import type { BatchUserOperationCallData } from "../types.js"; -import { - BaseSmartContractAccount, - type BaseSmartAccountParams, -} from "./base.js"; import type { SmartAccountSigner } from "../signer/types.js"; +import type { BatchUserOperationCallData } from "../types.js"; +import { BaseSmartContractAccount } from "./base.js"; +import type { BaseSmartAccountParams } from "./types.js"; export interface SimpleSmartAccountParams< TTransport extends Transport | FallbackTransport = Transport diff --git a/packages/core/src/account/types.ts b/packages/core/src/account/types.ts index 9a044a6b8c..2b4cccadce 100644 --- a/packages/core/src/account/types.ts +++ b/packages/core/src/account/types.ts @@ -1,11 +1,18 @@ import type { Address } from "abitype"; -import type { Hash, Hex } from "viem"; +import type { Hash, Hex, Transport } from "viem"; import type { SignTypedDataParameters } from "viem/accounts"; +import type { z } from "zod"; +import type { SupportedTransports } from "../client/types"; import type { SmartAccountSigner } from "../signer/types"; import type { BatchUserOperationCallData } from "../types"; +import type { createBaseSmartAccountParamsSchema } from "./schema"; export type SignTypedDataParams = Omit; +export type BaseSmartAccountParams< + TTransport extends SupportedTransports = Transport +> = z.infer>>; + export interface ISmartContractAccount { /** * @returns the init code for the account diff --git a/packages/core/src/client/schema.ts b/packages/core/src/client/schema.ts new file mode 100644 index 0000000000..147b0213ed --- /dev/null +++ b/packages/core/src/client/schema.ts @@ -0,0 +1,17 @@ +import type { Transport } from "viem"; +import { z } from "zod"; +import type { PublicErc4337Client, SupportedTransports } from "./types"; + +export const createPublicErc4337ClientSchema = < + TTransport extends SupportedTransports = Transport +>() => + z.custom>((provider) => { + return ( + provider != null && + typeof provider === "object" && + "request" in provider && + "type" in provider && + "key" in provider && + "name" in provider + ); + }); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index ea59f25511..79f11583cf 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -7,11 +7,14 @@ export { SimpleAccountAbi } from "./abis/SimpleAccountAbi.js"; export { SimpleAccountFactoryAbi } from "./abis/SimpleAccountFactoryAbi.js"; export { BaseSmartContractAccount } from "./account/base.js"; -export type { BaseSmartAccountParams } from "./account/base.js"; +export { createBaseSmartAccountParamsSchema } from "./account/schema.js"; export { SimpleSmartContractAccount } from "./account/simple.js"; export type { SimpleSmartAccountParams } from "./account/simple.js"; export type * from "./account/types.js"; +export type { BaseSmartAccountParams } from "./account/types.js"; + export { LocalAccountSigner } from "./signer/local-account.js"; +export { SignerSchema } from "./signer/schema.js"; export type { SmartAccountSigner } from "./signer/types.js"; export { verifyEIP6492Signature, @@ -24,6 +27,7 @@ export { createPublicErc4337FromClient, erc4337ClientActions, } from "./client/create-client.js"; +export { createPublicErc4337ClientSchema } from "./client/schema.js"; export type * from "./client/types.js"; export { @@ -33,6 +37,10 @@ export { } from "./ens/utils.js"; export { SmartAccountProvider, noOpMiddleware } from "./provider/base.js"; +export { + createSmartAccountProviderConfigSchema, + SmartAccountProviderOptsSchema, +} from "./provider/schema.js"; export type * from "./provider/types.js"; export type * from "./types.js"; @@ -48,6 +56,7 @@ export { getDefaultSimpleAccountFactoryAddress, getUserOperationHash, resolveProperties, + ChainSchema, } from "./utils/index.js"; export { Logger } from "./logger.js"; diff --git a/packages/core/src/provider/__tests__/base.test.ts b/packages/core/src/provider/__tests__/base.test.ts index ca6e110b96..096f0e90d4 100644 --- a/packages/core/src/provider/__tests__/base.test.ts +++ b/packages/core/src/provider/__tests__/base.test.ts @@ -244,10 +244,11 @@ describe("Base Tests", () => { "[ { \\"code\\": \\"custom\\", - \\"message\\": \\"Invalid input\\", + \\"fatal\\": true, \\"path\\": [ \\"chain\\" - ] + ], + \\"message\\": \\"Invalid input\\" }, { \\"code\\": \\"invalid_type\\", diff --git a/packages/core/src/provider/base.ts b/packages/core/src/provider/base.ts index cd8938ecf3..5f8844e21e 100644 --- a/packages/core/src/provider/base.ts +++ b/packages/core/src/provider/base.ts @@ -41,7 +41,7 @@ import { resolveProperties, type Deferrable, } from "../utils/index.js"; -import { SmartAccountProviderConfigSchema } from "./schema.js"; +import { createSmartAccountProviderConfigSchema } from "./schema.js"; import type { AccountMiddlewareFn, AccountMiddlewareOverrideFn, @@ -85,7 +85,7 @@ export class SmartAccountProvider< | PublicErc4337Client; constructor(config: SmartAccountProviderConfig) { - SmartAccountProviderConfigSchema().parse(config); + createSmartAccountProviderConfigSchema().parse(config); const { rpcProvider, entryPointAddress, chain, opts } = config; diff --git a/packages/core/src/provider/schema.ts b/packages/core/src/provider/schema.ts index 74ec5585db..5d0e17304c 100644 --- a/packages/core/src/provider/schema.ts +++ b/packages/core/src/provider/schema.ts @@ -1,8 +1,9 @@ -import { Address as zAddress } from "abitype/zod"; -import type { Chain, Transport } from "viem"; +import { Address } from "abitype/zod"; +import type { Transport } from "viem"; import z from "zod"; -import type { PublicErc4337Client, SupportedTransports } from "../client/types"; -import { getChain } from "../utils/index.js"; +import { createPublicErc4337ClientSchema } from "../client/schema.js"; +import type { SupportedTransports } from "../client/types"; +import { ChainSchema } from "../utils/index.js"; export const SmartAccountProviderOptsSchema = z.object({ /** @@ -26,43 +27,15 @@ export const SmartAccountProviderOptsSchema = z.object({ minPriorityFeePerBid: z.bigint().optional().default(100_000_000n), }); -export const SmartAccountProviderConfigSchema = < +export const createSmartAccountProviderConfigSchema = < TTransport extends SupportedTransports = Transport >() => { return z.object({ rpcProvider: z.union([ z.string(), - z - .any() - .refine>( - (provider): provider is PublicErc4337Client => { - return ( - typeof provider === "object" && - "request" in provider && - "type" in provider && - "key" in provider && - "name" in provider - ); - } - ), + createPublicErc4337ClientSchema(), ]), - - chain: z.any().refine((chain): chain is Chain => { - if ( - !(typeof chain === "object") || - !("id" in chain) || - typeof chain.id !== "number" - ) { - return false; - } - - try { - return getChain(chain.id) !== undefined; - } catch { - return false; - } - }), - + chain: ChainSchema, /** * Optional entry point contract address for override if needed. * If not provided, the entry point contract address for the provider is the connected account's entry point contract, @@ -71,8 +44,7 @@ export const SmartAccountProviderConfigSchema = < * Refer to https://docs.alchemy.com/reference/eth-supportedentrypoints for all the supported entrypoints * when using Alchemy as your RPC provider. */ - entryPointAddress: zAddress.optional(), - + entryPointAddress: Address.optional(), opts: SmartAccountProviderOptsSchema.optional(), }); }; diff --git a/packages/core/src/provider/types.ts b/packages/core/src/provider/types.ts index 1052144b77..22cd8038cd 100644 --- a/packages/core/src/provider/types.ts +++ b/packages/core/src/provider/types.ts @@ -26,8 +26,8 @@ import type { } from "../types.js"; import type { Deferrable } from "../utils"; import type { - SmartAccountProviderConfigSchema, SmartAccountProviderOptsSchema, + createSmartAccountProviderConfigSchema, } from "./schema.js"; type WithRequired = Required>; @@ -87,7 +87,9 @@ export type SmartAccountProviderOpts = z.infer< export type SmartAccountProviderConfig< TTransport extends SupportedTransports = Transport -> = z.infer>>; +> = z.infer< + ReturnType> +>; // TODO: this also will need to implement EventEmitteer export interface ISmartAccountProvider< diff --git a/packages/core/src/signer/schema.ts b/packages/core/src/signer/schema.ts new file mode 100644 index 0000000000..bd03edbde9 --- /dev/null +++ b/packages/core/src/signer/schema.ts @@ -0,0 +1,13 @@ +import { z } from "zod"; +import type { SmartAccountSigner } from "./types"; + +export const SignerSchema = z.custom((signer) => { + return ( + signer != null && + typeof signer === "object" && + "signerType" in signer && + "signMessage" in signer && + "signTypedData" in signer && + "getAddress" in signer + ); +}); diff --git a/packages/core/src/utils/index.ts b/packages/core/src/utils/index.ts index 33b1fba35c..3b6c49269e 100644 --- a/packages/core/src/utils/index.ts +++ b/packages/core/src/utils/index.ts @@ -158,4 +158,5 @@ export function defineReadOnly( export * from "./bigint.js"; export * from "./defaults.js"; +export * from "./schema.js"; export * from "./userop.js"; diff --git a/packages/core/src/utils/schema.ts b/packages/core/src/utils/schema.ts new file mode 100644 index 0000000000..140dd4f42b --- /dev/null +++ b/packages/core/src/utils/schema.ts @@ -0,0 +1,20 @@ +import type { Chain } from "viem"; +import { z } from "zod"; +import { getChain } from "./index.js"; + +export const ChainSchema = z.custom((chain) => { + if ( + chain == null || + !(typeof chain === "object") || + !("id" in chain) || + typeof chain.id !== "number" + ) { + return false; + } + + try { + return getChain(chain.id) !== undefined; + } catch { + return false; + } +}); diff --git a/packages/ethers/src/__tests__/provider-adapter.test.ts b/packages/ethers/src/__tests__/provider-adapter.test.ts index 303c3526b3..a873cde4ec 100644 --- a/packages/ethers/src/__tests__/provider-adapter.test.ts +++ b/packages/ethers/src/__tests__/provider-adapter.test.ts @@ -49,10 +49,6 @@ const givenConnectedProvider = ({ rpcClient, }); - account.getAddress = vi.fn( - async () => "0xb856DBD4fA1A79a46D426f537455e7d3E79ab7c4" - ); - return account; } );