Skip to content

Commit

Permalink
feat(prover): Refactor WitnessGenerator (#2845)
Browse files Browse the repository at this point in the history
## What ❔

Introduce new structure for witness generators.
Introduce `ArtifactsManager` trait responsible for operations with
object store and artifacts.

## Why ❔

<!-- Why are these changes done? What goal do they contribute to? What
are the principles behind them? -->
<!-- Example: PR templates ensure PR reviewers, observers, and future
iterators are in context about the evolution of repos. -->

## Checklist

<!-- Check your PR fulfills the following items. -->
<!-- For draft PRs check the boxes as you complete them. -->

- [ ] PR title corresponds to the body of PR (we generate changelog
entries from PRs).
- [ ] Tests for the changes have been added / updated.
- [ ] Documentation comments have been added / updated.
- [ ] Code has been formatted via `zk fmt` and `zk lint`.
  • Loading branch information
Artemka374 authored Sep 12, 2024
1 parent a5ffaf1 commit 934634b
Show file tree
Hide file tree
Showing 19 changed files with 1,440 additions and 1,114 deletions.
8 changes: 2 additions & 6 deletions prover/crates/bin/proof_fri_compressor/src/compressor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ impl ProofCompressor {

#[tracing::instrument(skip(proof, _compression_mode))]
pub fn compress_proof(
l1_batch: L1BatchNumber,
proof: ZkSyncRecursionLayerProof,
_compression_mode: u8,
keystore: Keystore,
Expand Down Expand Up @@ -171,16 +170,13 @@ impl JobProcessor for ProofCompressor {

async fn process_job(
&self,
job_id: &L1BatchNumber,
_job_id: &L1BatchNumber,
job: ZkSyncRecursionLayerProof,
_started_at: Instant,
) -> JoinHandle<anyhow::Result<Self::JobArtifacts>> {
let compression_mode = self.compression_mode;
let block_number = *job_id;
let keystore = self.keystore.clone();
tokio::task::spawn_blocking(move || {
Self::compress_proof(block_number, job, compression_mode, keystore)
})
tokio::task::spawn_blocking(move || Self::compress_proof(job, compression_mode, keystore))
}

async fn save_result(
Expand Down
50 changes: 50 additions & 0 deletions prover/crates/bin/witness_generator/src/artifacts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use std::time::Instant;

use async_trait::async_trait;
use zksync_object_store::ObjectStore;
use zksync_prover_dal::{ConnectionPool, Prover};

#[derive(Debug)]
pub(crate) struct AggregationBlobUrls {
pub aggregations_urls: String,
pub circuit_ids_and_urls: Vec<(u8, String)>,
}

#[derive(Debug)]
pub(crate) struct SchedulerBlobUrls {
pub circuit_ids_and_urls: Vec<(u8, String)>,
pub closed_form_inputs_and_urls: Vec<(u8, String, usize)>,
pub scheduler_witness_url: String,
}

pub(crate) enum BlobUrls {
Url(String),
Aggregation(AggregationBlobUrls),
Scheduler(SchedulerBlobUrls),
}

#[async_trait]
pub(crate) trait ArtifactsManager {
type InputMetadata;
type InputArtifacts;
type OutputArtifacts;

async fn get_artifacts(
metadata: &Self::InputMetadata,
object_store: &dyn ObjectStore,
) -> anyhow::Result<Self::InputArtifacts>;

async fn save_artifacts(
job_id: u32,
artifacts: Self::OutputArtifacts,
object_store: &dyn ObjectStore,
) -> BlobUrls;

async fn update_database(
connection_pool: &ConnectionPool<Prover>,
job_id: u32,
started_at: Instant,
blob_urls: BlobUrls,
artifacts: Self::OutputArtifacts,
) -> anyhow::Result<()>;
}
108 changes: 108 additions & 0 deletions prover/crates/bin/witness_generator/src/basic_circuits/artifacts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use std::time::Instant;

use async_trait::async_trait;
use zksync_object_store::ObjectStore;
use zksync_prover_dal::{ConnectionPool, Prover, ProverDal};
use zksync_prover_fri_types::AuxOutputWitnessWrapper;
use zksync_prover_fri_utils::get_recursive_layer_circuit_id_for_base_layer;
use zksync_types::{basic_fri_types::AggregationRound, L1BatchNumber};

use crate::{
artifacts::{ArtifactsManager, BlobUrls},
basic_circuits::{BasicCircuitArtifacts, BasicWitnessGenerator, BasicWitnessGeneratorJob},
utils::SchedulerPartialInputWrapper,
};

#[async_trait]
impl ArtifactsManager for BasicWitnessGenerator {
type InputMetadata = L1BatchNumber;
type InputArtifacts = BasicWitnessGeneratorJob;
type OutputArtifacts = BasicCircuitArtifacts;

async fn get_artifacts(
metadata: &Self::InputMetadata,
object_store: &dyn ObjectStore,
) -> anyhow::Result<Self::InputArtifacts> {
let l1_batch_number = *metadata;
let data = object_store.get(l1_batch_number).await.unwrap();
Ok(BasicWitnessGeneratorJob {
block_number: l1_batch_number,
data,
})
}

async fn save_artifacts(
job_id: u32,
artifacts: Self::OutputArtifacts,
object_store: &dyn ObjectStore,
) -> BlobUrls {
let aux_output_witness_wrapper = AuxOutputWitnessWrapper(artifacts.aux_output_witness);
object_store
.put(L1BatchNumber(job_id), &aux_output_witness_wrapper)
.await
.unwrap();
let wrapper = SchedulerPartialInputWrapper(artifacts.scheduler_witness);
let url = object_store
.put(L1BatchNumber(job_id), &wrapper)
.await
.unwrap();

BlobUrls::Url(url)
}

#[tracing::instrument(skip_all, fields(l1_batch = %job_id))]
async fn update_database(
connection_pool: &ConnectionPool<Prover>,
job_id: u32,
started_at: Instant,
blob_urls: BlobUrls,
_artifacts: Self::OutputArtifacts,
) -> anyhow::Result<()> {
let blob_urls = match blob_urls {
BlobUrls::Scheduler(blobs) => blobs,
_ => unreachable!(),
};

let mut connection = connection_pool
.connection()
.await
.expect("failed to get database connection");
let mut transaction = connection
.start_transaction()
.await
.expect("failed to get database transaction");
let protocol_version_id = transaction
.fri_witness_generator_dal()
.protocol_version_for_l1_batch(L1BatchNumber(job_id))
.await;
transaction
.fri_prover_jobs_dal()
.insert_prover_jobs(
L1BatchNumber(job_id),
blob_urls.circuit_ids_and_urls,
AggregationRound::BasicCircuits,
0,
protocol_version_id,
)
.await;
transaction
.fri_witness_generator_dal()
.create_aggregation_jobs(
L1BatchNumber(job_id),
&blob_urls.closed_form_inputs_and_urls,
&blob_urls.scheduler_witness_url,
get_recursive_layer_circuit_id_for_base_layer,
protocol_version_id,
)
.await;
transaction
.fri_witness_generator_dal()
.mark_witness_job_as_successful(L1BatchNumber(job_id), started_at.elapsed())
.await;
transaction
.commit()
.await
.expect("failed to commit database transaction");
Ok(())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
use std::{sync::Arc, time::Instant};

use anyhow::Context as _;
use tracing::Instrument;
use zksync_prover_dal::ProverDal;
use zksync_prover_fri_types::{get_current_pod_name, AuxOutputWitnessWrapper};
use zksync_queued_job_processor::{async_trait, JobProcessor};
use zksync_types::{basic_fri_types::AggregationRound, L1BatchNumber};

use crate::{
artifacts::{ArtifactsManager, BlobUrls, SchedulerBlobUrls},
basic_circuits::{BasicCircuitArtifacts, BasicWitnessGenerator, BasicWitnessGeneratorJob},
metrics::WITNESS_GENERATOR_METRICS,
};

#[async_trait]
impl JobProcessor for BasicWitnessGenerator {
type Job = BasicWitnessGeneratorJob;
type JobId = L1BatchNumber;
// The artifact is optional to support skipping blocks when sampling is enabled.
type JobArtifacts = Option<BasicCircuitArtifacts>;

const SERVICE_NAME: &'static str = "fri_basic_circuit_witness_generator";

async fn get_next_job(&self) -> anyhow::Result<Option<(Self::JobId, Self::Job)>> {
let mut prover_connection = self.prover_connection_pool.connection().await?;
let last_l1_batch_to_process = self.config.last_l1_batch_to_process();
let pod_name = get_current_pod_name();
match prover_connection
.fri_witness_generator_dal()
.get_next_basic_circuit_witness_job(
last_l1_batch_to_process,
self.protocol_version,
&pod_name,
)
.await
{
Some(block_number) => {
tracing::info!(
"Processing FRI basic witness-gen for block {}",
block_number
);
let started_at = Instant::now();
let job = Self::get_artifacts(&block_number, &*self.object_store).await?;

WITNESS_GENERATOR_METRICS.blob_fetch_time[&AggregationRound::BasicCircuits.into()]
.observe(started_at.elapsed());

Ok(Some((block_number, job)))
}
None => Ok(None),
}
}

async fn save_failure(&self, job_id: L1BatchNumber, _started_at: Instant, error: String) -> () {
self.prover_connection_pool
.connection()
.await
.unwrap()
.fri_witness_generator_dal()
.mark_witness_job_failed(&error, job_id)
.await;
}

#[allow(clippy::async_yields_async)]
async fn process_job(
&self,
_job_id: &Self::JobId,
job: BasicWitnessGeneratorJob,
started_at: Instant,
) -> tokio::task::JoinHandle<anyhow::Result<Option<BasicCircuitArtifacts>>> {
let object_store = Arc::clone(&self.object_store);
let max_circuits_in_flight = self.config.max_circuits_in_flight;
tokio::spawn(async move {
let block_number = job.block_number;
Ok(
Self::process_job_impl(object_store, job, started_at, max_circuits_in_flight)
.instrument(tracing::info_span!("basic_circuit", %block_number))
.await,
)
})
}

#[tracing::instrument(skip_all, fields(l1_batch = %job_id))]
async fn save_result(
&self,
job_id: L1BatchNumber,
started_at: Instant,
optional_artifacts: Option<BasicCircuitArtifacts>,
) -> anyhow::Result<()> {
match optional_artifacts {
None => Ok(()),
Some(artifacts) => {
let blob_started_at = Instant::now();
let circuit_urls = artifacts.circuit_urls.clone();
let queue_urls = artifacts.queue_urls.clone();

let aux_output_witness_wrapper =
AuxOutputWitnessWrapper(artifacts.aux_output_witness.clone());
if self.config.shall_save_to_public_bucket {
self.public_blob_store.as_deref()
.expect("public_object_store shall not be empty while running with shall_save_to_public_bucket config")
.put(job_id, &aux_output_witness_wrapper)
.await
.unwrap();
}

let scheduler_witness_url =
match Self::save_artifacts(job_id.0, artifacts.clone(), &*self.object_store)
.await
{
BlobUrls::Url(url) => url,
_ => unreachable!(),
};

WITNESS_GENERATOR_METRICS.blob_save_time[&AggregationRound::BasicCircuits.into()]
.observe(blob_started_at.elapsed());

Self::update_database(
&self.prover_connection_pool,
job_id.0,
started_at,
BlobUrls::Scheduler(SchedulerBlobUrls {
circuit_ids_and_urls: circuit_urls,
closed_form_inputs_and_urls: queue_urls,
scheduler_witness_url,
}),
artifacts,
)
.await?;
Ok(())
}
}
}

fn max_attempts(&self) -> u32 {
self.config.max_attempts
}

async fn get_job_attempts(&self, job_id: &L1BatchNumber) -> anyhow::Result<u32> {
let mut prover_storage = self
.prover_connection_pool
.connection()
.await
.context("failed to acquire DB connection for BasicWitnessGenerator")?;
prover_storage
.fri_witness_generator_dal()
.get_basic_circuit_witness_job_attempts(*job_id)
.await
.map(|attempts| attempts.unwrap_or(0))
.context("failed to get job attempts for BasicWitnessGenerator")
}
}
Loading

0 comments on commit 934634b

Please sign in to comment.