Skip to content

Commit

Permalink
Made the plugin compilation be based on the wrapper functions.
Browse files Browse the repository at this point in the history
commit-id:4213f446
  • Loading branch information
orizi committed Nov 12, 2024
1 parent 35bbaa5 commit c0e38c8
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 16 deletions.
26 changes: 19 additions & 7 deletions crates/cairo-lang-runnable/src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use cairo_lang_sierra_generator::program_generator::SierraProgramWithDebug;
use cairo_lang_utils::Upcast;
use itertools::Itertools;

use crate::plugin::{RUNNABLE_ATTR, RunnablePlugin};
use crate::plugin::{RUNNABLE_PREFIX, RUNNABLE_RAW_ATTR, RunnablePlugin};

/// Compile the function given by path.
/// Errors if there is ambiguity.
Expand Down Expand Up @@ -45,14 +45,11 @@ pub fn compile_runnable_in_prepared_db(
) -> Result<String> {
let mut runnables: Vec<_> = find_executable_function_ids(db, main_crate_ids)
.into_iter()
.filter_map(|(id, labels)| labels.into_iter().any(|l| l == RUNNABLE_ATTR).then_some(id))
.filter_map(|(id, labels)| labels.into_iter().any(|l| l == RUNNABLE_RAW_ATTR).then_some(id))
.collect();

// TODO(ilya): Add contract names.
if let Some(runnable_path) = runnable_path {
runnables.retain(|runnable| {
runnable.base_semantic_function(db).full_path(db.upcast()) == runnable_path
});
runnables.retain(|runnable| originating_function_path(db, *runnable) == runnable_path);
};
let runnable = match runnables.len() {
0 => {
Expand All @@ -64,7 +61,7 @@ pub fn compile_runnable_in_prepared_db(
_ => {
let runnable_names = runnables
.iter()
.map(|runnable| runnable.base_semantic_function(db).full_path(db.upcast()))
.map(|runnable| originating_function_path(db, *runnable))
.join("\n ");
anyhow::bail!(
"More than one runnable found in the main crate: \n {}\nUse --runnable to \
Expand All @@ -77,6 +74,21 @@ pub fn compile_runnable_in_prepared_db(
compile_runnable_function_in_prepared_db(db, runnable, diagnostics_reporter)
}

/// Returns the path to the function that the runnable is wrapping.
///
/// If the runnable is not wrapping a function, returns the full path of the runnable.
fn originating_function_path(db: &RootDatabase, wrapper: ConcreteFunctionWithBodyId) -> String {
let wrapper_name = wrapper.name(db);
let wrapper_full_path = wrapper.base_semantic_function(db).full_path(db.upcast());
let Some(wrapped_name) = wrapper_name.strip_suffix(RUNNABLE_PREFIX) else {
return wrapper_full_path;
};
let Some(wrapper_path_to_module) = wrapper_full_path.strip_suffix(wrapper_name.as_str()) else {
return wrapper_full_path;
};
format!("{}{}", wrapper_path_to_module, wrapped_name)
}

/// Runs compiler for a runnable function.
///
/// # Arguments
Expand Down
233 changes: 227 additions & 6 deletions crates/cairo-lang-runnable/src/compile_test_data/basic
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,27 @@ fn main() {}
//! > generated_casm_code
# builtins:
# header #
%{ raise NotImplementedError("memory[ap + 0].. = params[0])") %}
%{ raise NotImplementedError("memory[ap + 2].. = params[1])") %}
ap += 4;
call rel 3;
ret;
# sierra based code #
[fp + -5] = [ap + 0] + [fp + -6], ap++;
jmp rel 4 if [ap + -1] != 0;
jmp rel 13;
%{ memory[ap + 0] = segments.add() %}
ap += 1;
[ap + 0] = 117999715903629884655797335944760714204113152088920212735095598, ap++;
[ap + -1] = [[ap + -2] + 0];
[ap + 0] = 1, ap++;
[ap + 0] = [ap + -3], ap++;
[ap + 0] = [ap + -4] + 1, ap++;
ret;
ap += 2;
[ap + 0] = 0, ap++;
[ap + 0] = [fp + -4], ap++;
[ap + 0] = [fp + -3], ap++;
ret;
# footer #
ret;
Expand All @@ -36,12 +54,58 @@ fn main(a: felt252, b: felt252) -> felt252 {
# builtins:
# header #
%{ raise NotImplementedError("memory[ap + 0].. = params[0])") %}
%{ raise NotImplementedError("memory[ap + 1].. = params[1])") %}
ap += 2;
%{ raise NotImplementedError("memory[ap + 2].. = params[1])") %}
ap += 4;
call rel 3;
ret;
# sierra based code #
[ap + 0] = [fp + -4] + [fp + -3], ap++;
[fp + -5] = [ap + 0] + [fp + -6], ap++;
jmp rel 4 if [ap + -1] != 0;
jmp rel 54;
[ap + 0] = [fp + -6] + 1, ap++;
[ap + 0] = [fp + -5], ap++;
[ap + 0] = [[fp + -6] + 0], ap++;
[ap + -2] = [ap + 0] + [ap + -3], ap++;
jmp rel 4 if [ap + -1] != 0;
jmp rel 32;
[ap + 0] = [ap + -4] + 1, ap++;
[ap + 0] = [ap + -4], ap++;
[ap + 0] = [[ap + -6] + 0], ap++;
[ap + -2] = [ap + 0] + [ap + -3], ap++;
jmp rel 4 if [ap + -1] != 0;
jmp rel 13;
%{ memory[ap + 0] = segments.add() %}
ap += 1;
[ap + 0] = 117999715903629884655797335944760714204113152088920212735095598, ap++;
[ap + -1] = [[ap + -2] + 0];
[ap + 0] = 1, ap++;
[ap + 0] = [ap + -3], ap++;
[ap + 0] = [ap + -4] + 1, ap++;
ret;
ap += 1;
[ap + 0] = [ap + -7] + [ap + -3], ap++;
[ap + -1] = [[fp + -3] + 0];
[ap + 0] = 0, ap++;
[ap + 0] = [fp + -4], ap++;
[ap + 0] = [fp + -3] + 1, ap++;
ret;
ap += 4;
%{ memory[ap + 0] = segments.add() %}
ap += 1;
[ap + 0] = 485748461484230571791265682659113160264223489397539653310998840191492913, ap++;
[ap + -1] = [[ap + -2] + 0];
[ap + 0] = 1, ap++;
[ap + 0] = [ap + -3], ap++;
[ap + 0] = [ap + -4] + 1, ap++;
ret;
ap += 8;
%{ memory[ap + 0] = segments.add() %}
ap += 1;
[ap + 0] = 485748461484230571791265682659113160264223489397539653310998840191492912, ap++;
[ap + -1] = [[ap + -2] + 0];
[ap + 0] = 1, ap++;
[ap + 0] = [ap + -3], ap++;
[ap + 0] = [ap + -4] + 1, ap++;
ret;
# footer #
ret;
Expand Down Expand Up @@ -70,12 +134,169 @@ fn fib(a: u128, b: u128, n: u128) -> u128 {
# header #
[ap + 0] = [fp + -3], ap++;
%{ raise NotImplementedError("memory[ap + 0].. = params[0])") %}
%{ raise NotImplementedError("memory[ap + 1].. = params[1])") %}
%{ raise NotImplementedError("memory[ap + 2].. = params[2])") %}
ap += 3;
%{ raise NotImplementedError("memory[ap + 2].. = params[1])") %}
ap += 4;
call rel 3;
ret;
# sierra based code #
[fp + -5] = [ap + 0] + [fp + -6], ap++;
jmp rel 4 if [ap + -1] != 0;
jmp rel 10;
[ap + 0] = [fp + -6] + 1, ap++;
[ap + 0] = [fp + -5], ap++;
[ap + 0] = 0, ap++;
[ap + 0] = [fp + -6], ap++;
jmp rel 8;
[ap + 0] = [fp + -6], ap++;
[ap + 0] = [fp + -5], ap++;
[ap + 0] = 1, ap++;
[ap + 0] = 0, ap++;
jmp rel 199 if [ap + -2] != 0;
[ap + 0] = [[ap + -1] + 0], ap++;
%{ memory[ap + 0] = memory[ap + -1] < 340282366920938463463374607431768211456 %}
jmp rel 22 if [ap + 0] != 0, ap++;
%{ (memory[ap + 3], memory[ap + 4]) = divmod(memory[ap + -2], 340282366920938463463374607431768211456) %}
[ap + 3] = [[fp + -7] + 0], ap++;
[ap + 3] = [[fp + -7] + 1], ap++;
[ap + -2] = [ap + 1] * 340282366920938463463374607431768211456, ap++;
[ap + -5] = [ap + -3] + [ap + 1], ap++;
[ap + -3] = [ap + -1] + -10633823966279327296825105735305134080, ap++;
jmp rel 6 if [ap + -4] != 0;
[ap + -3] = [ap + -1] + 340282366920938463463374607431768211455;
jmp rel 4;
[ap + -3] = [ap + -2] + 329648542954659136166549501696463077376;
[ap + -3] = [[fp + -7] + 2];
jmp rel 174 if [ap + -2] != 0;
[fp + -1] = [fp + -1] + 1;
[ap + -2] = [[fp + -7] + 0];
[ap + 0] = [fp + -7] + 1, ap++;
[ap + -6] = [ap + 0] + [ap + -7], ap++;
jmp rel 4 if [ap + -1] != 0;
jmp rel 10;
[ap + 0] = [ap + -8] + 1, ap++;
[ap + 0] = [ap + -8], ap++;
[ap + 0] = 0, ap++;
[ap + 0] = [ap + -11], ap++;
jmp rel 8;
[ap + 0] = [ap + -8], ap++;
[ap + 0] = [ap + -8], ap++;
[ap + 0] = 1, ap++;
[ap + 0] = 0, ap++;
jmp rel 135 if [ap + -2] != 0;
[ap + 0] = [[ap + -1] + 0], ap++;
%{ memory[ap + 0] = memory[ap + -1] < 340282366920938463463374607431768211456 %}
jmp rel 22 if [ap + 0] != 0, ap++;
%{ (memory[ap + 3], memory[ap + 4]) = divmod(memory[ap + -2], 340282366920938463463374607431768211456) %}
[ap + 3] = [[ap + -8] + 0], ap++;
[ap + 3] = [[ap + -9] + 1], ap++;
[ap + -2] = [ap + 1] * 340282366920938463463374607431768211456, ap++;
[ap + -5] = [ap + -3] + [ap + 1], ap++;
[ap + -3] = [ap + -1] + -10633823966279327296825105735305134080, ap++;
jmp rel 6 if [ap + -4] != 0;
[ap + -3] = [ap + -1] + 340282366920938463463374607431768211455;
jmp rel 4;
[ap + -3] = [ap + -2] + 329648542954659136166549501696463077376;
[ap + -3] = [[ap + -13] + 2];
jmp rel 110 if [ap + -2] != 0;
[fp + -1] = [fp + -1] + 1;
[ap + -2] = [[ap + -8] + 0];
[ap + 0] = [ap + -8] + 1, ap++;
[ap + -6] = [ap + 0] + [ap + -7], ap++;
jmp rel 4 if [ap + -1] != 0;
jmp rel 10;
[ap + 0] = [ap + -8] + 1, ap++;
[ap + 0] = [ap + -8], ap++;
[ap + 0] = 0, ap++;
[ap + 0] = [ap + -11], ap++;
jmp rel 8;
[ap + 0] = [ap + -8], ap++;
[ap + 0] = [ap + -8], ap++;
[ap + 0] = 1, ap++;
[ap + 0] = 0, ap++;
jmp rel 71 if [ap + -2] != 0;
[ap + 0] = [[ap + -1] + 0], ap++;
%{ memory[ap + 0] = memory[ap + -1] < 340282366920938463463374607431768211456 %}
jmp rel 22 if [ap + 0] != 0, ap++;
%{ (memory[ap + 3], memory[ap + 4]) = divmod(memory[ap + -2], 340282366920938463463374607431768211456) %}
[ap + 3] = [[ap + -8] + 0], ap++;
[ap + 3] = [[ap + -9] + 1], ap++;
[ap + -2] = [ap + 1] * 340282366920938463463374607431768211456, ap++;
[ap + -5] = [ap + -3] + [ap + 1], ap++;
[ap + -3] = [ap + -1] + -10633823966279327296825105735305134080, ap++;
jmp rel 6 if [ap + -4] != 0;
[ap + -3] = [ap + -1] + 340282366920938463463374607431768211455;
jmp rel 4;
[ap + -3] = [ap + -2] + 329648542954659136166549501696463077376;
[ap + -3] = [[ap + -13] + 2];
jmp rel 46 if [ap + -2] != 0;
[fp + -1] = [fp + -1] + 1;
[ap + -2] = [[ap + -8] + 0];
[ap + 0] = [ap + -8] + 1, ap++;
[ap + -6] = [ap + 0] + [ap + -7], ap++;
jmp rel 4 if [ap + -1] != 0;
jmp rel 14;
%{ memory[ap + 0] = segments.add() %}
ap += 1;
[ap + 0] = 117999715903629884655797335944760714204113152088920212735095598, ap++;
[ap + -1] = [[ap + -2] + 0];
[ap + 0] = [ap + -4], ap++;
[ap + 0] = 1, ap++;
[ap + 0] = [ap + -4], ap++;
[ap + 0] = [ap + -5] + 1, ap++;
ret;
[ap + 0] = [ap + -2], ap++;
[ap + 0] = [ap + -21], ap++;
[ap + 0] = [ap + -14], ap++;
[ap + 0] = [ap + -7], ap++;
call rel 69;
jmp rel 10 if [ap + -3] != 0;
[ap + -1] = [[fp + -3] + 0];
[ap + 0] = [ap + -4], ap++;
[ap + 0] = 0, ap++;
[ap + 0] = [fp + -4], ap++;
[ap + 0] = [fp + -3] + 1, ap++;
ret;
[ap + 0] = [ap + -4], ap++;
[ap + 0] = 1, ap++;
[ap + 0] = [ap + -4], ap++;
[ap + 0] = [ap + -4], ap++;
ret;
[ap + 0] = [ap + -13] + 3, ap++;
jmp rel 3;
[ap + 0] = [ap + -6], ap++;
%{ memory[ap + 0] = segments.add() %}
ap += 1;
[ap + 0] = 485748461484230571791265682659113160264223489397539653310998840191492914, ap++;
[ap + -1] = [[ap + -2] + 0];
[ap + 0] = [ap + -3], ap++;
[ap + 0] = 1, ap++;
[ap + 0] = [ap + -4], ap++;
[ap + 0] = [ap + -5] + 1, ap++;
ret;
[ap + 0] = [ap + -13] + 3, ap++;
jmp rel 3;
[ap + 0] = [ap + -6], ap++;
%{ memory[ap + 0] = segments.add() %}
ap += 1;
[ap + 0] = 485748461484230571791265682659113160264223489397539653310998840191492913, ap++;
[ap + -1] = [[ap + -2] + 0];
[ap + 0] = [ap + -3], ap++;
[ap + 0] = 1, ap++;
[ap + 0] = [ap + -4], ap++;
[ap + 0] = [ap + -5] + 1, ap++;
ret;
[ap + 0] = [fp + -7] + 3, ap++;
jmp rel 3;
[ap + 0] = [fp + -7], ap++;
%{ memory[ap + 0] = segments.add() %}
ap += 1;
[ap + 0] = 485748461484230571791265682659113160264223489397539653310998840191492912, ap++;
[ap + -1] = [[ap + -2] + 0];
[ap + 0] = [ap + -3], ap++;
[ap + 0] = 1, ap++;
[ap + 0] = [ap + -4], ap++;
[ap + 0] = [ap + -5] + 1, ap++;
ret;
jmp rel 9 if [fp + -3] != 0;
[ap + 0] = [fp + -6], ap++;
[ap + 0] = 0, ap++;
Expand Down
8 changes: 5 additions & 3 deletions crates/cairo-lang-runnable/src/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ use indoc::formatdoc;
use itertools::Itertools;

pub const RUNNABLE_ATTR: &str = "runnable";
const RUNNABLE_PREFIX: &str = "__runnable_wrapper__";
pub const RUNNABLE_RAW_ATTR: &str = "runnable_raw";
pub const RUNNABLE_PREFIX: &str = "__runnable_wrapper__";
const IMPLICIT_PRECEDENCE: &[&str] = &[
"core::pedersen::Pedersen",
"core::RangeCheck",
Expand Down Expand Up @@ -58,6 +59,7 @@ impl MacroPlugin for RunnablePlugin {
builder.add_modified(RewriteNode::interpolate_patched(&formatdoc! {"
$implicit_precedence$
#[{RUNNABLE_RAW_ATTR}]
fn {RUNNABLE_PREFIX}$function_name$(mut input: Span<felt252>, ref output: Array<felt252>) {{\n
"},
&[
Expand Down Expand Up @@ -111,10 +113,10 @@ impl MacroPlugin for RunnablePlugin {
}

fn declared_attributes(&self) -> Vec<String> {
vec![RUNNABLE_ATTR.to_string()]
vec![RUNNABLE_ATTR.to_string(), RUNNABLE_RAW_ATTR.to_string()]
}

fn executable_attributes(&self) -> Vec<String> {
vec![RUNNABLE_ATTR.to_string()]
vec![RUNNABLE_RAW_ATTR.to_string()]
}
}
2 changes: 2 additions & 0 deletions crates/cairo-lang-runnable/src/plugin_test_data/diagnostics
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ struct NoSerde {
v: felt252,
}
#[implicit_precedence(core::pedersen::Pedersen, core::RangeCheck, core::integer::Bitwise, core::ec::EcOp, core::poseidon::Poseidon, core::circuit::RangeCheck96, core::circuit::AddMod, core::circuit::MulMod)]
#[runnable_raw]
fn __runnable_wrapper__main(mut input: Span<felt252>, ref output: Array<felt252>) {

let __param__runnable_wrapper__0 = Serde::deserialize(ref input).expect('Failed to deserialize param #0');
Expand Down Expand Up @@ -89,6 +90,7 @@ struct NoSerde {
v: felt252,
}
#[implicit_precedence(core::pedersen::Pedersen, core::RangeCheck, core::integer::Bitwise, core::ec::EcOp, core::poseidon::Poseidon, core::circuit::RangeCheck96, core::circuit::AddMod, core::circuit::MulMod)]
#[runnable_raw]
fn __runnable_wrapper__main(mut input: Span<felt252>, ref output: Array<felt252>) {

let __param__runnable_wrapper__0 = Serde::deserialize(ref input).expect('Failed to deserialize param #0');
Expand Down
2 changes: 2 additions & 0 deletions crates/cairo-lang-runnable/src/plugin_test_data/expansion
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ fn main() {}
#[runnable]
fn main() {}
#[implicit_precedence(core::pedersen::Pedersen, core::RangeCheck, core::integer::Bitwise, core::ec::EcOp, core::poseidon::Poseidon, core::circuit::RangeCheck96, core::circuit::AddMod, core::circuit::MulMod)]
#[runnable_raw]
fn __runnable_wrapper__main(mut input: Span<felt252>, ref output: Array<felt252>) {

assert(core::array::SpanTrait::is_empty(input), 'Input too long for params.');
Expand Down Expand Up @@ -40,6 +41,7 @@ fn main(a: felt252, b: felt252) -> felt252 {
a + b
}
#[implicit_precedence(core::pedersen::Pedersen, core::RangeCheck, core::integer::Bitwise, core::ec::EcOp, core::poseidon::Poseidon, core::circuit::RangeCheck96, core::circuit::AddMod, core::circuit::MulMod)]
#[runnable_raw]
fn __runnable_wrapper__main(mut input: Span<felt252>, ref output: Array<felt252>) {

let __param__runnable_wrapper__0 = Serde::deserialize(ref input).expect('Failed to deserialize param #0');
Expand Down

0 comments on commit c0e38c8

Please sign in to comment.