Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support generic return types #234

Merged
merged 1 commit into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 46 additions & 38 deletions sylvia-derive/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use crate::parser::{
Custom, MsgAttr, MsgType, OverrideEntryPoint, OverrideEntryPoints,
};
use crate::strip_generics::StripGenerics;
use crate::utils::{extract_return_type, filter_wheres, process_fields};
use crate::utils::{
as_where_clause, brace_generics, extract_return_type, filter_wheres, process_fields,
};
use crate::variant_descs::{AsVariantDescs, VariantDescs};
use convert_case::{Case, Casing};
use proc_macro2::{Span, TokenStream};
Expand Down Expand Up @@ -100,14 +102,6 @@ impl<'a> StructMessage<'a> {
custom,
} = self;

let where_clause = if !wheres.is_empty() {
quote! {
where #(#wheres,)*
}
} else {
quote! {}
};

let ctx_type = msg_attr
.msg_type()
.emit_ctx_type(&custom.query_or_default());
Expand All @@ -119,21 +113,9 @@ impl<'a> StructMessage<'a> {
});
let fields = fields.iter().map(MsgField::emit);

let generics = if generics.is_empty() {
quote! {}
} else {
quote! {
<#(#generics,)*>
}
};

let unused_generics = if unused_generics.is_empty() {
quote! {}
} else {
quote! {
<#(#unused_generics,)*>
}
};
let where_clause = as_where_clause(wheres);
let generics = brace_generics(generics);
let unused_generics = brace_generics(unused_generics);

#[cfg(not(tarpaulin_include))]
{
Expand Down Expand Up @@ -282,18 +264,33 @@ impl<'a> EnumMessage<'a> {
let ctx_type = msg_ty.emit_ctx_type(query_type);
let dispatch_type = msg_ty.emit_result_type(resp_type, &parse_quote!(C::Error));

let all_generics = if all_generics.is_empty() {
let all_generics = brace_generics(all_generics);
let phantom = if generics.is_empty() {
quote! {}
} else if MsgType::Query == *msg_ty {
quote! {
#[returns((#(#generics,)*))]
_Phantom(std::marker::PhantomData<( #(#generics,)* )>),
}
} else {
quote! { <#(#all_generics,)*> }
quote! {
_Phantom(std::marker::PhantomData<( #(#generics,)* )>),
}
};

let generics = if generics.is_empty() {
quote! {}
let match_arms = if !generics.is_empty() {
quote! {
#(#match_arms,)*
_Phantom(_) => unreachable!(),
}
} else {
quote! { <#(#generics,)*> }
quote! {
#(#match_arms,)*
}
};

let generics = brace_generics(generics);

let unique_enum_name = Ident::new(&format!("{}{}", trait_name, name), name.span());

#[cfg(not(tarpaulin_include))]
Expand All @@ -305,6 +302,7 @@ impl<'a> EnumMessage<'a> {
#[serde(rename_all="snake_case")]
pub enum #unique_enum_name #generics {
#(#variants,)*
#phantom
}
pub type #name #generics = #unique_enum_name #generics;
}
Expand All @@ -316,6 +314,7 @@ impl<'a> EnumMessage<'a> {
#[serde(rename_all="snake_case")]
pub enum #unique_enum_name #generics {
#(#variants,)*
#phantom
}
pub type #name #generics = #unique_enum_name #generics;
}
Expand All @@ -334,7 +333,7 @@ impl<'a> EnumMessage<'a> {
use #unique_enum_name::*;

match self {
#(#match_arms,)*
#match_arms
}
}
pub const fn messages() -> [&'static str; #msgs_cnt] {
Expand Down Expand Up @@ -507,10 +506,12 @@ impl<'a> MsgVariant<'a> {
let return_type = if let MsgAttr::Query { resp_type } = msg_attr {
match resp_type {
Some(resp_type) => {
generics_checker.visit_path(&parse_quote! { #resp_type });
quote! {#resp_type}
}
None => {
let return_type = extract_return_type(&sig.output);
generics_checker.visit_path(return_type);
quote! {#return_type}
}
}
Expand Down Expand Up @@ -621,7 +622,11 @@ impl<'a> MsgVariant<'a> {
}
}

pub fn emit_querier_impl(&self, trait_module: Option<&Path>) -> TokenStream {
pub fn emit_querier_impl(
&self,
trait_module: Option<&Path>,
unbonded_generics: &Vec<&GenericParam>,
) -> TokenStream {
let sylvia = crate_module();
let Self {
name,
Expand All @@ -637,6 +642,12 @@ impl<'a> MsgVariant<'a> {
.map(|module| quote! { #module ::QueryMsg })
.unwrap_or_else(|| quote! { QueryMsg });

let msg = if !unbonded_generics.is_empty() {
quote! { #msg ::< #(#unbonded_generics,)* > }
} else {
quote! { #msg }
};

#[cfg(not(tarpaulin_include))]
{
quote! {
Expand Down Expand Up @@ -741,18 +752,15 @@ impl<'a> MsgVariants<'a> {
let methods_impl = variants
.iter()
.filter(|variant| variant.msg_type == MsgType::Query)
.map(|variant| variant.emit_querier_impl(None));
.map(|variant| variant.emit_querier_impl(None, unbonded_generics));

let methods_declaration = variants
.iter()
.filter(|variant| variant.msg_type == MsgType::Query)
.map(MsgVariant::emit_querier_declaration);

let querier = if !unbonded_generics.is_empty() {
quote! { Querier < #(#unbonded_generics,)* > }
} else {
quote! { Querier }
};
let braced_generics = brace_generics(unbonded_generics);
let querier = quote! { Querier #braced_generics };

#[cfg(not(tarpaulin_include))]
{
Expand Down Expand Up @@ -803,7 +811,7 @@ impl<'a> MsgVariants<'a> {
let methods_impl = variants
.iter()
.filter(|variant| variant.msg_type == MsgType::Query)
.map(|variant| variant.emit_querier_impl(trait_module));
.map(|variant| variant.emit_querier_impl(trait_module, unbonded_generics));

let mut querier = trait_module
.map(|module| quote! { #module ::Querier })
Expand Down
20 changes: 18 additions & 2 deletions sylvia-derive/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use proc_macro2::TokenStream;
use proc_macro_error::emit_error;
use quote::quote;
use syn::spanned::Spanned;
use syn::visit::Visit;
use syn::{
FnArg, GenericArgument, GenericParam, Path, PathArguments, ReturnType, Signature, Type,
WhereClause, WherePredicate,
parse_quote, FnArg, GenericArgument, GenericParam, Path, PathArguments, ReturnType, Signature,
Type, WhereClause, WherePredicate,
};

use crate::check_generics::CheckGenerics;
Expand Down Expand Up @@ -84,3 +86,17 @@

&type_path.path
}

pub fn as_where_clause(where_predicates: &[&WherePredicate]) -> Option<WhereClause> {
match where_predicates.is_empty() {
true => None,
false => Some(parse_quote! { where #(#where_predicates),* }),

Check warning on line 93 in sylvia-derive/src/utils.rs

View check run for this annotation

Codecov / codecov/patch

sylvia-derive/src/utils.rs#L93

Added line #L93 was not covered by tests
}
}

pub fn brace_generics(unbonded_generics: &[&GenericParam]) -> TokenStream {
match unbonded_generics.is_empty() {
true => quote! {},
false => quote! { < #(#unbonded_generics,)* > },
}
}
37 changes: 28 additions & 9 deletions sylvia/tests/generics.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
use cosmwasm_schema::cw_serde;

pub mod cw1 {
use cosmwasm_std::{CosmosMsg, CustomMsg, Response, StdError};
use cosmwasm_std::{CosmosMsg, CustomMsg, CustomQuery, Response, StdError};

use serde::Deserialize;
use serde::{de::DeserializeOwned, Deserialize};
use sylvia::types::{ExecCtx, QueryCtx};
use sylvia_derive::interface;

#[interface(module=msg)]
pub trait Cw1<Msg, Param>
#[sv::custom(msg=Msg)]
pub trait Cw1<Msg, Param, QueryRet>
where
for<'msg_de> Msg: CustomMsg + Deserialize<'msg_de>,
Param: sylvia::types::CustomMsg,
for<'msg_de> QueryRet: CustomQuery + DeserializeOwned,
{
type Error: From<StdError>;

#[msg(exec)]
fn execute(&self, ctx: ExecCtx, msgs: Vec<CosmosMsg<Msg>>)
-> Result<Response, Self::Error>;
fn execute(
&self,
ctx: ExecCtx,
msgs: Vec<CosmosMsg<Msg>>,
) -> Result<Response<Msg>, Self::Error>;

#[msg(query)]
fn query(&self, ctx: QueryCtx, param: Param) -> Result<String, Self::Error>;
fn some_query(&self, ctx: QueryCtx, param: Param) -> Result<QueryRet, Self::Error>;
}
}

Expand All @@ -29,16 +34,30 @@ pub struct ExternalMsg;
impl cosmwasm_std::CustomMsg for ExternalMsg {}
impl sylvia::types::CustomMsg for ExternalMsg {}

#[cw_serde]
pub struct ExternalQuery;
impl cosmwasm_std::CustomQuery for ExternalQuery {}

#[cfg(test)]
mod tests {
use cosmwasm_std::{CosmosMsg, Empty};
use cosmwasm_std::{testing::mock_dependencies, Addr, CosmosMsg, Empty, QuerierWrapper};

use crate::ExternalMsg;
use crate::{cw1::Querier, ExternalMsg, ExternalQuery};

#[test]
fn construct_messages() {
let _ = crate::cw1::QueryMsg::query(ExternalMsg {});
let contract = Addr::unchecked("contract");

let _ = crate::cw1::QueryMsg::<_, Empty>::some_query(ExternalMsg {});
let _ = crate::cw1::ExecMsg::execute(vec![CosmosMsg::Custom(ExternalMsg {})]);
let _ = crate::cw1::ExecMsg::execute(vec![CosmosMsg::Custom(Empty {})]);

// Generic Querier
let deps = mock_dependencies();
let querier: QuerierWrapper<ExternalQuery> = QuerierWrapper::new(&deps.querier);

let cw1_querier = crate::cw1::BoundQuerier::borrowed(&contract, &querier);
let _: Result<ExternalQuery, _> = Querier::some_query(&cw1_querier, ExternalMsg {});
let _: Result<ExternalQuery, _> = cw1_querier.some_query(ExternalMsg {});
}
}
Loading