|
1 | 1 | use std::collections::hash_map::Entry; |
2 | 2 | use std::collections::HashMap; |
| 3 | +use std::num::NonZeroUsize; |
3 | 4 | use std::str::FromStr; |
4 | 5 | use std::sync::Arc; |
5 | 6 |
|
@@ -275,6 +276,8 @@ pub enum CallName { |
275 | 276 | Custom(CustomFunction), |
276 | 277 | /// Fold of a bounded list with the given function. |
277 | 278 | Fold(CustomFunction, NonZeroPow2Usize), |
| 279 | + /// Fold of an array with the given function. |
| 280 | + ArrayFold(CustomFunction, NonZeroUsize), |
278 | 281 | /// Loop over the given function a bounded number of times until it returns success. |
279 | 282 | ForWhile(CustomFunction, Pow2Usize), |
280 | 283 | } |
@@ -1187,6 +1190,26 @@ impl AbstractSyntaxTree for Call { |
1187 | 1190 | check_output_type(out_ty, ty).with_span(from)?; |
1188 | 1191 | analyze_arguments(from.args(), &args_ty, scope)? |
1189 | 1192 | } |
| 1193 | + CallName::ArrayFold(function, size) => { |
| 1194 | + // An array fold has the signature: |
| 1195 | + // array_fold::<f, N>(array: [E; N], initial_accumulator: A) -> A |
| 1196 | + // where |
| 1197 | + // fn f(element: E, accumulator: A) -> A |
| 1198 | + let element_ty = function.params().first().expect("foldable function").ty(); |
| 1199 | + let array_ty = ResolvedType::array(element_ty.clone(), size.get()); |
| 1200 | + let accumulator_ty = function |
| 1201 | + .params() |
| 1202 | + .get(1) |
| 1203 | + .expect("foldable function") |
| 1204 | + .ty() |
| 1205 | + .clone(); |
| 1206 | + let args_ty = [array_ty, accumulator_ty]; |
| 1207 | + |
| 1208 | + check_argument_types(from.args(), &args_ty).with_span(from)?; |
| 1209 | + let out_ty = function.body().ty(); |
| 1210 | + check_output_type(out_ty, ty).with_span(from)?; |
| 1211 | + analyze_arguments(from.args(), &args_ty, scope)? |
| 1212 | + } |
1190 | 1213 | CallName::ForWhile(function, _bit_width) => { |
1191 | 1214 | // A for-while loop has the signature: |
1192 | 1215 | // for_while::<f>(initial_accumulator: A, readonly_context: C) -> Either<B, A> |
@@ -1262,6 +1285,21 @@ impl AbstractSyntaxTree for CallName { |
1262 | 1285 | .map(Self::Custom) |
1263 | 1286 | .ok_or(Error::FunctionUndefined(name.clone())) |
1264 | 1287 | .with_span(from), |
| 1288 | + parse::CallName::ArrayFold(name, size) => { |
| 1289 | + let function = scope |
| 1290 | + .get_function(name) |
| 1291 | + .cloned() |
| 1292 | + .ok_or(Error::FunctionUndefined(name.clone())) |
| 1293 | + .with_span(from)?; |
| 1294 | + // A function that is used in a array fold has the signature: |
| 1295 | + // fn f(element: E, accumulator: A) -> A |
| 1296 | + if function.params().len() != 2 || function.params()[1].ty() != function.body().ty() |
| 1297 | + { |
| 1298 | + Err(Error::FunctionNotFoldable(name.clone())).with_span(from) |
| 1299 | + } else { |
| 1300 | + Ok(Self::ArrayFold(function, *size)) |
| 1301 | + } |
| 1302 | + } |
1265 | 1303 | parse::CallName::Fold(name, bound) => { |
1266 | 1304 | let function = scope |
1267 | 1305 | .get_function(name) |
|
0 commit comments