Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: E2E benchmark halo2 generate flamegraphs #1203

Merged
merged 12 commits into from
Jan 13, 2025
Merged
2 changes: 1 addition & 1 deletion benchmarks/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pprof = { version = "0.13", features = [

[features]
default = ["parallel", "mimalloc", "bench-metrics"]
bench-metrics = ["openvm-native-recursion/bench-metrics"]
bench-metrics = ["openvm-native-recursion/bench-metrics", "openvm-native-compiler/bench-metrics"]
stephenh-axiom-xyz marked this conversation as resolved.
Show resolved Hide resolved
profiling = ["openvm-sdk/profiling"]
aggregation = []
static-verifier = ["openvm-native-recursion/static-verifier"]
Expand Down
1 change: 1 addition & 0 deletions benchmarks/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ impl BenchmarkCli {
halo2_config: Halo2Config {
verifier_k: self.halo2_outer_k.unwrap_or(24),
wrapper_k: self.halo2_wrapper_k,
profiling: self.profiling,
},
}
}
Expand Down
30 changes: 22 additions & 8 deletions ci/scripts/metric_unify/flamegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,25 @@

from utils import FLAMEGRAPHS_DIR, get_git_root

def get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name):
def get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics=None):
"""
Filters a metrics_dict obtained from json for entries that look like:
[ { labels: [["key1", "span1;span2"], ["key2", "span3"]], "metric": metric_name, "value": 2 } ]

It will find entries that have all of stack_keys as present in the labels and then concatenate the corresponding values into a single flat stack entry and then add the value at the end.
It will write a file with one line each for flamegraph.pl or inferno-flamegraph to consume.
If sum_metrics is not None, instead of searching for metric_name, it will sum the values of the metrics in sum_metrics.
"""
lines = []
stack_sums = {}
non_zero = False

# Process counters
for counter in metrics_dict.get('counter', []):
if counter['metric'] != metric_name:
if (sum_metrics is not None and counter['metric'] not in sum_metrics) or \
(sum_metrics is None and counter['metric'] != metric_name):
continue

# list of pairs -> dict
labels = dict(counter['labels'])
filter = False
Expand All @@ -41,15 +46,21 @@ def get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name):

stack = ';'.join(stack_values)
value = int(counter['value'])
stack_sums[stack] = stack_sums.get(stack, 0) + value

if value != 0:
non_zero = True

lines.append(f"{stack} {value}")
lines = [f"{stack} {value}" for stack, value in stack_sums.items() if value != 0]

# Currently cycle tracker does not use gauge
return lines
return lines if non_zero else []


def create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name, reverse=False):
lines = get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name)
def create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics=None, reverse=False):
lines = get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics)
if not lines:
return

suffixes = [key for key in stack_keys if key != "cycle_tracker_span"]

Expand All @@ -74,7 +85,7 @@ def create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name
print(f"Created flamegraph at {flamegraph_path}")


def create_flamegraphs(metrics_file, group_by, stack_keys, metric_name, reverse=False):
def create_flamegraphs(metrics_file, group_by, stack_keys, metric_name, sum_metrics=None, reverse=False):
fname_prefix = os.path.splitext(os.path.basename(metrics_file))[0]

with open(metrics_file, 'r') as f:
Expand All @@ -92,7 +103,7 @@ def create_flamegraphs(metrics_file, group_by, stack_keys, metric_name, reverse=
for group_by_values in group_by_values_list:
group_by_kvs = list(zip(group_by, group_by_values))
fname = fname_prefix + '-' + '-'.join(group_by_values)
create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name, reverse=reverse)
create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics, reverse=reverse)


def create_custom_flamegraphs(metrics_file, group_by=["group"]):
Expand All @@ -101,6 +112,9 @@ def create_custom_flamegraphs(metrics_file, group_by=["group"]):
reverse=reverse)
create_flamegraphs(metrics_file, group_by, ["cycle_tracker_span", "dsl_ir", "opcode", "air_name"], "cells_used",
reverse=reverse)
create_flamegraphs(metrics_file, group_by, ["cell_tracker_span"], "cells_used",
sum_metrics=["simple_advice_cells", "fixed_cells", "lookup_advice_cells"],
reverse=reverse)


def main():
Expand Down
1 change: 1 addition & 0 deletions crates/sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ default = ["parallel"]
bench-metrics = [
"openvm-circuit/bench-metrics",
"openvm-native-recursion/bench-metrics",
"openvm-native-compiler/bench-metrics",
]
profiling = ["openvm-circuit/function-span", "openvm-transpiler/function-span"]
parallel = ["openvm-circuit/parallel"]
Expand Down
3 changes: 3 additions & 0 deletions crates/sdk/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ pub struct Halo2Config {
pub verifier_k: usize,
/// If not specified, keygen will tune wrapper_k automatically.
pub wrapper_k: Option<usize>,
/// Sets the profiling mode of halo2 VM
pub profiling: bool,
}

impl<VC> AppConfig<VC> {
Expand Down Expand Up @@ -101,6 +103,7 @@ impl Default for AggConfig {
halo2_config: Halo2Config {
verifier_k: 24,
wrapper_k: None,
profiling: false,
},
}
}
Expand Down
8 changes: 7 additions & 1 deletion crates/sdk/src/keygen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ pub struct Halo2ProvingKey {
pub verifier: Halo2VerifierProvingKey,
/// Wrapper circuit to verify static verifier and reduce the verification costs in the final proof.
pub wrapper: Halo2WrapperProvingKey,
/// Whether to collect detailed profiling metrics
pub profiling: bool,
}

impl<VC: VmConfig<F>> AppProvingKey<VC>
Expand Down Expand Up @@ -315,7 +317,11 @@ impl AggProvingKey {
} else {
Halo2WrapperProvingKey::keygen_auto_tune(reader, dummy_snark)
};
let halo2_pk = Halo2ProvingKey { verifier, wrapper };
let halo2_pk = Halo2ProvingKey {
verifier,
wrapper,
profiling: halo2_config.profiling,
};
Self {
agg_stark_pk,
halo2_pk,
Expand Down
7 changes: 5 additions & 2 deletions crates/sdk/src/prover/halo2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ impl Halo2Prover {
pub fn prove_for_evm(&self, root_proof: &Proof<RootSC>) -> EvmProof {
let mut witness = Witness::default();
root_proof.write(&mut witness);
let snark = info_span!("prove", group = "halo2_outer")
.in_scope(|| self.halo2_pk.verifier.prove(&self.verifier_srs, witness));
let snark = info_span!("prove", group = "halo2_outer").in_scope(|| {
self.halo2_pk
.verifier
.prove(&self.verifier_srs, witness, self.halo2_pk.profiling)
});
info_span!("prove_for_evm", group = "halo2_wrapper").in_scope(|| {
self.halo2_pk
.wrapper
Expand Down
1 change: 1 addition & 0 deletions crates/sdk/tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ fn agg_config_for_test() -> AggConfig {
halo2_config: Halo2Config {
verifier_k: 24,
wrapper_k: None,
profiling: false,
},
}
}
Expand Down
2 changes: 1 addition & 1 deletion extensions/native/recursion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ static-verifier = [
"dep:once_cell",
]
test-utils = ["openvm-circuit/test-utils"]
bench-metrics = ["dep:metrics", "openvm-circuit/bench-metrics"]
bench-metrics = ["dep:metrics", "openvm-circuit/bench-metrics", "openvm-native-compiler/bench-metrics"]
mimalloc = ["openvm-stark-backend/mimalloc"]
jemalloc = ["openvm-stark-backend/jemalloc"]
nightly-features = ["openvm-circuit/nightly-features"]
3 changes: 2 additions & 1 deletion extensions/native/recursion/src/halo2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,15 @@ impl Halo2Prover {
pk: &ProvingKey<G1Affine>,
dsl_operations: DslOperations<C>,
witness: Witness<C>,
profiling: bool,
) -> Snark {
let k = config_params.k;
#[cfg(feature = "bench-metrics")]
let start = std::time::Instant::now();
let builder = Self::builder(CircuitBuilderStage::Prover, k)
.use_params(config_params)
.use_break_points(break_points);
let builder = Self::populate(builder, dsl_operations, witness, false);
let builder = Self::populate(builder, dsl_operations, witness, profiling);
#[cfg(feature = "bench-metrics")]
{
let stats = builder.statistics();
Expand Down
2 changes: 1 addition & 1 deletion extensions/native/recursion/src/halo2/testing_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub fn run_static_verifier_test(
.entered();
let mut witness = Witness::default();
vparams.data.proof.write(&mut witness);
let static_verifier_snark = stark_verifier_circuit.prove(params, witness);
let static_verifier_snark = stark_verifier_circuit.prove(params, witness, false);
info_span.exit();
(stark_verifier_circuit, static_verifier_snark)
}
8 changes: 7 additions & 1 deletion extensions/native/recursion/src/halo2/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,20 @@ pub fn generate_halo2_verifier_proving_key(
}

impl Halo2VerifierProvingKey {
pub fn prove(&self, params: &Halo2Params, witness: Witness<OuterConfig>) -> Snark {
pub fn prove(
&self,
params: &Halo2Params,
witness: Witness<OuterConfig>,
profiling: bool,
) -> Snark {
Halo2Prover::prove(
params,
self.pinning.metadata.config_params.clone(),
self.pinning.metadata.break_points.clone(),
&self.pinning.pk,
self.dsl_ops.clone(),
witness,
profiling,
)
}
// TODO: Add verify method
Expand Down