Skip to content

Commit

Permalink
lang: Add non-8-byte discriminator support in declare_program! (#3103)
Browse files Browse the repository at this point in the history
  • Loading branch information
acheroncrypto authored Jul 22, 2024
1 parent ba33d5e commit e5bed20
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 43 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The minor version will be incremented upon a breaking change and the patch versi
- lang: Add `Account` utility type to get accounts from bytes ([#3091](https://github.com/coral-xyz/anchor/pull/3091)).
- client: Add option to pass in mock rpc client when using anchor_client ([#3053](https://github.com/coral-xyz/anchor/pull/3053)).
- lang: Get discriminator length dynamically ([#3101](https://github.com/coral-xyz/anchor/pull/3101)).
- lang: Add non-8-byte discriminator support in `declare_program!` ([#3103](https://github.com/coral-xyz/anchor/pull/3103)).

### Fixes

Expand Down
6 changes: 3 additions & 3 deletions lang/attribute/program/src/declare_program/mods/accounts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
}

let given_disc = &buf[..8];
let given_disc = &buf[..#discriminator.len()];
if &#discriminator != given_disc {
return Err(
anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch)
Expand Down Expand Up @@ -51,7 +51,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
#try_deserialize

fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
let mut data: &[u8] = &buf[8..];
let mut data: &[u8] = &buf[#discriminator.len()..];
AnchorDeserialize::deserialize(&mut data)
.map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
}
Expand All @@ -75,7 +75,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
#try_deserialize

fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
let data: &[u8] = &buf[8..];
let data: &[u8] = &buf[#discriminator.len()..];
let account = anchor_lang::__private::bytemuck::from_bytes(data);
Ok(*account)
}
Expand Down
12 changes: 4 additions & 8 deletions lang/attribute/program/src/declare_program/mods/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,11 @@ fn gen_internal_args_mod(idl: &Idl) -> proc_macro2::TokenStream {
}
};

let impl_discriminator = if ix.discriminator.len() == 8 {
let discriminator = gen_discriminator(&ix.discriminator);
quote! {
impl anchor_lang::Discriminator for #ix_struct_name {
const DISCRIMINATOR: &'static [u8] = &#discriminator;
}
let discriminator = gen_discriminator(&ix.discriminator);
let impl_discriminator = quote! {
impl anchor_lang::Discriminator for #ix_struct_name {
const DISCRIMINATOR: &'static [u8] = &#discriminator;
}
} else {
quote! {}
};

let impl_ix_data = quote! {
Expand Down
56 changes: 24 additions & 32 deletions lang/attribute/program/src/declare_program/mods/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,17 @@ fn gen_account(idl: &Idl) -> proc_macro2::TokenStream {
.iter()
.map(|acc| format_ident!("{}", acc.name))
.map(|name| quote! { #name(#name) });
let match_arms = idl.accounts.iter().map(|acc| {
let disc = gen_discriminator(&acc.discriminator);
let if_statements = idl.accounts.iter().map(|acc| {
let name = format_ident!("{}", acc.name);
let account = quote! {
#name::try_from_slice(&value[8..])
.map(Self::#name)
.map_err(Into::into)
};
quote! { #disc => #account }
let disc = gen_discriminator(&acc.discriminator);
let disc_len = acc.discriminator.len();
quote! {
if value.starts_with(&#disc) {
return #name::try_from_slice(&value[#disc_len..])
.map(Self::#name)
.map_err(Into::into)
}
}
});

quote! {
Expand All @@ -57,14 +59,8 @@ fn gen_account(idl: &Idl) -> proc_macro2::TokenStream {
type Error = anchor_lang::error::Error;

fn try_from(value: &[u8]) -> Result<Self> {
if value.len() < 8 {
return Err(ProgramError::InvalidArgument.into());
}

match &value[..8] {
#(#match_arms,)*
_ => Err(ProgramError::InvalidArgument.into()),
}
#(#if_statements)*
Err(ProgramError::InvalidArgument.into())
}
}
}
Expand All @@ -76,15 +72,17 @@ fn gen_event(idl: &Idl) -> proc_macro2::TokenStream {
.iter()
.map(|ev| format_ident!("{}", ev.name))
.map(|name| quote! { #name(#name) });
let match_arms = idl.events.iter().map(|ev| {
let disc = gen_discriminator(&ev.discriminator);
let if_statements = idl.events.iter().map(|ev| {
let name = format_ident!("{}", ev.name);
let event = quote! {
#name::try_from_slice(&value[8..])
.map(Self::#name)
.map_err(Into::into)
};
quote! { #disc => #event }
let disc = gen_discriminator(&ev.discriminator);
let disc_len = ev.discriminator.len();
quote! {
if value.starts_with(&#disc) {
return #name::try_from_slice(&value[#disc_len..])
.map(Self::#name)
.map_err(Into::into)
}
}
});

quote! {
Expand All @@ -109,14 +107,8 @@ fn gen_event(idl: &Idl) -> proc_macro2::TokenStream {
type Error = anchor_lang::error::Error;

fn try_from(value: &[u8]) -> Result<Self> {
if value.len() < 8 {
return Err(ProgramError::InvalidArgument.into());
}

match &value[..8] {
#(#match_arms,)*
_ => Err(ProgramError::InvalidArgument.into()),
}
#(#if_statements)*
Err(ProgramError::InvalidArgument.into())
}
}
}
Expand Down

0 comments on commit e5bed20

Please sign in to comment.