-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(prover): Refactor WitnessGenerator (#2845)
## What ❔ Introduce new structure for witness generators. Introduce `ArtifactsManager` trait responsible for operations with object store and artifacts. ## Why ❔ <!-- Why are these changes done? What goal do they contribute to? What are the principles behind them? --> <!-- Example: PR templates ensure PR reviewers, observers, and future iterators are in context about the evolution of repos. --> ## Checklist <!-- Check your PR fulfills the following items. --> <!-- For draft PRs check the boxes as you complete them. --> - [ ] PR title corresponds to the body of PR (we generate changelog entries from PRs). - [ ] Tests for the changes have been added / updated. - [ ] Documentation comments have been added / updated. - [ ] Code has been formatted via `zk fmt` and `zk lint`.
- Loading branch information
1 parent
2f926b2
commit 6992f8c
Showing
19 changed files
with
1,440 additions
and
1,114 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
use std::time::Instant; | ||
|
||
use async_trait::async_trait; | ||
use zksync_object_store::ObjectStore; | ||
use zksync_prover_dal::{ConnectionPool, Prover}; | ||
|
||
#[derive(Debug)] | ||
pub(crate) struct AggregationBlobUrls { | ||
pub aggregations_urls: String, | ||
pub circuit_ids_and_urls: Vec<(u8, String)>, | ||
} | ||
|
||
#[derive(Debug)] | ||
pub(crate) struct SchedulerBlobUrls { | ||
pub circuit_ids_and_urls: Vec<(u8, String)>, | ||
pub closed_form_inputs_and_urls: Vec<(u8, String, usize)>, | ||
pub scheduler_witness_url: String, | ||
} | ||
|
||
pub(crate) enum BlobUrls { | ||
Url(String), | ||
Aggregation(AggregationBlobUrls), | ||
Scheduler(SchedulerBlobUrls), | ||
} | ||
|
||
#[async_trait] | ||
pub(crate) trait ArtifactsManager { | ||
type InputMetadata; | ||
type InputArtifacts; | ||
type OutputArtifacts; | ||
|
||
async fn get_artifacts( | ||
metadata: &Self::InputMetadata, | ||
object_store: &dyn ObjectStore, | ||
) -> anyhow::Result<Self::InputArtifacts>; | ||
|
||
async fn save_artifacts( | ||
job_id: u32, | ||
artifacts: Self::OutputArtifacts, | ||
object_store: &dyn ObjectStore, | ||
) -> BlobUrls; | ||
|
||
async fn update_database( | ||
connection_pool: &ConnectionPool<Prover>, | ||
job_id: u32, | ||
started_at: Instant, | ||
blob_urls: BlobUrls, | ||
artifacts: Self::OutputArtifacts, | ||
) -> anyhow::Result<()>; | ||
} |
108 changes: 108 additions & 0 deletions
108
prover/crates/bin/witness_generator/src/basic_circuits/artifacts.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
use std::time::Instant; | ||
|
||
use async_trait::async_trait; | ||
use zksync_object_store::ObjectStore; | ||
use zksync_prover_dal::{ConnectionPool, Prover, ProverDal}; | ||
use zksync_prover_fri_types::AuxOutputWitnessWrapper; | ||
use zksync_prover_fri_utils::get_recursive_layer_circuit_id_for_base_layer; | ||
use zksync_types::{basic_fri_types::AggregationRound, L1BatchNumber}; | ||
|
||
use crate::{ | ||
artifacts::{ArtifactsManager, BlobUrls}, | ||
basic_circuits::{BasicCircuitArtifacts, BasicWitnessGenerator, BasicWitnessGeneratorJob}, | ||
utils::SchedulerPartialInputWrapper, | ||
}; | ||
|
||
#[async_trait] | ||
impl ArtifactsManager for BasicWitnessGenerator { | ||
type InputMetadata = L1BatchNumber; | ||
type InputArtifacts = BasicWitnessGeneratorJob; | ||
type OutputArtifacts = BasicCircuitArtifacts; | ||
|
||
async fn get_artifacts( | ||
metadata: &Self::InputMetadata, | ||
object_store: &dyn ObjectStore, | ||
) -> anyhow::Result<Self::InputArtifacts> { | ||
let l1_batch_number = *metadata; | ||
let data = object_store.get(l1_batch_number).await.unwrap(); | ||
Ok(BasicWitnessGeneratorJob { | ||
block_number: l1_batch_number, | ||
data, | ||
}) | ||
} | ||
|
||
async fn save_artifacts( | ||
job_id: u32, | ||
artifacts: Self::OutputArtifacts, | ||
object_store: &dyn ObjectStore, | ||
) -> BlobUrls { | ||
let aux_output_witness_wrapper = AuxOutputWitnessWrapper(artifacts.aux_output_witness); | ||
object_store | ||
.put(L1BatchNumber(job_id), &aux_output_witness_wrapper) | ||
.await | ||
.unwrap(); | ||
let wrapper = SchedulerPartialInputWrapper(artifacts.scheduler_witness); | ||
let url = object_store | ||
.put(L1BatchNumber(job_id), &wrapper) | ||
.await | ||
.unwrap(); | ||
|
||
BlobUrls::Url(url) | ||
} | ||
|
||
#[tracing::instrument(skip_all, fields(l1_batch = %job_id))] | ||
async fn update_database( | ||
connection_pool: &ConnectionPool<Prover>, | ||
job_id: u32, | ||
started_at: Instant, | ||
blob_urls: BlobUrls, | ||
_artifacts: Self::OutputArtifacts, | ||
) -> anyhow::Result<()> { | ||
let blob_urls = match blob_urls { | ||
BlobUrls::Scheduler(blobs) => blobs, | ||
_ => unreachable!(), | ||
}; | ||
|
||
let mut connection = connection_pool | ||
.connection() | ||
.await | ||
.expect("failed to get database connection"); | ||
let mut transaction = connection | ||
.start_transaction() | ||
.await | ||
.expect("failed to get database transaction"); | ||
let protocol_version_id = transaction | ||
.fri_witness_generator_dal() | ||
.protocol_version_for_l1_batch(L1BatchNumber(job_id)) | ||
.await; | ||
transaction | ||
.fri_prover_jobs_dal() | ||
.insert_prover_jobs( | ||
L1BatchNumber(job_id), | ||
blob_urls.circuit_ids_and_urls, | ||
AggregationRound::BasicCircuits, | ||
0, | ||
protocol_version_id, | ||
) | ||
.await; | ||
transaction | ||
.fri_witness_generator_dal() | ||
.create_aggregation_jobs( | ||
L1BatchNumber(job_id), | ||
&blob_urls.closed_form_inputs_and_urls, | ||
&blob_urls.scheduler_witness_url, | ||
get_recursive_layer_circuit_id_for_base_layer, | ||
protocol_version_id, | ||
) | ||
.await; | ||
transaction | ||
.fri_witness_generator_dal() | ||
.mark_witness_job_as_successful(L1BatchNumber(job_id), started_at.elapsed()) | ||
.await; | ||
transaction | ||
.commit() | ||
.await | ||
.expect("failed to commit database transaction"); | ||
Ok(()) | ||
} | ||
} |
153 changes: 153 additions & 0 deletions
153
prover/crates/bin/witness_generator/src/basic_circuits/job_processor.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
use std::{sync::Arc, time::Instant}; | ||
|
||
use anyhow::Context as _; | ||
use tracing::Instrument; | ||
use zksync_prover_dal::ProverDal; | ||
use zksync_prover_fri_types::{get_current_pod_name, AuxOutputWitnessWrapper}; | ||
use zksync_queued_job_processor::{async_trait, JobProcessor}; | ||
use zksync_types::{basic_fri_types::AggregationRound, L1BatchNumber}; | ||
|
||
use crate::{ | ||
artifacts::{ArtifactsManager, BlobUrls, SchedulerBlobUrls}, | ||
basic_circuits::{BasicCircuitArtifacts, BasicWitnessGenerator, BasicWitnessGeneratorJob}, | ||
metrics::WITNESS_GENERATOR_METRICS, | ||
}; | ||
|
||
#[async_trait] | ||
impl JobProcessor for BasicWitnessGenerator { | ||
type Job = BasicWitnessGeneratorJob; | ||
type JobId = L1BatchNumber; | ||
// The artifact is optional to support skipping blocks when sampling is enabled. | ||
type JobArtifacts = Option<BasicCircuitArtifacts>; | ||
|
||
const SERVICE_NAME: &'static str = "fri_basic_circuit_witness_generator"; | ||
|
||
async fn get_next_job(&self) -> anyhow::Result<Option<(Self::JobId, Self::Job)>> { | ||
let mut prover_connection = self.prover_connection_pool.connection().await?; | ||
let last_l1_batch_to_process = self.config.last_l1_batch_to_process(); | ||
let pod_name = get_current_pod_name(); | ||
match prover_connection | ||
.fri_witness_generator_dal() | ||
.get_next_basic_circuit_witness_job( | ||
last_l1_batch_to_process, | ||
self.protocol_version, | ||
&pod_name, | ||
) | ||
.await | ||
{ | ||
Some(block_number) => { | ||
tracing::info!( | ||
"Processing FRI basic witness-gen for block {}", | ||
block_number | ||
); | ||
let started_at = Instant::now(); | ||
let job = Self::get_artifacts(&block_number, &*self.object_store).await?; | ||
|
||
WITNESS_GENERATOR_METRICS.blob_fetch_time[&AggregationRound::BasicCircuits.into()] | ||
.observe(started_at.elapsed()); | ||
|
||
Ok(Some((block_number, job))) | ||
} | ||
None => Ok(None), | ||
} | ||
} | ||
|
||
async fn save_failure(&self, job_id: L1BatchNumber, _started_at: Instant, error: String) -> () { | ||
self.prover_connection_pool | ||
.connection() | ||
.await | ||
.unwrap() | ||
.fri_witness_generator_dal() | ||
.mark_witness_job_failed(&error, job_id) | ||
.await; | ||
} | ||
|
||
#[allow(clippy::async_yields_async)] | ||
async fn process_job( | ||
&self, | ||
_job_id: &Self::JobId, | ||
job: BasicWitnessGeneratorJob, | ||
started_at: Instant, | ||
) -> tokio::task::JoinHandle<anyhow::Result<Option<BasicCircuitArtifacts>>> { | ||
let object_store = Arc::clone(&self.object_store); | ||
let max_circuits_in_flight = self.config.max_circuits_in_flight; | ||
tokio::spawn(async move { | ||
let block_number = job.block_number; | ||
Ok( | ||
Self::process_job_impl(object_store, job, started_at, max_circuits_in_flight) | ||
.instrument(tracing::info_span!("basic_circuit", %block_number)) | ||
.await, | ||
) | ||
}) | ||
} | ||
|
||
#[tracing::instrument(skip_all, fields(l1_batch = %job_id))] | ||
async fn save_result( | ||
&self, | ||
job_id: L1BatchNumber, | ||
started_at: Instant, | ||
optional_artifacts: Option<BasicCircuitArtifacts>, | ||
) -> anyhow::Result<()> { | ||
match optional_artifacts { | ||
None => Ok(()), | ||
Some(artifacts) => { | ||
let blob_started_at = Instant::now(); | ||
let circuit_urls = artifacts.circuit_urls.clone(); | ||
let queue_urls = artifacts.queue_urls.clone(); | ||
|
||
let aux_output_witness_wrapper = | ||
AuxOutputWitnessWrapper(artifacts.aux_output_witness.clone()); | ||
if self.config.shall_save_to_public_bucket { | ||
self.public_blob_store.as_deref() | ||
.expect("public_object_store shall not be empty while running with shall_save_to_public_bucket config") | ||
.put(job_id, &aux_output_witness_wrapper) | ||
.await | ||
.unwrap(); | ||
} | ||
|
||
let scheduler_witness_url = | ||
match Self::save_artifacts(job_id.0, artifacts.clone(), &*self.object_store) | ||
.await | ||
{ | ||
BlobUrls::Url(url) => url, | ||
_ => unreachable!(), | ||
}; | ||
|
||
WITNESS_GENERATOR_METRICS.blob_save_time[&AggregationRound::BasicCircuits.into()] | ||
.observe(blob_started_at.elapsed()); | ||
|
||
Self::update_database( | ||
&self.prover_connection_pool, | ||
job_id.0, | ||
started_at, | ||
BlobUrls::Scheduler(SchedulerBlobUrls { | ||
circuit_ids_and_urls: circuit_urls, | ||
closed_form_inputs_and_urls: queue_urls, | ||
scheduler_witness_url, | ||
}), | ||
artifacts, | ||
) | ||
.await?; | ||
Ok(()) | ||
} | ||
} | ||
} | ||
|
||
fn max_attempts(&self) -> u32 { | ||
self.config.max_attempts | ||
} | ||
|
||
async fn get_job_attempts(&self, job_id: &L1BatchNumber) -> anyhow::Result<u32> { | ||
let mut prover_storage = self | ||
.prover_connection_pool | ||
.connection() | ||
.await | ||
.context("failed to acquire DB connection for BasicWitnessGenerator")?; | ||
prover_storage | ||
.fri_witness_generator_dal() | ||
.get_basic_circuit_witness_job_attempts(*job_id) | ||
.await | ||
.map(|attempts| attempts.unwrap_or(0)) | ||
.context("failed to get job attempts for BasicWitnessGenerator") | ||
} | ||
} |
Oops, something went wrong.