Skip to content

Commit

Permalink
Merge pull request #117 from 0xPolygonMiden/grjte-refactor-constraint…
Browse files Browse the repository at this point in the history
…-domain

refactor: replace row_offset with ConstraintDomain
  • Loading branch information
grjte authored Jan 23, 2023
2 parents 93143a0 + a5812b7 commit ccb90ba
Show file tree
Hide file tree
Showing 12 changed files with 200 additions and 125 deletions.
2 changes: 1 addition & 1 deletion ir/src/boundary_stmts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{BoundaryConstraintsMap, TraceSegment};
use super::{BTreeMap, Expression, IdentifierType, SemanticError, SymbolTable};
use parser::ast::{self, BoundaryStmt};

// BOUNDARY CONSTRAINTS
// BOUNDARY STATEMENTS
// ================================================================================================

/// A struct containing all of the boundary constraints to be applied at each of the 2 allowed
Expand Down
77 changes: 77 additions & 0 deletions ir/src/integrity_stmts/constraint.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use super::NodeIndex;

/// A [ConstraintRoot] represents the entry node of a subgraph representing an integrity constraint
/// within the [AlgebraicGraph]. It also contains the row offset for the constraint which is the
/// maximum of all row offsets accessed by the constraint. For example, if a constraint only
/// accesses the trace in the current row then the row offset will be 0, but if it accesses the
/// trace in both the current and the next rows then the row offset will be 1.
#[derive(Debug, Clone)]
pub struct ConstraintRoot {
index: NodeIndex,
domain: ConstraintDomain,
}

impl ConstraintRoot {
/// Creates a new [ConstraintRoot] with the specified entry index and row offset.
pub fn new(index: NodeIndex, domain: ConstraintDomain) -> Self {
Self { index, domain }
}

/// Returns the index of the entry node of the subgraph representing the constraint.
pub fn node_index(&self) -> &NodeIndex {
&self.index
}

/// Returns the [ConstraintDomain] for this constraint, which specifies the rows against which
/// the constraint should be applied.
pub fn domain(&self) -> ConstraintDomain {
self.domain
}
}

/// The domain to which the constraint is applied, which is either the first or last row (for
/// boundary constraints), every row (for validity constraints), or every frame (for transition
/// constraints). When the constraint is applied to a frame the inner value specifies the size of
/// the frame. For example, for a transition constraint that is applied against the current and next
/// rows, the frame size will be 2.
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum ConstraintDomain {
FirstRow, // for boundary constraints against the first row
LastRow, // for boundary constraints against the last row
EveryRow, // for validity constraints
EveryFrame(usize), // for transition constraints
}

impl ConstraintDomain {
/// Combines the two [ConstraintDomain]s into a single [ConstraintDomain] that represents the
/// maximum constraint domain. For example, if one domain is [ConstraintDomain::EveryFrame(2)]
/// and the other is [ConstraintDomain::EveryFrame(3)] then the result will be
/// [ConstraintDomain::EveryFrame(3)].
pub fn merge(&self, other: &ConstraintDomain) -> ConstraintDomain {
if self == other {
return *other;
}

match (self, other) {
(ConstraintDomain::EveryFrame(a), ConstraintDomain::EveryFrame(b)) => {
ConstraintDomain::EveryFrame(*a.max(b))
}
(ConstraintDomain::EveryFrame(a), _) => ConstraintDomain::EveryFrame(*a),
(_, ConstraintDomain::EveryFrame(b)) => ConstraintDomain::EveryFrame(*b),
// for any other pair of constraints which are not equal, the result of combining the
// domains is to apply the constraint at every row.
_ => ConstraintDomain::EveryRow,
}
}
}

impl From<usize> for ConstraintDomain {
/// Creates a [ConstraintDomain] from the specified row offset.
fn from(row_offset: usize) -> Self {
if row_offset == 0 {
ConstraintDomain::EveryRow
} else {
ConstraintDomain::EveryFrame(row_offset + 1)
}
}
}
113 changes: 63 additions & 50 deletions ir/src/integrity_stmts/graph.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
use super::{super::BTreeMap, degree::IntegrityConstraintDegree, SemanticError, SymbolTable};
use super::{
super::BTreeMap, degree::IntegrityConstraintDegree, ConstraintDomain, ExprDetails,
SemanticError, SymbolTable, VariableRoots,
};
use crate::{
symbol_table::IdentifierType, ConstantType, ExprDetails, Expression, Identifier,
IndexedTraceAccess, MatrixAccess, NamedTraceAccess, VariableRoots, VariableType, VectorAccess,
CURRENT_ROW,
symbol_table::IdentifierType, ConstantType, Expression, Identifier, IndexedTraceAccess,
MatrixAccess, NamedTraceAccess, TraceSegment, VariableType, VectorAccess,
};

// CONSTANTS
// ================================================================================================

/// The offset of the "current" row during constraint evaluation.
const CURRENT_ROW: usize = 0;
/// The default segment against which a constraint is applied is the main trace segment.
const DEFAULT_SEGMENT: TraceSegment = 0;
/// The default constraint domain is every row.
const DEFAULT_DOMAIN: ConstraintDomain = ConstraintDomain::EveryRow;

// ALGEBRAIC GRAPH
// ================================================================================================

Expand Down Expand Up @@ -91,10 +103,8 @@ impl AlgebraicGraph {
) -> Result<ExprDetails, SemanticError> {
match expr {
Expression::Const(value) => {
// constraint target defaults to Main trace.
let trace_segment = 0;
let node_index = self.insert_op(Operation::Constant(ConstantValue::Inline(value)));
Ok((trace_segment, node_index, CURRENT_ROW))
Ok((node_index, DEFAULT_SEGMENT, DEFAULT_DOMAIN))
}
Expression::Elem(Identifier(ident)) => {
self.insert_symbol_access(symbol_table, &ident, variable_roots)
Expand All @@ -112,7 +122,7 @@ impl AlgebraicGraph {
// the AirScript syntax.
let trace_segment = 1;
let node_index = self.insert_op(Operation::RandomValue(index));
Ok((trace_segment, node_index, CURRENT_ROW))
Ok((node_index, trace_segment, DEFAULT_DOMAIN))
}
Expression::IndexedTraceAccess(column_access) => {
self.insert_indexed_trace_access(symbol_table, column_access)
Expand All @@ -122,47 +132,47 @@ impl AlgebraicGraph {
}
Expression::Add(lhs, rhs) => {
// add both subexpressions.
let (lhs_segment, lhs, lhs_row_offset) =
let (lhs, lhs_segment, lhs_domain) =
self.insert_expr(symbol_table, *lhs, variable_roots)?;
let (rhs_segment, rhs, rhs_row_offset) =
let (rhs, rhs_segment, rhs_domain) =
self.insert_expr(symbol_table, *rhs, variable_roots)?;
// add the expression.
let trace_segment = lhs_segment.max(rhs_segment);
let node_index = self.insert_op(Operation::Add(lhs, rhs));
let row_offset = lhs_row_offset.max(rhs_row_offset);
Ok((trace_segment, node_index, row_offset))
let domain = lhs_domain.merge(&rhs_domain);
Ok((node_index, trace_segment, domain))
}
Expression::Sub(lhs, rhs) => {
// add both subexpressions.
let (lhs_segment, lhs, lhs_row_offset) =
let (lhs, lhs_segment, lhs_domain) =
self.insert_expr(symbol_table, *lhs, variable_roots)?;
let (rhs_segment, rhs, rhs_row_offset) =
let (rhs, rhs_segment, rhs_domain) =
self.insert_expr(symbol_table, *rhs, variable_roots)?;
// add the expression.
let trace_segment = lhs_segment.max(rhs_segment);
let node_index = self.insert_op(Operation::Sub(lhs, rhs));
let row_offset = lhs_row_offset.max(rhs_row_offset);
Ok((trace_segment, node_index, row_offset))
let domain = lhs_domain.merge(&rhs_domain);
Ok((node_index, trace_segment, domain))
}
Expression::Mul(lhs, rhs) => {
// add both subexpressions.
let (lhs_segment, lhs, lhs_row_offset) =
let (lhs, lhs_segment, lhs_domain) =
self.insert_expr(symbol_table, *lhs, variable_roots)?;
let (rhs_segment, rhs, rhs_row_offset) =
let (rhs, rhs_segment, rhs_domain) =
self.insert_expr(symbol_table, *rhs, variable_roots)?;
// add the expression.
let trace_segment = lhs_segment.max(rhs_segment);
let node_index = self.insert_op(Operation::Mul(lhs, rhs));
let row_offset = lhs_row_offset.max(rhs_row_offset);
Ok((trace_segment, node_index, row_offset))
let domain = lhs_domain.merge(&rhs_domain);
Ok((node_index, trace_segment, domain))
}
Expression::Exp(lhs, rhs) => {
// add base subexpression.
let (trace_segment, lhs, row_offset) =
let (lhs, trace_segment, domain) =
self.insert_expr(symbol_table, *lhs, variable_roots)?;
// add exponent subexpression.
let node_index = self.insert_op(Operation::Exp(lhs, rhs as usize));
Ok((trace_segment, node_index, row_offset))
Ok((node_index, trace_segment, domain))
}
}
}
Expand Down Expand Up @@ -199,7 +209,11 @@ impl AlgebraicGraph {
let trace_segment = trace_access.trace_segment();
let row_offset = trace_access.row_offset();
let node_index = self.insert_op(Operation::TraceElement(trace_access));
Ok((trace_segment, node_index, row_offset))
Ok((
node_index,
trace_segment,
ConstraintDomain::from(row_offset),
))
}

/// Adds a named trace element access to the graph and returns the node index, trace segment,
Expand All @@ -225,7 +239,11 @@ impl AlgebraicGraph {
row_offset,
);
let node_index = self.insert_op(Operation::TraceElement(trace_access));
Ok((trace_segment, node_index, row_offset))
Ok((
node_index,
trace_segment,
ConstraintDomain::from(row_offset),
))
}
_ => Err(SemanticError::InvalidUsage(format!(
"Identifier {} was declared as a {} not as a trace column",
Expand Down Expand Up @@ -254,35 +272,32 @@ impl AlgebraicGraph {
let trace_access =
IndexedTraceAccess::new(trace_segment, columns.offset(), CURRENT_ROW);
let node_index = self.insert_op(Operation::TraceElement(trace_access));
Ok((trace_segment, node_index, CURRENT_ROW))
Ok((node_index, trace_segment, DEFAULT_DOMAIN))
}
IdentifierType::PeriodicColumn(index, cycle_len) => {
// constraint target defaults to Main trace.
let trace_segment = 0;
let node_index = self.insert_op(Operation::PeriodicColumn(*index, *cycle_len));
Ok((trace_segment, node_index, CURRENT_ROW))
Ok((node_index, DEFAULT_SEGMENT, DEFAULT_DOMAIN))
}
IdentifierType::Constant(ConstantType::Scalar(_)) => {
let trace_segment = 0;
let node_index = self.insert_op(Operation::Constant(ConstantValue::Scalar(
ident.to_string(),
)));
Ok((trace_segment, node_index, CURRENT_ROW))
Ok((node_index, DEFAULT_SEGMENT, DEFAULT_DOMAIN))
}
IdentifierType::IntegrityVariable(integrity_variable) => {
if let VariableType::Scalar(expr) = integrity_variable.value() {
if let Some((trace_segment, node_index, row_offset)) =
if let Some((node_index, trace_segment, domain)) =
variable_roots.get(&VariableValue::Scalar(ident.to_string()))
{
Ok((*trace_segment, *node_index, *row_offset))
Ok((*node_index, *trace_segment, *domain))
} else {
let (trace_segment, node_index, row_offset) =
let (node_index, trace_segment, domain) =
self.insert_expr(symbol_table, expr.clone(), variable_roots)?;
variable_roots.insert(
VariableValue::Scalar(ident.to_string()),
(trace_segment, node_index, row_offset),
(node_index, trace_segment, domain),
);
Ok((trace_segment, node_index, row_offset))
Ok((node_index, trace_segment, domain))
}
} else {
Err(SemanticError::InvalidUsage(format!(
Expand Down Expand Up @@ -312,27 +327,26 @@ impl AlgebraicGraph {
let symbol_type = symbol_table.access_vector_element(vector_access)?;
match symbol_type {
IdentifierType::Constant(ConstantType::Vector(_)) => {
let trace_segment = 0;
let node_index = self.insert_op(Operation::Constant(ConstantValue::Vector(
vector_access.clone(),
)));
Ok((trace_segment, node_index, CURRENT_ROW))
Ok((node_index, DEFAULT_SEGMENT, DEFAULT_DOMAIN))
}
IdentifierType::IntegrityVariable(integrity_variable) => {
if let VariableType::Vector(vector) = integrity_variable.value() {
let expr = &vector[vector_access.idx()];
if let Some((trace_segment, node_index, row_offset)) =
if let Some((node_index, trace_segment, domain)) =
variable_roots.get(&VariableValue::Vector(vector_access.clone()))
{
Ok((*trace_segment, *node_index, *row_offset))
Ok((*node_index, *trace_segment, *domain))
} else {
let (trace_segment, node_index, row_offset) =
let (node_index, trace_segment, domain) =
self.insert_expr(symbol_table, expr.clone(), variable_roots)?;
variable_roots.insert(
VariableValue::Vector(vector_access.clone()),
(trace_segment, node_index, row_offset),
(node_index, trace_segment, domain),
);
Ok((trace_segment, node_index, row_offset))
Ok((node_index, trace_segment, domain))
}
} else {
Err(SemanticError::InvalidUsage(format!(
Expand All @@ -350,7 +364,7 @@ impl AlgebraicGraph {
col_idx,
CURRENT_ROW,
)));
Ok((trace_segment, node_index, CURRENT_ROW))
Ok((node_index, trace_segment, DEFAULT_DOMAIN))
}
_ => Err(SemanticError::invalid_vector_access(
vector_access,
Expand All @@ -373,27 +387,26 @@ impl AlgebraicGraph {
let symbol_type = symbol_table.access_matrix_element(matrix_access)?;
match symbol_type {
IdentifierType::Constant(ConstantType::Matrix(_)) => {
let trace_segment = 0;
let node_index = self.insert_op(Operation::Constant(ConstantValue::Matrix(
matrix_access.clone(),
)));
Ok((trace_segment, node_index, CURRENT_ROW))
Ok((node_index, DEFAULT_SEGMENT, DEFAULT_DOMAIN))
}
IdentifierType::IntegrityVariable(integrity_variable) => {
if let VariableType::Matrix(matrix) = integrity_variable.value() {
let expr = &matrix[matrix_access.row_idx()][matrix_access.col_idx()];
if let Some((trace_segment, node_index, row_offset)) =
if let Some((node_index, trace_segment, domain)) =
variable_roots.get(&VariableValue::Matrix(matrix_access.clone()))
{
Ok((*trace_segment, *node_index, *row_offset))
Ok((*node_index, *trace_segment, *domain))
} else {
let (trace_segment, node_index, row_offset) =
let (node_index, trace_segment, domain) =
self.insert_expr(symbol_table, expr.clone(), variable_roots)?;
variable_roots.insert(
VariableValue::Matrix(matrix_access.clone()),
(trace_segment, node_index, row_offset),
(node_index, trace_segment, domain),
);
Ok((trace_segment, node_index, row_offset))
Ok((node_index, trace_segment, domain))
}
} else {
Err(SemanticError::invalid_matrix_access(
Expand Down
Loading

0 comments on commit ccb90ba

Please sign in to comment.