Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lang: Use associated discriminator constants instead of hardcoding in #[account] #3144

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The minor version will be incremented upon a breaking change and the patch versi
- cli: Warn if `anchor-spl/idl-build` is missing ([#3133](https://github.com/coral-xyz/anchor/pull/3133)).
- client: Add `internal_rpc` method for `mock` feature ([#3135](https://github.com/coral-xyz/anchor/pull/3135)).
- lang: Add `#[instruction]` attribute proc-macro to override default instruction discriminators ([#3137](https://github.com/coral-xyz/anchor/pull/3137)).
- lang: Use associated discriminator constants instead of hardcoding in `#[account]` ([#3144](https://github.com/coral-xyz/anchor/pull/3144)).

### Fixes

Expand Down
19 changes: 10 additions & 9 deletions lang/attribute/account/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ pub fn account(
);
format!("{discriminator:?}").parse().unwrap()
};
let disc = quote! { #account_name::DISCRIMINATOR };

let owner_impl = {
if namespace.is_empty() {
Expand Down Expand Up @@ -162,18 +163,18 @@ pub fn account(
#[automatically_derived]
impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
if buf.len() < #discriminator.len() {
if buf.len() < #disc.len() {
return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
}
let given_disc = &buf[..#discriminator.len()];
if &#discriminator != given_disc {
let given_disc = &buf[..#disc.len()];
if #disc != given_disc {
return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
}
Self::try_deserialize_unchecked(buf)
}

fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
let data: &[u8] = &buf[#discriminator.len()..];
let data: &[u8] = &buf[#disc.len()..];
// Re-interpret raw bytes into the POD data structure.
let account = anchor_lang::__private::bytemuck::from_bytes(data);
// Copy out the bytes into a new, owned data structure.
Expand All @@ -191,7 +192,7 @@ pub fn account(
#[automatically_derived]
impl #impl_gen anchor_lang::AccountSerialize for #account_name #type_gen #where_clause {
fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> anchor_lang::Result<()> {
if writer.write_all(&#discriminator).is_err() {
if writer.write_all(#disc).is_err() {
return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
}

Expand All @@ -205,18 +206,18 @@ pub fn account(
#[automatically_derived]
impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
if buf.len() < #discriminator.len() {
if buf.len() < #disc.len() {
return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
}
let given_disc = &buf[..#discriminator.len()];
if &#discriminator != given_disc {
let given_disc = &buf[..#disc.len()];
if #disc != given_disc {
return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
}
Self::try_deserialize_unchecked(buf)
}

fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
let mut data: &[u8] = &buf[#discriminator.len()..];
let mut data: &[u8] = &buf[#disc.len()..];
AnchorDeserialize::deserialize(&mut data)
.map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
}
Expand Down
13 changes: 7 additions & 6 deletions lang/attribute/program/src/declare_program/mods/accounts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
let accounts = idl.accounts.iter().map(|acc| {
let name = format_ident!("{}", acc.name);
let discriminator = gen_discriminator(&acc.discriminator);
let disc = quote! { #name::DISCRIMINATOR };

let ty_def = idl
.types
Expand All @@ -17,12 +18,12 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
let impls = {
let try_deserialize = quote! {
fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
if buf.len() < #discriminator.len() {
if buf.len() < #disc.len() {
return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
}

let given_disc = &buf[..#discriminator.len()];
if &#discriminator != given_disc {
let given_disc = &buf[..#disc.len()];
if #disc != given_disc {
return Err(
anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch)
.with_account_name(stringify!(#name))
Expand All @@ -36,7 +37,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
IdlSerialization::Borsh => quote! {
impl anchor_lang::AccountSerialize for #name {
fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> anchor_lang::Result<()> {
if writer.write_all(&#discriminator).is_err() {
if writer.write_all(#disc).is_err() {
return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
}
if AnchorSerialize::serialize(self, writer).is_err() {
Expand All @@ -51,7 +52,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
#try_deserialize

fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
let mut data: &[u8] = &buf[#discriminator.len()..];
let mut data: &[u8] = &buf[#disc.len()..];
AnchorDeserialize::deserialize(&mut data)
.map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
}
Expand All @@ -75,7 +76,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
#try_deserialize

fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
let data: &[u8] = &buf[#discriminator.len()..];
let data: &[u8] = &buf[#disc.len()..];
let account = anchor_lang::__private::bytemuck::from_bytes(data);
Ok(*account)
}
Expand Down
Loading