Skip to content

Commit c13a1a8

Browse files
committed
Further Code Cleanup
1 parent 3b07f75 commit c13a1a8

File tree

4 files changed

+116
-726
lines changed

4 files changed

+116
-726
lines changed

src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterSearchUtils.java

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import java.util.HashSet;
3131
import java.util.List;
3232
import java.util.Map;
33+
import java.util.Random;
3334
import java.util.Set;
3435
import java.util.UUID;
3536
import java.util.stream.Collectors;
@@ -495,6 +496,105 @@ public static int toBaseNNumber(int[] digits, int n) {
495496
return out;
496497
}
497498

499+
public static List<RewriterStatement> mergeSubtreeCombinations(RewriterStatement stmt, List<Integer> indices, List<List<RewriterStatement>> mList, final RuleContext ctx, int maximumCombinations) {
500+
if (indices.isEmpty())
501+
return List.of(stmt);
502+
503+
List<RewriterStatement> mergedTreeCombinations = new ArrayList<>();
504+
RewriterUtils.cartesianProduct(mList, new RewriterStatement[mList.size()], stack -> {
505+
RewriterStatement cpy = stmt.copyNode();
506+
for (int i = 0; i < stack.length; i++)
507+
cpy.getOperands().set(indices.get(i), stack[i]);
508+
cpy.consolidate(ctx);
509+
cpy.prepareForHashing();
510+
cpy.recomputeHashCodes(ctx);
511+
mergedTreeCombinations.add(cpy);
512+
return mergedTreeCombinations.size() < maximumCombinations;
513+
});
514+
515+
return mergedTreeCombinations;
516+
}
517+
518+
public static List<RewriterStatement> generateSubtrees(RewriterStatement stmt, final RuleContext ctx, int maximumCombinations) {
519+
List<RewriterStatement> l = generateSubtrees(stmt, new HashMap<>(), ctx, maximumCombinations);
520+
521+
if (ctx.metaPropagator != null)
522+
l.forEach(subtree -> ctx.metaPropagator.apply(subtree));
523+
524+
return l.stream().map(subtree -> {
525+
if (ctx.metaPropagator != null)
526+
subtree = ctx.metaPropagator.apply(subtree);
527+
528+
subtree.prepareForHashing();
529+
subtree.recomputeHashCodes(ctx);
530+
return subtree;
531+
}).collect(Collectors.toList());
532+
}
533+
534+
private static Random rd = new Random();
535+
536+
private static List<RewriterStatement> generateSubtrees(RewriterStatement stmt, Map<RewriterStatement, List<RewriterStatement>> visited, final RuleContext ctx, int maxCombinations) {
537+
if (stmt == null)
538+
return Collections.emptyList();
539+
540+
RewriterStatement is = stmt;
541+
List<RewriterStatement> alreadyVisited = visited.get(is);
542+
543+
if (alreadyVisited != null)
544+
return alreadyVisited;
545+
546+
if (stmt.getOperands().size() == 0)
547+
return List.of(stmt);
548+
549+
// Scan if operand is not a DataType
550+
List<Integer> indices = new ArrayList<>();
551+
for (int i = 0; i < stmt.getOperands().size(); i++) {
552+
if (stmt.getChild(i).isInstruction() || stmt.getChild(i).isLiteral())
553+
indices.add(i);
554+
}
555+
556+
int n = indices.size();
557+
int totalSubsets = 1 << n;
558+
559+
List<RewriterStatement> mList = new ArrayList<>();
560+
561+
visited.put(is, mList);
562+
563+
List<List<RewriterStatement>> mOptions = indices.stream().map(i -> generateSubtrees(stmt.getOperands().get(i), visited, ctx, maxCombinations)).collect(Collectors.toList());
564+
List<RewriterStatement> out = new ArrayList<>();
565+
566+
for (int subsetMask = 0; subsetMask < totalSubsets; subsetMask++) {
567+
List<List<RewriterStatement>> mOptionCpy = new ArrayList<>(mOptions);
568+
569+
for (int i = 0; i < n; i++) {
570+
// Check if the i-th child is included in the current subset
571+
if ((subsetMask & (1 << i)) == 0) {
572+
String dt = stmt.getOperands().get(indices.get(i)).getResultingDataType(ctx);
573+
String namePrefix = "tmp";
574+
if (dt.equals("MATRIX"))
575+
namePrefix = "M";
576+
else if (dt.equals("FLOAT"))
577+
namePrefix = "f";
578+
else if (dt.equals("INT"))
579+
namePrefix = "i";
580+
else if (dt.equals("BOOL"))
581+
namePrefix = "b";
582+
RewriterDataType mT = new RewriterDataType().as(namePrefix + rd.nextInt(100000)).ofType(dt);
583+
mT.consolidate(ctx);
584+
mOptionCpy.set(i, List.of(mT));
585+
}
586+
}
587+
588+
out.addAll(mergeSubtreeCombinations(stmt, indices, mOptionCpy, ctx, maxCombinations));
589+
if (out.size() > maxCombinations) {
590+
System.out.println("Aborting early due to too many combinations");
591+
return out;
592+
}
593+
}
594+
595+
return out;
596+
}
597+
498598
public static final class Operand {
499599
public final String op;
500600
public final int numArgs;

0 commit comments

Comments
 (0)