Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Sep 14, 2024
1 parent b4bfa60 commit 3d6ed0b
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 42 deletions.
77 changes: 37 additions & 40 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4164,7 +4164,7 @@ pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
) -> Result<ValTensor<F>, CircuitError> {
let input = values[0].clone();

let is_assigned = !input.any_unknowns()?;
let is_assigned = !input.all_prev_assigned();

let bases: ValTensor<F> = Tensor::from(
(0..*n)
Expand All @@ -4173,63 +4173,60 @@ pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
)
.into();

let mut claimed_output: ValTensor<F> = if is_assigned {
let input_evals = input.int_evals()?;
tensor::ops::decompose(&input_evals, *base, *n)?
.par_enum_map(|_, x| Ok::<_, TensorError>(Value::known(integer_rep_to_felt::<F>(x))))?
.into()
} else {
let mut dims = input.dims().to_vec();
dims.push(n + 1);

Tensor::new(
Some(&vec![Value::<F>::unknown(); input.len() * (n + 1)]),
&dims,
)?
.into()
};

claimed_output = region.assign(&config.custom_gates.inputs[0], &claimed_output)?;
region.increment(claimed_output.len());

let cartesian_coord = input
.dims()
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();

let mut dummy_iterator = Tensor::from(0..cartesian_coord.len());
let mut output: Tensor<Tensor<ValType<F>>> = Tensor::new(None, input.dims())?;

let inner_loop_function = |i: usize, region: &mut RegionCtx<F>| {
let coord = cartesian_coord[i].clone();
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let inner_loop_function =
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
let coord = cartesian_coord[i].clone();
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let mut sliced_input = input.get_slice(&slice)?;
sliced_input.flatten();

if !is_assigned {
sliced_input = region.assign(&config.custom_gates.inputs[0], &sliced_input)?;
}

let mut claimed_output_slice = claimed_output.get_slice(&slice)?;
claimed_output_slice.flatten();
let mut claimed_output_slice = sliced_input.decompose(*base, *n)?;

let sliced_input = input.get_slice(&slice)?;
// get the sign bit and make sure it is valid
let sign = claimed_output_slice.first()?;
let sign = range_check(config, region, &[sign], &(-1, 1))?;
claimed_output_slice =
region.assign(&config.custom_gates.inputs[1], &claimed_output_slice)?;
claimed_output_slice.flatten();

// get the rest of the thing and make sure it is in the correct range
let rest = claimed_output_slice.get_slice(&[1..claimed_output_slice.len()])?;
region.increment(claimed_output_slice.len());

let rest = range_check(config, region, &[rest], &(0, (base - 1) as i128))?;
// get the sign bit and make sure it is valid
let sign = claimed_output_slice.first()?;
let sign = range_check(config, region, &[sign], &(-1, 1))?;

let prod_decomp = dot(config, region, &[rest, bases.clone()])?;
// get the rest of the thing and make sure it is in the correct range
let rest = claimed_output_slice.get_slice(&[1..claimed_output_slice.len()])?;

let signed_decomp = pairwise(config, region, &[prod_decomp, sign], BaseOp::Mult)?;
let rest = range_check(config, region, &[rest], &(0, (base - 1) as i128))?;

enforce_equality(config, region, &[sliced_input, signed_decomp])?;
let prod_decomp = dot(config, region, &[rest, bases.clone()])?;

Ok(usize::default())
};
let signed_decomp = pairwise(config, region, &[prod_decomp, sign], BaseOp::Mult)?;

region.apply_in_loop(&mut dummy_iterator, inner_loop_function)?;
enforce_equality(config, region, &[sliced_input, signed_decomp])?;

Ok(claimed_output)
Ok(claimed_output_slice.get_inner_tensor()?.clone())
};

region.apply_in_loop(&mut output, inner_loop_function)?;

let mut combined_output = output.combine()?;
let mut output_dims = input.dims().to_vec();
output_dims.push(*n + 1);
combined_output.reshape(&output_dims)?;

Ok(combined_output.into())
}

pub(crate) fn sign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
Expand Down
8 changes: 6 additions & 2 deletions src/tensor/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ pub enum DecompositionError {
/// * `n` - usize
/// * `base` - usize
///
fn get_rep(x: &IntegerRep, base: usize, n: usize) -> Result<Vec<IntegerRep>, DecompositionError> {
pub fn get_rep(
x: &IntegerRep,
base: usize,
n: usize,
) -> Result<Vec<IntegerRep>, DecompositionError> {
// check if x is too large
if *x > (base.pow(n as u32) as IntegerRep) {
if x.abs() > (base.pow(n as u32) as IntegerRep) {
return Err(DecompositionError::TooLarge(*x, base, n));
}
let mut rep = vec![0; n + 1];
Expand Down
36 changes: 36 additions & 0 deletions src/tensor/val.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,42 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
.into())
}

/// Decompose the inner values into base `base` and `n` legs.
pub fn decompose(&self, base: usize, n: usize) -> Result<Self, TensorError> {
let res = self
.get_inner()?
.par_iter()
.map(|x| {
let mut is_empty = true;
x.map(|_| is_empty = false);
if is_empty {
return Ok::<_, TensorError>(vec![Value::<F>::unknown(); n + 1]);
} else {
let mut res = vec![Value::unknown(); n + 1];
let mut int_rep = 0;

x.map(|f| {
int_rep = crate::fieldutils::felt_to_integer_rep(f);
});
let decompe = crate::tensor::ops::get_rep(&int_rep, base, n)?;

for (i, x) in decompe.iter().enumerate() {
res[i] = Value::known(crate::fieldutils::integer_rep_to_felt(*x));
}
Ok(res)
}
})
.collect::<Result<Vec<_>, _>>();

let mut tensor = Tensor::from(res?.into_iter().flatten().collect::<Vec<_>>().into_iter());
let mut dims = self.dims().to_vec();
dims.push(n + 1);

tensor.reshape(&dims)?;

Ok(tensor.into())
}

/// Calls `int_evals` on the inner tensor.
pub fn int_evals(&self) -> Result<Tensor<IntegerRep>, TensorError> {
// finally convert to vector of integers
Expand Down

0 comments on commit 3d6ed0b

Please sign in to comment.