Skip to content

Commit

Permalink
chore: swap comparison ops to be lookupless
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Sep 15, 2024
1 parent bc017a2 commit 34e814c
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 55 deletions.
30 changes: 12 additions & 18 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::{
ops::{accumulated, add, mult, sub},
Tensor, TensorError, ValType,
},
EZKL_DECOMP_BASE, EZKL_DECOMP_LEN,
};

use super::*;
Expand Down Expand Up @@ -2500,12 +2501,9 @@ pub fn greater<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(

let diff = pairwise(config, region, &[lhs, rhs], BaseOp::Sub)?;

nonlinearity(
config,
region,
&[diff],
&LookupOp::GreaterThan { a: utils::F32(0.) },
)
let sign = sign(config, region, &[diff], &EZKL_DECOMP_BASE, &EZKL_DECOMP_LEN)?;

equals(config, region, &[sign, create_unit_tensor(1)])
}

/// Greater equals than operation.
Expand Down Expand Up @@ -2544,21 +2542,17 @@ pub fn greater_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
) -> Result<ValTensor<F>, CircuitError> {
let (mut lhs, mut rhs) = (values[0].clone(), values[1].clone());

let broadcasted_shape = get_broadcasted_shape(lhs.dims(), rhs.dims())?;

lhs.expand(&broadcasted_shape)?;
rhs.expand(&broadcasted_shape)?;

let diff = pairwise(config, region, &[lhs, rhs], BaseOp::Sub)?;
let (lhs, rhs) = (values[0].clone(), values[1].clone());

nonlinearity(
// add 1 to lhs
let lhs_plus_one = pairwise(
config,
region,
&[diff],
&LookupOp::GreaterThanEqual { a: utils::F32(0.) },
)
&[lhs.clone(), create_unit_tensor(1)],
BaseOp::Add,
)?;

greater(config, region, &[lhs_plus_one, rhs])
}

/// Less than to operation.
Expand Down
38 changes: 1 addition & 37 deletions src/circuit/ops/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,6 @@ pub enum LookupOp {
Erf {
scale: utils::F32,
},
GreaterThan {
a: utils::F32,
},
LessThan {
a: utils::F32,
},
GreaterThanEqual {
a: utils::F32,
},
LessThanEqual {
a: utils::F32,
},
KroneckerDelta,
Pow {
scale: utils::F32,
Expand Down Expand Up @@ -143,10 +131,6 @@ impl LookupOp {
LookupOp::KroneckerDelta => "kronecker_delta".into(),
LookupOp::Max { scale, a } => format!("max_{}_{}", scale, a),
LookupOp::Min { scale, a } => format!("min_{}_{}", scale, a),
LookupOp::LessThan { a } => format!("less_than_{}", a),
LookupOp::LessThanEqual { a } => format!("less_than_equal_{}", a),
LookupOp::GreaterThan { a } => format!("greater_than_{}", a),
LookupOp::GreaterThanEqual { a } => format!("greater_than_equal_{}", a),
LookupOp::Div { denom } => format!("div_{}", denom),
LookupOp::Cast { scale } => format!("cast_{}", scale),
LookupOp::Recip {
Expand Down Expand Up @@ -208,18 +192,6 @@ impl LookupOp {
LookupOp::Min { scale, a } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::min(&x, scale.0.into(), a.0.into()),
),
LookupOp::LessThan { a } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::less_than(&x, f32::from(*a).into()),
),
LookupOp::LessThanEqual { a } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::less_than_equal(&x, f32::from(*a).into()),
),
LookupOp::GreaterThan { a } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::greater_than(&x, f32::from(*a).into()),
),
LookupOp::GreaterThanEqual { a } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::greater_than_equal(&x, f32::from(*a).into()),
),
LookupOp::Div { denom } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()),
),
Expand Down Expand Up @@ -319,10 +291,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
LookupOp::KroneckerDelta => "K_DELTA".into(),
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
LookupOp::GreaterThan { a } => format!("GREATER_THAN(a={})", a),
LookupOp::GreaterThanEqual { a } => format!("GREATER_THAN_EQUAL(a={})", a),
LookupOp::LessThan { a } => format!("LESS_THAN(a={})", a),
LookupOp::LessThanEqual { a } => format!("LESS_THAN_EQUAL(a={})", a),
LookupOp::Recip {
input_scale,
output_scale,
Expand Down Expand Up @@ -377,11 +345,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
in_scale + multiplier_to_scale(1. / scale.0 as f64)
}
LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()),
LookupOp::GreaterThan { .. }
| LookupOp::LessThan { .. }
| LookupOp::GreaterThanEqual { .. }
| LookupOp::LessThanEqual { .. }
| LookupOp::KroneckerDelta => 0,
LookupOp::KroneckerDelta => 0,
_ => inputs_scale[0],
};
Ok(scale)
Expand Down
1 change: 1 addition & 0 deletions src/circuit/ops/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ impl<
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
PolyOp::Identity { out_scale } => out_scale.unwrap_or(in_scales[0]),
PolyOp::Sign { .. } => 0,
_ => in_scales[0],
};
Ok(scale)
Expand Down
1 change: 1 addition & 0 deletions src/circuit/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::Fr as F;
use halo2curves::ff::{Field, PrimeField};
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
use ops::lookup::LookupOp;
use ops::region::RegionCtx;
use rand::rngs::OsRng;
Expand Down
6 changes: 6 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,12 @@ lazy_static! {
.unwrap();
}

#[cfg(target_arch = "wasm32")]
const EZKL_DECOMP_BASE: &usize = &16384;

#[cfg(target_arch = "wasm32")]
const EZKL_DECOMP_LEN: &usize = &2;

#[cfg(target_arch = "wasm32")]
const EZKL_KEY_FORMAT: &str = "raw-bytes";

Expand Down

0 comments on commit 34e814c

Please sign in to comment.