Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve how to specify the callee contract address #215

Merged
merged 9 commits into from
Jul 6, 2022
24 changes: 13 additions & 11 deletions contracts/dynamic-callee-contract/src/contract.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use cosmwasm_std::{
callable_point, dynamic_link, entry_point, to_vec, Addr, DepsMut, Env, GlobalApi, MessageInfo,
Response,
callable_point, dynamic_link, entry_point, Addr, Contract, DepsMut, Env, GlobalApi,
MessageInfo, Response,
};
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -53,18 +53,20 @@ fn pong_env() -> Env {
GlobalApi::env()
}

#[dynamic_link(contract_name = "dynamic_caller_contract")]
extern "C" {
fn should_never_be_called();
#[derive(Contract)]
struct Me {
address: Addr,
}

#[dynamic_link(Me)]
trait ReEntrance: Contract {
fn should_never_be_called(&self);
}

#[callable_point]
fn reentrancy(addr: Addr) {
GlobalApi::with_deps_mut(|deps| {
deps.storage
.set(b"dynamic_caller_contract", &to_vec(&addr).unwrap());
});
should_never_be_called()
fn reentrancy(address: Addr) {
let me = Me { address };
me.should_never_be_called()
}

// And declare a custom Error variant for the ones where you will want to make use of it
Expand Down
3 changes: 1 addition & 2 deletions contracts/dynamic-caller-contract/schema/query_msg.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
{
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "QueryMsg",
"type": "string",
"enum": []
"type": "object"
}
47 changes: 28 additions & 19 deletions contracts/dynamic-caller-contract/src/contract.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use cosmwasm_std::{
callable_point, dynamic_link, entry_point, to_vec, Addr, DepsMut, Env, MessageInfo, Response,
Uint128,
callable_point, dynamic_link, entry_point, from_slice, to_vec, Addr, Contract, DepsMut, Env,
MessageInfo, Response, Uint128,
};
use serde::{Deserialize, Serialize};
use std::fmt;
Expand All @@ -19,14 +19,19 @@ impl fmt::Display for ExampleStruct {
}
}

#[dynamic_link(contract_name = "dynamic_callee_contract")]
extern "C" {
fn pong(ping_num: u64) -> u64;
fn pong_with_struct(example: ExampleStruct) -> ExampleStruct;
fn pong_with_tuple(input: (String, i32)) -> (String, i32);
fn pong_with_tuple_takes_2_args(input1: String, input2: i32) -> (String, i32);
fn pong_env() -> Env;
fn reentrancy(addr: Addr);
#[derive(Contract)]
struct CalleeContract {
address: Addr,
}

#[dynamic_link(CalleeContract)]
trait Callee: Contract {
fn pong(&self, ping_num: u64) -> u64;
fn pong_with_struct(&self, example: ExampleStruct) -> ExampleStruct;
fn pong_with_tuple(&self, input: (String, i32)) -> (String, i32);
fn pong_with_tuple_takes_2_args(&self, input1: String, input2: i32) -> (String, i32);
fn pong_env(&self) -> Env;
fn reentrancy(&self, addr: Addr);
}

// Note, you can use StdResult in some functions where you do not
Expand Down Expand Up @@ -54,18 +59,20 @@ pub fn execute(
) -> Result<Response, ContractError> {
match msg {
ExecuteMsg::Ping { ping_num } => try_ping(deps, ping_num),
ExecuteMsg::TryReEntrancy {} => try_re_entrancy(env),
ExecuteMsg::TryReEntrancy {} => try_re_entrancy(deps, env),
}
}

pub fn try_ping(_deps: DepsMut, ping_num: Uint128) -> Result<Response, ContractError> {
let pong_ret = pong(ping_num.u128() as u64);
let struct_ret = pong_with_struct(ExampleStruct {
pub fn try_ping(deps: DepsMut, ping_num: Uint128) -> Result<Response, ContractError> {
let address = from_slice(&deps.storage.get(b"dynamic_callee_contract").unwrap())?;
let contract = CalleeContract { address };
let pong_ret = contract.pong(ping_num.u128() as u64);
let struct_ret = contract.pong_with_struct(ExampleStruct {
str_field: String::from("hello"),
u64_field: 100u64,
});
let tuple_ret = pong_with_tuple((String::from("hello"), 41));
let tuple_ret2 = pong_with_tuple_takes_2_args(String::from("hello"), 41);
let tuple_ret = contract.pong_with_tuple((String::from("hello"), 41));
let tuple_ret2 = contract.pong_with_tuple_takes_2_args(String::from("hello"), 41);

let mut res = Response::default();
res.add_attribute("returned_pong", pong_ret.to_string());
Expand All @@ -80,15 +87,17 @@ pub fn try_ping(_deps: DepsMut, ping_num: Uint128) -> Result<Response, ContractE
);
res.add_attribute(
"returned_contract_address",
pong_env().contract.address.to_string(),
contract.pong_env().contract.address.to_string(),
);
Ok(res)
}

pub fn try_re_entrancy(env: Env) -> Result<Response, ContractError> {
pub fn try_re_entrancy(deps: DepsMut, env: Env) -> Result<Response, ContractError> {
// It will be tried to call the should_never_be_called function below.
// But, should be blocked by VM host side normally because it's a reentrancy case.
reentrancy(env.contract.address);
let address = from_slice(&deps.storage.get(b"dynamic_callee_contract").unwrap())?;
let contract = CalleeContract { address };
contract.reentrancy(env.contract.address);
Ok(Response::default())
}

Expand Down
3 changes: 3 additions & 0 deletions contracts/dynamic-caller-contract/src/msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ pub enum ExecuteMsg {
Ping { ping_num: Uint128 },
TryReEntrancy {},
}

#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, JsonSchema)]
pub struct QueryMsg {}
20 changes: 10 additions & 10 deletions contracts/dynamic-caller-contract/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,28 @@ fn required_imports() -> Vec<(String, String, FunctionType)> {
vec![
(
String::from("stub_pong"),
String::from("dynamic_callee_contract"),
([Type::I32], [Type::I32]).into(),
String::from("CalleeContract"),
([Type::I32, Type::I32], [Type::I32]).into(),
),
(
String::from("stub_pong_with_struct"),
String::from("dynamic_callee_contract"),
([Type::I32], [Type::I32]).into(),
String::from("CalleeContract"),
([Type::I32, Type::I32], [Type::I32]).into(),
),
(
String::from("stub_pong_with_tuple"),
String::from("dynamic_callee_contract"),
([Type::I32], [Type::I32]).into(),
String::from("CalleeContract"),
([Type::I32, Type::I32], [Type::I32]).into(),
),
(
String::from("stub_pong_with_tuple_takes_2_args"),
String::from("dynamic_callee_contract"),
([Type::I32, Type::I32], [Type::I32]).into(),
String::from("CalleeContract"),
([Type::I32, Type::I32, Type::I32], [Type::I32]).into(),
),
(
String::from("stub_pong_env"),
String::from("dynamic_callee_contract"),
([], [Type::I32]).into(),
String::from("CalleeContract"),
([Type::I32], [Type::I32]).into(),
),
]
}
Expand Down
104 changes: 104 additions & 0 deletions packages/derive/src/contract.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use proc_macro2::TokenStream;
use quote::quote;
use syn::spanned::Spanned;
use syn::{Attribute, DeriveInput, Error, Fields, Ident, Result};

fn generate_get_address(address_field: &Ident) -> TokenStream {
quote! {
fn get_address(&self) -> cosmwasm_std::Addr {
self.#address_field.clone()
}
}
}

fn generate_set_address(address_field: &Ident) -> TokenStream {
quote! {
fn set_address(&mut self, address: cosmwasm_std::Addr) {
self.#address_field = address
}
}
}

fn generate_impl_contract(struct_id: &Ident, address_field: &Ident) -> TokenStream {
let get_fn = generate_get_address(address_field);
let set_fn = generate_set_address(address_field);
quote! {
impl Contract for #struct_id {
#get_fn
#set_fn
}
}
}

fn has_address_attribute(attrs: &[Attribute]) -> bool {
attrs.iter().filter(|a| a.path.is_ident("address")).count() > 0
}

/// scan fields and extraction a field specifying address.
/// The priority is "contract_address" -> "contract_addr"
/// -> "address" -> "addr"
fn scan_address_field(fields: &Fields) -> Option<Ident> {
let candidates = vec!["contract_address", "contract_addr", "address", "addr"];
for field in fields {
match &field.ident {
Some(id) => {
for candidate in &candidates {
if id == candidate {
return Some(id.clone());
}
}
}
None => continue,
}
}
None
}

fn find_address_field_id(fields: Fields) -> Result<Ident> {
let filtered = fields
.iter()
.filter(|field| has_address_attribute(&field.attrs));
match filtered.clone().count() {
0 => match scan_address_field(&fields) {
Some(id) => Ok(id),
None => Err(Error::new(
fields.span(),
"[Contract] There are no field specifying address.",
)),
},
1 => {
let field = filtered.last().unwrap().clone();
match field.ident {
Some(id) => Ok(id),
None => Err(Error::new(
field.span(),
"[Contract] The field attributed `address` has no name.",
)),
}
}
_ => Err(Error::new(
fields.span(),
"[Contract] Only one or zero fields can have `address` attribute.",
)),
}
}

/// derive `Contract` from a derive input. The input needs to be a struct.
pub fn derive_contract(input: DeriveInput) -> TokenStream {
match input.data {
syn::Data::Struct(struct_data) => match find_address_field_id(struct_data.fields) {
Ok(address_field_id) => generate_impl_contract(&input.ident, &address_field_id),
Err(e) => e.to_compile_error(),
},
syn::Data::Enum(enum_data) => Error::new(
enum_data.enum_token.span,
"[Contract] `derive(Contract)` cannot be applied to Enum.",
)
.to_compile_error(),
syn::Data::Union(union_data) => Error::new(
union_data.union_token.span,
"[Contract] `derive(Contract)` cannot be applied to Union.",
)
.to_compile_error(),
}
}
Loading