Skip to content

Commit

Permalink
Added selector inline macro for starknet.
Browse files Browse the repository at this point in the history
commit-id:5bae488f
  • Loading branch information
orizi committed Aug 16, 2023
1 parent 3b1c538 commit d62fa48
Show file tree
Hide file tree
Showing 18 changed files with 359 additions and 37 deletions.
4 changes: 2 additions & 2 deletions crates/cairo-lang-compiler/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ impl RootDatabaseBuilder {

pub fn with_inline_macro_plugin(
&mut self,
name: String,
name: &str,
plugin: Arc<dyn InlineMacroExprPlugin>,
) -> &mut Self {
self.inline_macro_plugins.insert(name, plugin);
self.inline_macro_plugins.insert(name.into(), plugin);
self
}

Expand Down
2 changes: 2 additions & 0 deletions crates/cairo-lang-language-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use cairo_lang_semantic::items::imp::ImplId;
use cairo_lang_semantic::items::us::get_use_segments;
use cairo_lang_semantic::resolve::{AsSegments, ResolvedConcreteItem, ResolvedGenericItem};
use cairo_lang_semantic::{SemanticDiagnostic, TypeLongId};
use cairo_lang_starknet::inline_macros::selector::SelectorMacro;
use cairo_lang_starknet::plugin::StarkNetPlugin;
use cairo_lang_syntax::node::ast::PathSegment;
use cairo_lang_syntax::node::helpers::GetIdentifier;
Expand Down Expand Up @@ -81,6 +82,7 @@ pub async fn serve_language_service() {
let db = RootDatabase::builder()
.with_cfg(CfgSet::from_iter([Cfg::name("test")]))
.with_macro_plugin(Arc::new(StarkNetPlugin::default()))
.with_inline_macro_plugin(SelectorMacro::NAME, Arc::new(SelectorMacro))
.build()
.expect("Failed to initialize Cairo compiler database.");

Expand Down
2 changes: 1 addition & 1 deletion crates/cairo-lang-semantic/src/inline_macros/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub fn get_default_inline_macro_plugins() -> OrderedHashMap<String, Arc<dyn Inli
res
}

fn unsupported_bracket_diagnostic(
pub fn unsupported_bracket_diagnostic(
db: &dyn SyntaxGroup,
macro_ast: &ast::ExprInlineMacro,
) -> InlinePluginResult {
Expand Down
271 changes: 271 additions & 0 deletions crates/cairo-lang-semantic/src/patcher.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
use cairo_lang_defs::db::DefsGroup;
use cairo_lang_filesystem::span::{TextOffset, TextSpan, TextWidth};
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::{SyntaxNode, TypedSyntaxNode};
use cairo_lang_utils::extract_matches;
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;

/// Interface for modifying syntax nodes.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum RewriteNode {
/// A rewrite node that represents a trimmed copy of a syntax node:
/// one with the leading and trailing trivia excluded.
Trimmed {
node: SyntaxNode,
trim_left: bool,
trim_right: bool,
},
Copied(SyntaxNode),
Modified(ModifiedNode),
Text(String),
}
impl RewriteNode {
pub fn new_trimmed(syntax_node: SyntaxNode) -> Self {
Self::Trimmed { node: syntax_node, trim_left: true, trim_right: true }
}

pub fn new_modified(children: Vec<RewriteNode>) -> Self {
Self::Modified(ModifiedNode { children: Some(children) })
}

/// Creates a rewrite node from an AST object.
pub fn from_ast<T: TypedSyntaxNode>(node: &T) -> Self {
RewriteNode::Copied(node.as_syntax_node())
}

/// Prepares a node for modification.
pub fn modify(&mut self, db: &dyn SyntaxGroup) -> &mut ModifiedNode {
match self {
RewriteNode::Copied(syntax_node) => {
*self = RewriteNode::new_modified(
syntax_node.children(db).map(RewriteNode::Copied).collect(),
);
extract_matches!(self, RewriteNode::Modified)
}

RewriteNode::Trimmed { node, trim_left, trim_right } => {
let num_children = node.children(db).len();
let mut new_children = Vec::new();

// Get the index of the leftmost nonempty child.
let Some(left_idx) =
node.children(db).position(|child| child.width(db) != TextWidth::default())
else {
*self = RewriteNode::Modified(ModifiedNode { children: None });
return extract_matches!(self, RewriteNode::Modified);
};
// Get the index of the rightmost nonempty child.
let right_idx = node
.children(db)
.rposition(|child| child.width(db) != TextWidth::default())
.unwrap();
new_children.extend(itertools::repeat_n(
RewriteNode::Modified(ModifiedNode { children: None }),
left_idx,
));

// The number of children between the first and last nonempty nodes.
let num_middle = right_idx - left_idx + 1;
let mut children_iter = node.children(db).skip(left_idx);
match num_middle {
1 => {
new_children.push(RewriteNode::Trimmed {
node: children_iter.next().unwrap(),
trim_left: *trim_left,
trim_right: *trim_right,
});
}
_ => {
new_children.push(RewriteNode::Trimmed {
node: children_iter.next().unwrap(),
trim_left: *trim_left,
trim_right: false,
});
for _ in 0..(num_middle - 2) {
let child = children_iter.next().unwrap();
new_children.push(RewriteNode::Copied(child));
}
new_children.push(RewriteNode::Trimmed {
node: children_iter.next().unwrap(),
trim_left: false,
trim_right: *trim_right,
});
}
};
new_children.extend(itertools::repeat_n(
RewriteNode::Modified(ModifiedNode { children: None }),
num_children - right_idx - 1,
));

*self = RewriteNode::Modified(ModifiedNode { children: Some(new_children) });
extract_matches!(self, RewriteNode::Modified)
}
RewriteNode::Modified(modified) => modified,
RewriteNode::Text(_) => panic!("A text node can't be modified"),
}
}

/// Prepares a node for modification and returns a specific child.
pub fn modify_child(&mut self, db: &dyn SyntaxGroup, index: usize) -> &mut RewriteNode {
if matches!(self, RewriteNode::Modified(ModifiedNode { children: None })) {
// Modification of an empty node is idempotent.
return self;
}
&mut self.modify(db).children.as_mut().unwrap()[index]
}

/// Replaces this node with text.
pub fn set_str(&mut self, s: String) {
*self = RewriteNode::Text(s)
}
/// Creates a new Rewrite node by interpolating a string with patches.
/// Each substring of the form `$<name>$` is replaced with syntax nodes from `patches`.
/// A `$$` substring is replaced with `$`.
pub fn interpolate_patched(
code: &str,
patches: UnorderedHashMap<String, RewriteNode>,
) -> RewriteNode {
let mut chars = code.chars().peekable();
let mut pending_text = String::new();
let mut children = Vec::new();
while let Some(c) = chars.next() {
if c != '$' {
pending_text.push(c);
continue;
}

// An opening $ was detected.

// Read the name
let mut name = String::new();
for c in chars.by_ref() {
if c == '$' {
break;
}
name.push(c);
}

// A closing $ was found.
// If the string between the `$`s is empty - push a single `$` to the output.
if name.is_empty() {
pending_text.push('$');
continue;
}
// If the string wasn't empty and there is some pending text, first flush it as a text
// child.
if !pending_text.is_empty() {
children.push(RewriteNode::Text(pending_text.clone()));
pending_text.clear();
}
// Replace the substring with the relevant rewrite node.
// TODO(yuval): this currently panics. Fix it.
children.push(patches[&name].clone());
}
// Flush the remaining text as a text child.
if !pending_text.is_empty() {
children.push(RewriteNode::Text(pending_text.clone()));
}

RewriteNode::new_modified(children)
}
}
impl From<SyntaxNode> for RewriteNode {
fn from(node: SyntaxNode) -> Self {
RewriteNode::Copied(node)
}
}

/// A modified rewrite node.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ModifiedNode {
/// Children of the node.
/// Can be None, in which case this is an empty node (of width 0). It's not the same as
/// Some(vec![]) - None can be (idempotently) modified, whereas modifying Some(vec![]) would
/// panic.
pub children: Option<Vec<RewriteNode>>,
}

#[derive(Debug, PartialEq, Eq)]
pub struct Patch {
span: TextSpan,
origin_span: TextSpan,
}

#[derive(Debug, Default, PartialEq, Eq)]
pub struct Patches {
patches: Vec<Patch>,
}
impl Patches {
pub fn translate(&self, _db: &dyn DefsGroup, span: TextSpan) -> Option<TextSpan> {
for Patch { span: patch_span, origin_span } in &self.patches {
if patch_span.contains(span) {
let start = origin_span.start.add_width(span.start - patch_span.start);
return Some(TextSpan { start, end: start.add_width(span.end - span.start) });
}
}
None
}
}

pub struct PatchBuilder<'a> {
pub db: &'a dyn SyntaxGroup,
pub code: String,
pub patches: Patches,
}
impl<'a> PatchBuilder<'a> {
pub fn new(db: &'a dyn SyntaxGroup) -> Self {
Self { db, code: String::default(), patches: Patches::default() }
}

pub fn add_char(&mut self, c: char) {
self.code.push(c);
}

pub fn add_str(&mut self, s: &str) {
self.code += s;
}

pub fn add_modified(&mut self, node: RewriteNode) {
match node {
RewriteNode::Copied(node) => self.add_node(node),
RewriteNode::Trimmed { node, trim_left, trim_right } => {
self.add_trimmed_node(node, trim_left, trim_right)
}
RewriteNode::Modified(modified) => {
if let Some(children) = modified.children {
for child in children {
self.add_modified(child)
}
}
}
RewriteNode::Text(s) => self.add_str(s.as_str()),
}
}

pub fn add_node(&mut self, node: SyntaxNode) {
let orig_span = node.span(self.db);
let start = TextOffset::default().add_width(TextWidth::from_str(&self.code));
self.patches.patches.push(Patch {
span: TextSpan { start, end: start.add_width(orig_span.end - orig_span.start) },
origin_span: node.span(self.db),
});
self.code += node.get_text(self.db).as_str();
}

fn add_trimmed_node(&mut self, node: SyntaxNode, trim_left: bool, trim_right: bool) {
let TextSpan { start: trimmed_start, end: trimmed_end } = node.span_without_trivia(self.db);
let orig_start = if trim_left { trimmed_start } else { node.span(self.db).start };
let orig_end = if trim_right { trimmed_end } else { node.span(self.db).end };
let origin_span = TextSpan { start: orig_start, end: orig_end };

let text = node.get_text_of_span(self.db, origin_span);
let start = TextOffset::default().add_width(TextWidth::from_str(&self.code));

self.code += &text;

self.patches.patches.push(Patch {
span: TextSpan { start, end: start.add_width(TextWidth::from_str(&text)) },
origin_span,
});
}
}
2 changes: 2 additions & 0 deletions crates/cairo-lang-starknet/src/abi_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use indoc::indoc;
use pretty_assertions::assert_eq;

use crate::abi::AbiBuilder;
use crate::inline_macros::selector::SelectorMacro;
use crate::plugin::StarkNetPlugin;

#[test]
Expand Down Expand Up @@ -201,6 +202,7 @@ fn test_abi_failure() {
let db = &mut RootDatabase::builder()
.detect_corelib()
.with_macro_plugin(Arc::new(StarkNetPlugin::default()))
.with_inline_macro_plugin(SelectorMacro::NAME, Arc::new(SelectorMacro))
.build()
.unwrap();
let module_id = setup_test_module(
Expand Down
2 changes: 2 additions & 0 deletions crates/cairo-lang-starknet/src/contract_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use crate::contract::{
find_contracts, get_module_abi_functions, get_selector_and_sierra_function, ContractDeclaration,
};
use crate::felt252_serde::sierra_to_felt252s;
use crate::inline_macros::selector::SelectorMacro;
use crate::plugin::consts::{CONSTRUCTOR_MODULE, EXTERNAL_MODULE, L1_HANDLER_MODULE};
use crate::plugin::StarkNetPlugin;

Expand Down Expand Up @@ -85,6 +86,7 @@ pub fn compile_path(
let mut db = RootDatabase::builder()
.detect_corelib()
.with_macro_plugin(Arc::new(StarkNetPlugin::default()))
.with_inline_macro_plugin(SelectorMacro::NAME, Arc::new(SelectorMacro))
.build()?;

let main_crate_ids = setup_project(&mut db, Path::new(&path))?;
Expand Down
1 change: 1 addition & 0 deletions crates/cairo-lang-starknet/src/inline_macros/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod selector;
46 changes: 46 additions & 0 deletions crates/cairo-lang-starknet/src/inline_macros/selector.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use cairo_lang_defs::plugin::{InlineMacroExprPlugin, InlinePluginResult, PluginDiagnostic};
use cairo_lang_semantic::inline_macros::unsupported_bracket_diagnostic;
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::{ast, TypedSyntaxNode};

use crate::contract::starknet_keccak;

/// Macro for expanding a selector to a string literal.
#[derive(Debug)]
pub struct SelectorMacro;
impl SelectorMacro {
pub const NAME: &'static str = "selector";
}
impl InlineMacroExprPlugin for SelectorMacro {
fn generate_code(
&self,
db: &dyn SyntaxGroup,
syntax: &ast::ExprInlineMacro,
) -> InlinePluginResult {
let ast::WrappedExprList::ParenthesizedExprList(args) = syntax.arguments(db) else {
return unsupported_bracket_diagnostic(db, syntax);
};

let arguments = &args.expressions(db).elements(db);
if arguments.len() != 1 {
let diagnostics = vec![PluginDiagnostic {
stable_ptr: syntax.stable_ptr().untyped(),
message: "selector macro must have a single argument".to_string(),
}];
return InlinePluginResult { code: None, diagnostics };
}

let ast::Expr::String(input_string) = &arguments[0] else {
let diagnostics = vec![PluginDiagnostic {
stable_ptr: syntax.stable_ptr().untyped(),
message: "selector macro argument must be a string".to_string(),
}];
return InlinePluginResult { code: None, diagnostics };
};
let selector_string = input_string.string_value(db).unwrap();

let selector = starknet_keccak(selector_string.as_bytes());
let code: String = format!("0x{}", selector.to_str_radix(16));
InlinePluginResult { code: Some(code), diagnostics: vec![] }
}
}
1 change: 1 addition & 0 deletions crates/cairo-lang-starknet/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub mod contract;
pub mod contract_class;
mod felt252_serde;
mod felt252_vec_compression;
pub mod inline_macros;
pub mod plugin;

#[cfg(test)]
Expand Down
Loading

0 comments on commit d62fa48

Please sign in to comment.