Skip to content

Commit

Permalink
refactor(ir): update symbol lookup error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
grjte committed Apr 18, 2023
1 parent 3004030 commit bf0cbdc
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 176 deletions.
12 changes: 9 additions & 3 deletions air-script-core/src/access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ impl Display for AccessType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Default => write!(f, "direct reference by name"),
Self::Slice(_) => write!(f, "slice"),
Self::Vector(_) => write!(f, "vector"),
Self::Matrix(_, _) => write!(f, "matrix"),
Self::Slice(range) => write!(f, "slice in range {range}"),
Self::Vector(idx) => write!(f, "vector at index {idx}"),
Self::Matrix(row, col) => write!(f, "matrix at [{row}][{col}]"),
}
}
}
Expand Down Expand Up @@ -94,6 +94,12 @@ impl Range {
}
}

impl Display for Range {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}..{}", self.start(), self.end())
}
}

/// Contains values to be iterated over in a comprehension such as list comprehension or constraint
/// comprehension.
///
Expand Down
12 changes: 4 additions & 8 deletions ir/src/constraint_builder/variables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@ pub(crate) fn get_variable_expr(
if *idx < expr_vector.len() {
&expr_vector[*idx]
} else {
return Err(SemanticError::vector_access_out_of_bounds(
return Err(SemanticError::invalid_variable_access_type(
ident.name(),
*idx,
expr_vector.len(),
&access_type,
));
}
}
Expand All @@ -48,12 +47,9 @@ pub(crate) fn get_variable_expr(
if *row_idx < expr_matrix.len() && *col_idx < expr_matrix[0].len() {
&expr_matrix[*row_idx][*col_idx]
} else {
return Err(SemanticError::matrix_access_out_of_bounds(
return Err(SemanticError::invalid_variable_access_type(
ident.name(),
*row_idx,
*col_idx,
expr_matrix.len(),
expr_matrix[0].len(),
&access_type,
));
}
}
Expand Down
53 changes: 27 additions & 26 deletions ir/src/symbol_table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,34 +220,35 @@ impl SymbolTable {
) -> Result<TraceAccess, SemanticError> {
let symbol = self.get_symbol(symbol_access.name())?;

match symbol.binding() {
SymbolBinding::Trace(columns) => {
let col_offset = match symbol_access.access_type() {
AccessType::Default => columns.offset(),
AccessType::Vector(idx) => {
if *idx >= columns.size() {
todo!("invalid trace access");
}
columns.offset() + *idx
}
_ => {
todo!("invalid trace access")
}
};
Ok(TraceAccess::new(
columns.trace_segment(),
col_offset,
1,
symbol_access.offset(),
))
let columns = match symbol.binding() {
SymbolBinding::Trace(columns) => columns,
_ => return Err(SemanticError::not_a_trace_column_identifier(symbol)),
};

let col_offset = match symbol_access.access_type() {
AccessType::Default => columns.offset(),
AccessType::Vector(idx) => {
if *idx >= columns.size() {
return Err(SemanticError::invalid_access_type(
symbol,
symbol_access.access_type(),
));
}
columns.offset() + *idx
}
_ => {
return Err(SemanticError::not_a_trace_column_identifier(
symbol.name(),
symbol.binding(),
))
return Err(SemanticError::invalid_access_type(
symbol,
symbol_access.access_type(),
));
}
}
};
Ok(TraceAccess::new(
columns.trace_segment(),
col_offset,
1,
symbol_access.offset(),
))
}

/// Gets the number of trace segments that were specified for this AIR.
Expand All @@ -271,7 +272,7 @@ impl SymbolTable {
let trace_segment = usize::from(trace_access.trace_segment());
let trace_segment_width = self.declarations.trace_segment_width(trace_segment)?;
if trace_access.col_idx() as u16 >= trace_segment_width {
return Err(SemanticError::indexed_trace_column_access_out_of_bounds(
return Err(SemanticError::trace_access_out_of_bounds(
trace_access,
trace_segment_width,
));
Expand Down
143 changes: 78 additions & 65 deletions ir/src/symbol_table/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,60 +52,55 @@ impl Symbol {
symbol_access: SymbolAccess,
) -> Result<Value, SemanticError> {
if symbol_access.offset() != 0 {
return Err(SemanticError::invalid_constant_access_type(
symbol_access.name(),
symbol_access.access_type(),
return Err(SemanticError::invalid_access_offset(
self,
symbol_access.offset(),
));
}
match symbol_access.access_type() {
AccessType::Default => return Ok(Value::BoundConstant(symbol_access)),
AccessType::Slice(_) => {
return Err(SemanticError::invalid_constant_access_type(
symbol_access.name(),
return Err(SemanticError::invalid_access_type(
self,
symbol_access.access_type(),
));
}
AccessType::Vector(idx) => match constant_type {
ConstantValueExpr::Scalar(_) => {
return Err(SemanticError::invalid_constant_access_type(
symbol_access.name(),
return Err(SemanticError::invalid_access_type(
self,
symbol_access.access_type(),
))
));
}
ConstantValueExpr::Vector(vector) => {
if *idx >= vector.len() {
return Err(SemanticError::vector_access_out_of_bounds(
symbol_access.name(),
*idx,
vector.len(),
return Err(SemanticError::invalid_access_type(
self,
symbol_access.access_type(),
));
}
}
ConstantValueExpr::Matrix(matrix) => {
if *idx >= matrix.len() {
return Err(SemanticError::vector_access_out_of_bounds(
symbol_access.name(),
*idx,
matrix.len(),
return Err(SemanticError::invalid_access_type(
self,
symbol_access.access_type(),
));
}
}
},
AccessType::Matrix(row_idx, col_idx) => match constant_type {
ConstantValueExpr::Scalar(_) | ConstantValueExpr::Vector(_) => {
return Err(SemanticError::invalid_constant_access_type(
symbol_access.name(),
return Err(SemanticError::invalid_access_type(
self,
symbol_access.access_type(),
))
));
}
ConstantValueExpr::Matrix(matrix) => {
if *row_idx >= matrix.len() || *col_idx >= matrix[0].len() {
return Err(SemanticError::matrix_access_out_of_bounds(
symbol_access.name(),
*row_idx,
*col_idx,
matrix.len(),
matrix[0].len(),
return Err(SemanticError::invalid_access_type(
self,
symbol_access.access_type(),
));
}
}
Expand All @@ -121,12 +116,20 @@ impl Symbol {
cycle_len: usize,
symbol_access: SymbolAccess,
) -> Result<Value, SemanticError> {
let (_, access_type, offset) = symbol_access.into_parts();
match (access_type, offset) {
(AccessType::Default, 0) => Ok(Value::PeriodicColumn(index, cycle_len)),
_ => Err(SemanticError::invalid_periodic_column_access_type(
self.name(),
)),
if symbol_access.offset() != 0 {
return Err(SemanticError::invalid_access_offset(
self,
symbol_access.offset(),
));
}
match symbol_access.access_type() {
AccessType::Default => Ok(Value::PeriodicColumn(index, cycle_len)),
_ => {
return Err(SemanticError::invalid_access_type(
self,
symbol_access.access_type(),
))
}
}
}

Expand All @@ -135,19 +138,29 @@ impl Symbol {
size: usize,
symbol_access: SymbolAccess,
) -> Result<Value, SemanticError> {
let (_, access_type, offset) = symbol_access.into_parts();
match (access_type, offset) {
(AccessType::Vector(index), 0) => {
if index >= size {
return Err(SemanticError::vector_access_out_of_bounds(
self.name(),
index,
size,
if symbol_access.offset() != 0 {
return Err(SemanticError::invalid_access_offset(
self,
symbol_access.offset(),
));
}

match symbol_access.access_type() {
AccessType::Vector(index) => {
if *index >= size {
return Err(SemanticError::invalid_access_type(
self,
symbol_access.access_type(),
));
}
return Ok(Value::PublicInput(self.name().to_string(), index));
return Ok(Value::PublicInput(self.name().to_string(), *index));
}
_ => {
return Err(SemanticError::invalid_access_type(
self,
symbol_access.access_type(),
))
}
_ => return Err(SemanticError::invalid_public_input_access_type(self.name())),
}
}

Expand All @@ -157,29 +170,36 @@ impl Symbol {
binding_size: usize,
symbol_access: SymbolAccess,
) -> Result<Value, SemanticError> {
match (symbol_access.access_type(), symbol_access.offset()) {
(AccessType::Default, 0) => {
if symbol_access.offset() != 0 {
return Err(SemanticError::invalid_access_offset(
self,
symbol_access.offset(),
));
}

match symbol_access.access_type() {
AccessType::Default => {
if binding_size != 1 {
return Err(SemanticError::invalid_random_value_binding_access(
self.name(),
return Err(SemanticError::invalid_access_type(
self,
symbol_access.access_type(),
));
}
Ok(Value::RandomValue(binding_offset))
}
(AccessType::Vector(idx), 0) => {
AccessType::Vector(idx) => {
if *idx >= binding_size {
return Err(SemanticError::vector_access_out_of_bounds(
self.name(),
*idx,
binding_size,
return Err(SemanticError::invalid_access_type(
self,
symbol_access.access_type(),
));
}

let offset = binding_offset + idx;
Ok(Value::RandomValue(offset))
}
_ => Err(SemanticError::invalid_random_value_access_type(
self.name(),
_ => Err(SemanticError::invalid_access_type(
self,
symbol_access.access_type(),
)),
}
Expand All @@ -190,35 +210,28 @@ impl Symbol {
binding: &TraceBinding,
symbol_access: SymbolAccess,
) -> Result<Value, SemanticError> {
let (_, access_type, offset) = symbol_access.into_parts();
let (_, access_type, row_offset) = symbol_access.into_parts();
match access_type {
AccessType::Default => {
if binding.size() != 1 {
return Err(SemanticError::invalid_trace_binding_access(self.name()));
return Err(SemanticError::invalid_access_type(self, &access_type));
}
let trace_segment = binding.trace_segment();
let trace_access =
TraceAccess::new(trace_segment, binding.offset(), binding.size(), offset);
TraceAccess::new(trace_segment, binding.offset(), binding.size(), row_offset);
Ok(Value::TraceElement(trace_access))
}
AccessType::Vector(idx) => {
if idx >= binding.size() {
return Err(SemanticError::vector_access_out_of_bounds(
self.name(),
idx,
binding.size(),
));
return Err(SemanticError::invalid_access_type(self, &access_type));
}

let trace_segment = binding.trace_segment();
let trace_access =
TraceAccess::new(trace_segment, binding.offset() + idx, 1, offset);
TraceAccess::new(trace_segment, binding.offset() + idx, 1, row_offset);
Ok(Value::TraceElement(trace_access))
}
_ => Err(SemanticError::invalid_trace_access_type(
self.name(),
&access_type,
)),
_ => Err(SemanticError::invalid_access_type(self, &access_type)),
}
}
}
8 changes: 4 additions & 4 deletions ir/src/symbol_table/symbol_binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ pub(crate) enum SymbolBinding {
impl Display for SymbolBinding {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Constant(_) => write!(f, "Constant"),
Self::Trace(binding) => {
write!(f, "Trace in segment {}", binding.trace_segment())
Self::Constant(_) => write!(f, "ConstantBinding"),
Self::Trace(_) => {
write!(f, "TraceBinding")
}
Self::PublicInput(_) => write!(f, "PublicInput"),
Self::PeriodicColumn(_, _) => write!(f, "PeriodicColumn"),
Self::Variable(_) => write!(f, "Variable"),
Self::Variable(_) => write!(f, "VariableBinding"),
Self::RandomValues(_, _) => write!(f, "RandomValues"),
}
}
Expand Down
5 changes: 4 additions & 1 deletion ir/src/tests/variables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ fn err_bc_variable_ref_next() {
integrity_constraints:
enf clk' = clk + 1";

assert!(parse(source).is_err());
let parsed = parse(source).expect("Parsing failed");

let result = AirIR::new(parsed);
assert!(result.is_err());
}

#[test]
Expand Down
Loading

0 comments on commit bf0cbdc

Please sign in to comment.