Skip to content

Commit

Permalink
feat(core): add state mutability computation from world param (#2049)
Browse files Browse the repository at this point in the history
* rework: add state mutability computation from world param

* fix: run fmt

* fix: apply review comments

* fix: address missing review comment
  • Loading branch information
glihm authored Jun 13, 2024
1 parent 78c88e5 commit af5be66
Show file tree
Hide file tree
Showing 16 changed files with 439 additions and 528 deletions.
235 changes: 50 additions & 185 deletions crates/dojo-lang/src/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,28 @@ use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
use cairo_lang_defs::plugin::{
DynGeneratedFileAuxData, PluginDiagnostic, PluginGeneratedFile, PluginResult,
};
use cairo_lang_diagnostics::Severity;
use cairo_lang_syntax::attribute::structured::{
Attribute, AttributeArg, AttributeArgVariant, AttributeListStructurize,
};
use cairo_lang_syntax::node::ast::MaybeModuleBody;
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::{ast, ids, Terminal, TypedStablePtr, TypedSyntaxNode};
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use dojo_types::system::Dependency;

use crate::plugin::{DojoAuxData, SystemAuxData, DOJO_CONTRACT_ATTR};
use crate::plugin::{DojoAuxData, SystemAuxData};
use crate::syntax::world_param::{self, WorldParamInjectionKind};
use crate::syntax::{self_param, utils as syntax_utils};

const ALLOW_REF_SELF_ARG: &str = "allow_ref_self";
const DOJO_INIT_FN: &str = "dojo_init";

pub struct DojoContract {
diagnostics: Vec<PluginDiagnostic>,
dependencies: HashMap<smol_str::SmolStr, Dependency>,
do_allow_ref_self: bool,
}

impl DojoContract {
pub fn from_module(db: &dyn SyntaxGroup, module_ast: ast::ItemModule) -> PluginResult {
let name = module_ast.name(db).text(db);

let attrs = module_ast.attributes(db).structurize(db);
let dojo_contract_attr = attrs.iter().find(|attr| attr.id.as_str() == DOJO_CONTRACT_ATTR);
let do_allow_ref_self = extract_allow_ref_self(dojo_contract_attr, db).unwrap_or_default();

let mut system =
DojoContract { diagnostics: vec![], dependencies: HashMap::new(), do_allow_ref_self };
let mut system = DojoContract { diagnostics: vec![], dependencies: HashMap::new() };
let mut has_event = false;
let mut has_storage = false;
let mut has_init = false;
Expand Down Expand Up @@ -182,14 +173,14 @@ impl DojoContract {
let fn_decl = fn_ast.declaration(db);
let fn_name = fn_decl.name(db).text(db);

let (params_str, _, world_removed) = self.rewrite_parameters(
let (params_str, was_world_injected) = self.rewrite_parameters(
db,
fn_decl.signature(db).parameters(db),
fn_ast.stable_ptr().untyped(),
);

let mut world_read = "";
if world_removed {
if was_world_injected {
world_read = "let world = self.world_dispatcher.read();";
}

Expand Down Expand Up @@ -303,166 +294,61 @@ impl DojoContract {
)]
}

/// Gets name, modifiers and type from a function parameter.
pub fn get_parameter_info(
&mut self,
db: &dyn SyntaxGroup,
param: ast::Param,
) -> (String, String, String) {
let name = param.name(db).text(db).trim().to_string();
let modifiers = param.modifiers(db).as_syntax_node().get_text(db).trim().to_string();
let param_type =
param.type_clause(db).ty(db).as_syntax_node().get_text(db).trim().to_string();

(name, modifiers, param_type)
}

/// Check if the function has a self parameter.
///
/// Returns
/// * a boolean indicating if `self` has to be added,
// * a boolean indicating if there is a `ref self` parameter.
pub fn check_self_parameter(
&mut self,
db: &dyn SyntaxGroup,
param_list: ast::ParamList,
) -> (bool, bool) {
let mut add_self = true;
let mut has_ref_self = false;
if !param_list.elements(db).is_empty() {
let (param_name, param_modifiers, param_type) =
self.get_parameter_info(db, param_list.elements(db)[0].clone());

if param_name.eq(&"self".to_string()) {
if param_modifiers.contains(&"ref".to_string())
&& param_type.eq(&"ContractState".to_string())
{
has_ref_self = true;
add_self = false;
}

if param_type.eq(&"@ContractState".to_string()) {
add_self = false;
}
}
};

(add_self, has_ref_self)
}

/// Check if the function has multiple IWorldDispatcher parameters.
///
/// Returns
/// * a boolean indicating if the function has multiple world dispatchers.
pub fn check_world_dispatcher(
&mut self,
db: &dyn SyntaxGroup,
param_list: ast::ParamList,
) -> bool {
let mut count = 0;

param_list.elements(db).iter().for_each(|param| {
let (_, _, param_type) = self.get_parameter_info(db, param.clone());

if param_type.eq(&"IWorldDispatcher".to_string()) {
count += 1;
}
});

count > 1
}

/// Rewrites parameter list by:
/// * adding `self` parameter if missing,
/// * removing `world` if present as first parameter (self excluded), as it will be read from
/// the first function statement.
/// * adding `self` parameter based on the `world` parameter mutability. If `world` is not
/// provided, a `View` is assumed.
/// * removing `world` if present as first parameter, as it will be read from the first
/// function statement.
///
/// Reports an error in case of:
/// * `ref self`, as systems are supposed to be 100% stateless,
/// * multiple IWorldDispatcher parameters.
/// * the `IWorldDispatcher` is not the first parameter (self excluded) and named 'world'.
/// * `self` used explicitly,
/// * multiple world parameters,
/// * the `world` parameter is not the first parameter and named 'world'.
///
/// Returns
/// * the list of parameters in a String
/// * a boolean indicating if `self` has been added
// * a boolean indicating if `world` parameter has been removed
/// * the list of parameters in a String.
/// * true if the world has to be injected (found as the first param).
pub fn rewrite_parameters(
&mut self,
db: &dyn SyntaxGroup,
param_list: ast::ParamList,
diagnostic_item: ids::SyntaxStablePtrId,
) -> (String, bool, bool) {
let (add_self, has_ref_self) = self.check_self_parameter(db, param_list.clone());
let has_multiple_world_dispatchers = self.check_world_dispatcher(db, param_list.clone());
fn_diagnostic_item: ids::SyntaxStablePtrId,
) -> (String, bool) {
self_param::check_parameter(db, &param_list, fn_diagnostic_item, &mut self.diagnostics);

let mut world_removed = false;
let world_injection = world_param::parse_world_injection(
db,
param_list.clone(),
fn_diagnostic_item,
&mut self.diagnostics,
);

let mut params = param_list
.elements(db)
.iter()
.enumerate()
.filter_map(|(idx, param)| {
let (name, modifiers, param_type) = self.get_parameter_info(db, param.clone());

if param_type.eq(&"IWorldDispatcher".to_string())
&& modifiers.eq(&"".to_string())
&& !has_multiple_world_dispatchers
{
let has_good_pos = (add_self && idx == 0) || (!add_self && idx == 1);
let has_good_name = name.eq(&"world".to_string());

if has_good_pos && has_good_name {
world_removed = true;
None
} else {
if !has_good_pos {
self.diagnostics.push(PluginDiagnostic {
stable_ptr: param.stable_ptr().untyped(),
message: "The IWorldDispatcher parameter must be the first \
parameter of the function (self excluded)."
.to_string(),
severity: Severity::Error,
});
}
.filter_map(|param| {
let (name, _, param_type) = syntax_utils::get_parameter_info(db, param.clone());

if !has_good_name {
self.diagnostics.push(PluginDiagnostic {
stable_ptr: param.stable_ptr().untyped(),
message: "The IWorldDispatcher parameter must be named 'world'."
.to_string(),
severity: Severity::Error,
});
}
Some(param.as_syntax_node().get_text(db))
}
// If the param is `IWorldDispatcher`, we don't need to keep it in the param list
// as it is flatten in the first statement.
if world_param::is_world_param(&name, &param_type) {
None
} else {
Some(param.as_syntax_node().get_text(db))
}
})
.collect::<Vec<_>>();

if has_multiple_world_dispatchers {
self.diagnostics.push(PluginDiagnostic {
stable_ptr: diagnostic_item,
message: "Only one parameter of type IWorldDispatcher is allowed.".to_string(),
severity: Severity::Error,
});
}

if has_ref_self && !self.do_allow_ref_self {
self.diagnostics.push(PluginDiagnostic {
stable_ptr: diagnostic_item,
message: "Functions of dojo::contract cannot have 'ref self' parameter."
.to_string(),
severity: Severity::Error,
});
}

if add_self {
params.insert(0, "self: @ContractState".to_string());
match world_injection {
WorldParamInjectionKind::None | WorldParamInjectionKind::View => {
params.insert(0, "self: @ContractState".to_string());
}
WorldParamInjectionKind::External => {
params.insert(0, "ref self: ContractState".to_string());
}
}

(params.join(", "), add_self, world_removed)
(params.join(", "), world_injection != WorldParamInjectionKind::None)
}

/// Rewrites function statements by adding the reading of `world` at first statement.
Expand Down Expand Up @@ -493,21 +379,23 @@ impl DojoContract {
) -> Vec<RewriteNode> {
let mut rewritten_fn = RewriteNode::from_ast(&fn_ast);

let (params_str, self_added, world_removed) = self.rewrite_parameters(
let (params_str, was_world_injected) = self.rewrite_parameters(
db,
fn_ast.declaration(db).signature(db).parameters(db),
fn_ast.stable_ptr().untyped(),
);

if self_added || world_removed {
let rewritten_params = rewritten_fn
.modify_child(db, ast::FunctionWithBody::INDEX_DECLARATION)
.modify_child(db, ast::FunctionDeclaration::INDEX_SIGNATURE)
.modify_child(db, ast::FunctionSignature::INDEX_PARAMETERS);
rewritten_params.set_str(params_str);
}

if world_removed {
// We always rewrite the params as the self parameter is added based on the
// world mutability.
let rewritten_params = rewritten_fn
.modify_child(db, ast::FunctionWithBody::INDEX_DECLARATION)
.modify_child(db, ast::FunctionDeclaration::INDEX_SIGNATURE)
.modify_child(db, ast::FunctionSignature::INDEX_PARAMETERS);
rewritten_params.set_str(params_str);

// If the world was injected, we also need to rewrite the statements of the function
// to ensure the `world` injection is effective.
if was_world_injected {
let rewritten_statements = rewritten_fn
.modify_child(db, ast::FunctionWithBody::INDEX_BODY)
.modify_child(db, ast::ExprBlock::INDEX_STATEMENTS);
Expand Down Expand Up @@ -557,26 +445,3 @@ impl DojoContract {
vec![RewriteNode::Copied(impl_ast.as_syntax_node())]
}
}

/// Extract the allow_ref_self attribute.
pub(crate) fn extract_allow_ref_self(
allow_ref_self_attr: Option<&Attribute>,
db: &dyn SyntaxGroup,
) -> Option<bool> {
let Some(attr) = allow_ref_self_attr else {
return None;
};

#[allow(clippy::collapsible_match)]
match &attr.args[..] {
[AttributeArg { variant: AttributeArgVariant::Unnamed(value), .. }] => match value {
ast::Expr::Path(path)
if path.as_syntax_node().get_text_without_trivia(db) == ALLOW_REF_SELF_ARG =>
{
Some(true)
}
_ => None,
},
_ => None,
}
}
Loading

0 comments on commit af5be66

Please sign in to comment.