|
30 | 30 | import java.util.HashSet;
|
31 | 31 | import java.util.List;
|
32 | 32 | import java.util.Map;
|
| 33 | +import java.util.Random; |
33 | 34 | import java.util.Set;
|
34 | 35 | import java.util.UUID;
|
35 | 36 | import java.util.stream.Collectors;
|
@@ -495,6 +496,105 @@ public static int toBaseNNumber(int[] digits, int n) {
|
495 | 496 | return out;
|
496 | 497 | }
|
497 | 498 |
|
| 499 | + public static List<RewriterStatement> mergeSubtreeCombinations(RewriterStatement stmt, List<Integer> indices, List<List<RewriterStatement>> mList, final RuleContext ctx, int maximumCombinations) { |
| 500 | + if (indices.isEmpty()) |
| 501 | + return List.of(stmt); |
| 502 | + |
| 503 | + List<RewriterStatement> mergedTreeCombinations = new ArrayList<>(); |
| 504 | + RewriterUtils.cartesianProduct(mList, new RewriterStatement[mList.size()], stack -> { |
| 505 | + RewriterStatement cpy = stmt.copyNode(); |
| 506 | + for (int i = 0; i < stack.length; i++) |
| 507 | + cpy.getOperands().set(indices.get(i), stack[i]); |
| 508 | + cpy.consolidate(ctx); |
| 509 | + cpy.prepareForHashing(); |
| 510 | + cpy.recomputeHashCodes(ctx); |
| 511 | + mergedTreeCombinations.add(cpy); |
| 512 | + return mergedTreeCombinations.size() < maximumCombinations; |
| 513 | + }); |
| 514 | + |
| 515 | + return mergedTreeCombinations; |
| 516 | + } |
| 517 | + |
| 518 | + public static List<RewriterStatement> generateSubtrees(RewriterStatement stmt, final RuleContext ctx, int maximumCombinations) { |
| 519 | + List<RewriterStatement> l = generateSubtrees(stmt, new HashMap<>(), ctx, maximumCombinations); |
| 520 | + |
| 521 | + if (ctx.metaPropagator != null) |
| 522 | + l.forEach(subtree -> ctx.metaPropagator.apply(subtree)); |
| 523 | + |
| 524 | + return l.stream().map(subtree -> { |
| 525 | + if (ctx.metaPropagator != null) |
| 526 | + subtree = ctx.metaPropagator.apply(subtree); |
| 527 | + |
| 528 | + subtree.prepareForHashing(); |
| 529 | + subtree.recomputeHashCodes(ctx); |
| 530 | + return subtree; |
| 531 | + }).collect(Collectors.toList()); |
| 532 | + } |
| 533 | + |
| 534 | + private static Random rd = new Random(); |
| 535 | + |
| 536 | + private static List<RewriterStatement> generateSubtrees(RewriterStatement stmt, Map<RewriterStatement, List<RewriterStatement>> visited, final RuleContext ctx, int maxCombinations) { |
| 537 | + if (stmt == null) |
| 538 | + return Collections.emptyList(); |
| 539 | + |
| 540 | + RewriterStatement is = stmt; |
| 541 | + List<RewriterStatement> alreadyVisited = visited.get(is); |
| 542 | + |
| 543 | + if (alreadyVisited != null) |
| 544 | + return alreadyVisited; |
| 545 | + |
| 546 | + if (stmt.getOperands().size() == 0) |
| 547 | + return List.of(stmt); |
| 548 | + |
| 549 | + // Scan if operand is not a DataType |
| 550 | + List<Integer> indices = new ArrayList<>(); |
| 551 | + for (int i = 0; i < stmt.getOperands().size(); i++) { |
| 552 | + if (stmt.getChild(i).isInstruction() || stmt.getChild(i).isLiteral()) |
| 553 | + indices.add(i); |
| 554 | + } |
| 555 | + |
| 556 | + int n = indices.size(); |
| 557 | + int totalSubsets = 1 << n; |
| 558 | + |
| 559 | + List<RewriterStatement> mList = new ArrayList<>(); |
| 560 | + |
| 561 | + visited.put(is, mList); |
| 562 | + |
| 563 | + List<List<RewriterStatement>> mOptions = indices.stream().map(i -> generateSubtrees(stmt.getOperands().get(i), visited, ctx, maxCombinations)).collect(Collectors.toList()); |
| 564 | + List<RewriterStatement> out = new ArrayList<>(); |
| 565 | + |
| 566 | + for (int subsetMask = 0; subsetMask < totalSubsets; subsetMask++) { |
| 567 | + List<List<RewriterStatement>> mOptionCpy = new ArrayList<>(mOptions); |
| 568 | + |
| 569 | + for (int i = 0; i < n; i++) { |
| 570 | + // Check if the i-th child is included in the current subset |
| 571 | + if ((subsetMask & (1 << i)) == 0) { |
| 572 | + String dt = stmt.getOperands().get(indices.get(i)).getResultingDataType(ctx); |
| 573 | + String namePrefix = "tmp"; |
| 574 | + if (dt.equals("MATRIX")) |
| 575 | + namePrefix = "M"; |
| 576 | + else if (dt.equals("FLOAT")) |
| 577 | + namePrefix = "f"; |
| 578 | + else if (dt.equals("INT")) |
| 579 | + namePrefix = "i"; |
| 580 | + else if (dt.equals("BOOL")) |
| 581 | + namePrefix = "b"; |
| 582 | + RewriterDataType mT = new RewriterDataType().as(namePrefix + rd.nextInt(100000)).ofType(dt); |
| 583 | + mT.consolidate(ctx); |
| 584 | + mOptionCpy.set(i, List.of(mT)); |
| 585 | + } |
| 586 | + } |
| 587 | + |
| 588 | + out.addAll(mergeSubtreeCombinations(stmt, indices, mOptionCpy, ctx, maxCombinations)); |
| 589 | + if (out.size() > maxCombinations) { |
| 590 | + System.out.println("Aborting early due to too many combinations"); |
| 591 | + return out; |
| 592 | + } |
| 593 | + } |
| 594 | + |
| 595 | + return out; |
| 596 | + } |
| 597 | + |
498 | 598 | public static final class Operand {
|
499 | 599 | public final String op;
|
500 | 600 | public final int numArgs;
|
|
0 commit comments