diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d06d65d96..1ccad0f035 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ The minor version will be incremented upon a breaking change and the patch versi - cli: Warn if `anchor-spl/idl-build` is missing ([#3133](https://github.com/coral-xyz/anchor/pull/3133)). - client: Add `internal_rpc` method for `mock` feature ([#3135](https://github.com/coral-xyz/anchor/pull/3135)). - lang: Add `#[instruction]` attribute proc-macro to override default instruction discriminators ([#3137](https://github.com/coral-xyz/anchor/pull/3137)). +- lang: Use associated discriminator constants instead of hardcoding in `#[account]` ([#3144](https://github.com/coral-xyz/anchor/pull/3144)). ### Fixes diff --git a/lang/attribute/account/src/lib.rs b/lang/attribute/account/src/lib.rs index d5bad8d4e4..7d12a9db0d 100644 --- a/lang/attribute/account/src/lib.rs +++ b/lang/attribute/account/src/lib.rs @@ -100,6 +100,7 @@ pub fn account( ); format!("{discriminator:?}").parse().unwrap() }; + let disc = quote! { #account_name::DISCRIMINATOR }; let owner_impl = { if namespace.is_empty() { @@ -162,18 +163,18 @@ pub fn account( #[automatically_derived] impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause { fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result { - if buf.len() < #discriminator.len() { + if buf.len() < #disc.len() { return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into()); } - let given_disc = &buf[..#discriminator.len()]; - if &#discriminator != given_disc { + let given_disc = &buf[..#disc.len()]; + if #disc != given_disc { return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str)); } Self::try_deserialize_unchecked(buf) } fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result { - let data: &[u8] = &buf[#discriminator.len()..]; + let data: &[u8] = &buf[#disc.len()..]; // Re-interpret raw bytes into the POD data structure. let account = anchor_lang::__private::bytemuck::from_bytes(data); // Copy out the bytes into a new, owned data structure. @@ -191,7 +192,7 @@ pub fn account( #[automatically_derived] impl #impl_gen anchor_lang::AccountSerialize for #account_name #type_gen #where_clause { fn try_serialize(&self, writer: &mut W) -> anchor_lang::Result<()> { - if writer.write_all(&#discriminator).is_err() { + if writer.write_all(#disc).is_err() { return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into()); } @@ -205,18 +206,18 @@ pub fn account( #[automatically_derived] impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause { fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result { - if buf.len() < #discriminator.len() { + if buf.len() < #disc.len() { return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into()); } - let given_disc = &buf[..#discriminator.len()]; - if &#discriminator != given_disc { + let given_disc = &buf[..#disc.len()]; + if #disc != given_disc { return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str)); } Self::try_deserialize_unchecked(buf) } fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result { - let mut data: &[u8] = &buf[#discriminator.len()..]; + let mut data: &[u8] = &buf[#disc.len()..]; AnchorDeserialize::deserialize(&mut data) .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into()) } diff --git a/lang/attribute/program/src/declare_program/mods/accounts.rs b/lang/attribute/program/src/declare_program/mods/accounts.rs index 87fdb02fc3..9b4821e992 100644 --- a/lang/attribute/program/src/declare_program/mods/accounts.rs +++ b/lang/attribute/program/src/declare_program/mods/accounts.rs @@ -7,6 +7,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream { let accounts = idl.accounts.iter().map(|acc| { let name = format_ident!("{}", acc.name); let discriminator = gen_discriminator(&acc.discriminator); + let disc = quote! { #name::DISCRIMINATOR }; let ty_def = idl .types @@ -17,12 +18,12 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream { let impls = { let try_deserialize = quote! { fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result { - if buf.len() < #discriminator.len() { + if buf.len() < #disc.len() { return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into()); } - let given_disc = &buf[..#discriminator.len()]; - if &#discriminator != given_disc { + let given_disc = &buf[..#disc.len()]; + if #disc != given_disc { return Err( anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch) .with_account_name(stringify!(#name)) @@ -36,7 +37,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream { IdlSerialization::Borsh => quote! { impl anchor_lang::AccountSerialize for #name { fn try_serialize(&self, writer: &mut W) -> anchor_lang::Result<()> { - if writer.write_all(&#discriminator).is_err() { + if writer.write_all(#disc).is_err() { return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into()); } if AnchorSerialize::serialize(self, writer).is_err() { @@ -51,7 +52,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream { #try_deserialize fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result { - let mut data: &[u8] = &buf[#discriminator.len()..]; + let mut data: &[u8] = &buf[#disc.len()..]; AnchorDeserialize::deserialize(&mut data) .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into()) } @@ -75,7 +76,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream { #try_deserialize fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result { - let data: &[u8] = &buf[#discriminator.len()..]; + let data: &[u8] = &buf[#disc.len()..]; let account = anchor_lang::__private::bytemuck::from_bytes(data); Ok(*account) }