Skip to content

Commit a94575f

Browse files
committed
Some improvements
1 parent 4067c7e commit a94575f

File tree

5 files changed

+145
-23
lines changed

5 files changed

+145
-23
lines changed

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleSet.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,9 @@ public static RewriterRuleSet deserialize(String[] data, final RuleContext ctx)
305305
return new RewriterRuleSet(ctx, rules);
306306
}
307307

308-
public String toJavaCode(String className, boolean printErrors) {
308+
public String toJavaCode(String className, boolean optimize, boolean includePackageInfo, boolean printErrors) {
309309
List<Tuple2<String, RewriterRule>> mRules = IntStream.range(0, rules.size()).mapToObj(i -> new Tuple2<>("_applyRewrite" + i, rules.get(i))).collect(Collectors.toList());
310-
return RewriterCodeGen.generateClass(className, mRules, ctx, true, printErrors);
310+
return RewriterCodeGen.generateClass(className, mRules, optimize, includePackageInfo, ctx, true, printErrors);
311311
}
312312

313313
public Function<Hop, Hop> compile(String className, boolean printErrors) {

src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenCondition.java

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,25 @@ private static List<Object> populateLayerRecursively(List<Object> rules, List<In
5858

5959
for (int i = 0; i < out.size(); i++) {
6060
CodeGenCondition c = (CodeGenCondition) out.get(i);
61+
62+
if (c.rulesIf.size() <= maxNumRules)
63+
continue;
64+
6165
c.rulesIf = populateOpClassLayer(c.rulesIf, relativeChildPath, ctx);
6266

6367
for (int j = 0; j < c.rulesIf.size(); j++) {
6468
CodeGenCondition c2 = (CodeGenCondition) c.rulesIf.get(j);
6569
c2.rulesIf = populateOpCodeLayer(c2.rulesIf, relativeChildPath, ctx);
6670

71+
if (c.rulesIf.size() <= maxNumRules)
72+
continue;
73+
6774
for (int k = 0; k < c2.rulesIf.size(); k++) {
6875
CodeGenCondition c3 = (CodeGenCondition) c2.rulesIf.get(k);
6976
c3.rulesIf = populateInputSizeLayer(c3.rulesIf, relativeChildPath, ctx);
77+
78+
if (c.rulesIf.size() <= maxNumRules)
79+
continue;
7080
//int maxChildSize = c3.rulesIf.stream().flatMap(o -> ((CodeGenCondition)o).rulesIf.stream()).mapToInt(o -> ((Tuple2<RewriterRule, RewriterStatement>) o)._2.getOperands().size()).max().getAsInt();
7181

7282
for (int l = 0; l < c3.rulesIf.size(); l++) {
@@ -227,11 +237,17 @@ public void buildConditionCheck(StringBuilder sb, final RuleContext ctx) {
227237
hopVar += "_";
228238
hopVar += relativeChildPath.stream().map(Object::toString).collect(Collectors.joining("_"));
229239
}
240+
230241
String specialInstr = CodeGenUtils.getSpecialOpCheck(representant, ctx, hopVar);
231242
if (specialInstr != null) {
232243
sb.append(specialInstr);
233244
} else {
245+
// Some type casting
246+
sb.append("(( ");
247+
sb.append(CodeGenUtils.getOpClass(representant, ctx));
248+
sb.append(" ) ");
234249
sb.append(hopVar);
250+
sb.append(" )");
235251
sb.append(".getOp() == ");
236252
sb.append(CodeGenUtils.getOpCode(representant, ctx));
237253
}
@@ -367,7 +383,8 @@ public static void buildSelection(StringBuilder sb, List<CodeGenCondition> conds
367383
RewriterCodeGen.indent(indentation, sb);
368384
sb.append("hi = ");
369385
sb.append(fMapping);
370-
sb.append("(hi);");
386+
sb.append("(hi); // ");
387+
sb.append(t._1.toString());
371388
sb.append("\n");
372389
}
373390
}
@@ -379,6 +396,24 @@ public static void buildSelection(StringBuilder sb, List<CodeGenCondition> conds
379396
sb.append("if ( ");
380397
firstCond.buildConditionCheck(sb, ctx);
381398
sb.append(" ) {\n");
399+
400+
if (firstCond.conditionType == ConditionType.NUM_INPUTS) {
401+
int numInputs = (int)firstCond.conditionValue;
402+
403+
for (int i = 0; i < numInputs; i++) {
404+
RewriterCodeGen.indent(indentation + 1, sb);
405+
sb.append("Hop ");
406+
sb.append(firstCond.getVarName());
407+
sb.append("_");
408+
sb.append(i);
409+
sb.append(" = ");
410+
sb.append(firstCond.getVarName());
411+
sb.append(".getInput(");
412+
sb.append(i);
413+
sb.append(");\n");
414+
}
415+
}
416+
382417
List<CodeGenCondition> nestedCondition = firstCond.rulesIf.stream().filter(o -> o instanceof CodeGenCondition).map(o -> (CodeGenCondition)o).collect(Collectors.toList());
383418
buildSelection(sb, nestedCondition, indentation + 1, ruleFunctionMappings, ctx);
384419

@@ -388,10 +423,11 @@ public static void buildSelection(StringBuilder sb, List<CodeGenCondition> conds
388423
for (Tuple2<RewriterRule, RewriterStatement> t : cur) {
389424
String fMapping = ruleFunctionMappings.get(t._1);
390425
if (fMapping != null) {
391-
RewriterCodeGen.indent(indentation, sb);
426+
RewriterCodeGen.indent(indentation + 1, sb);
392427
sb.append("hi = ");
393428
sb.append(fMapping);
394-
sb.append("(hi);");
429+
sb.append("(hi); // ");
430+
sb.append(t._1.toString());
395431
sb.append("\n");
396432
}
397433
}
@@ -409,6 +445,23 @@ public static void buildSelection(StringBuilder sb, List<CodeGenCondition> conds
409445
sb.append(" ) {\n");
410446
}
411447

448+
if (cond.conditionType == ConditionType.NUM_INPUTS) {
449+
int numInputs = (int)cond.conditionValue;
450+
451+
for (int i = 0; i < numInputs; i++) {
452+
RewriterCodeGen.indent(indentation + 1, sb);
453+
sb.append("Hop ");
454+
sb.append(cond.getVarName());
455+
sb.append("_");
456+
sb.append(i);
457+
sb.append(" = ");
458+
sb.append(cond.getVarName());
459+
sb.append(".getInput(");
460+
sb.append(i);
461+
sb.append(");");
462+
}
463+
}
464+
412465
List<CodeGenCondition> mNestedCondition = cond.rulesIf.stream().filter(o -> o instanceof CodeGenCondition).map(o -> (CodeGenCondition)o).collect(Collectors.toList());
413466
buildSelection(sb, mNestedCondition, indentation + 1, ruleFunctionMappings, ctx);
414467

@@ -421,7 +474,8 @@ public static void buildSelection(StringBuilder sb, List<CodeGenCondition> conds
421474
RewriterCodeGen.indent(indentation, sb);
422475
sb.append("hi = ");
423476
sb.append(fMapping);
424-
sb.append("(hi);");
477+
sb.append("(hi); // ");
478+
sb.append(t._1.toString());
425479
sb.append("\n");
426480
}
427481
}

src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
import java.util.Map;
1818
import java.util.Set;
1919
import java.util.function.Function;
20+
import java.util.stream.Collectors;
2021

2122
public class RewriterCodeGen {
2223
public static boolean DEBUG = true;
2324

2425
public static Function<Hop, Hop> compileRewrites(String className, List<Tuple2<String, RewriterRule>> rewrites, final RuleContext ctx, boolean ignoreErrors, boolean printErrors) throws Exception {
25-
String code = generateClass(className, rewrites, ctx, ignoreErrors, printErrors);
26+
String code = generateClass(className, rewrites, false, false, ctx, ignoreErrors, printErrors);
2627
System.out.println("Compiling code:\n" + code);
2728
SimpleCompiler compiler = new SimpleCompiler();
2829
compiler.cook(code);
@@ -31,8 +32,12 @@ public static Function<Hop, Hop> compileRewrites(String className, List<Tuple2<S
3132
return (Function<Hop, Hop>) instance;
3233
}
3334

34-
public static String generateClass(String className, List<Tuple2<String, RewriterRule>> rewrites, final RuleContext ctx, boolean ignoreErrors, boolean printErrors) {
35+
public static String generateClass(String className, List<Tuple2<String, RewriterRule>> rewrites, boolean optimize, boolean includePackageInfo, final RuleContext ctx, boolean ignoreErrors, boolean printErrors) {
3536
StringBuilder msb = new StringBuilder();
37+
38+
if (includePackageInfo)
39+
msb.append("package org.apache.sysds.hops.rewriter;\n\n");
40+
3641
msb.append("import java.util.ArrayList;\n");
3742
msb.append("import java.util.function.Function;\n");
3843
msb.append("\n");
@@ -75,18 +80,33 @@ public static String generateClass(String className, List<Tuple2<String, Rewrite
7580
indent(1, msb);
7681
msb.append("@Override\n");
7782
indent(1, msb);
78-
msb.append("public Object apply( Object hi ) {\n");
83+
msb.append("public Object apply( Object _hi ) {\n");
7984
indent(2, msb);
80-
msb.append("if ( hi == null )\n");
85+
msb.append("if ( _hi == null )\n");
8186
indent(3, msb);
8287
msb.append("return null;\n\n");
88+
indent(2, msb);
89+
msb.append("Hop hi = (Hop) _hi;\n\n");
8390

84-
for (Tuple2<String, RewriterRule> appliedRewrites : rewrites) {
85-
if (implemented.contains(appliedRewrites._1)) {
86-
indent(2, msb);
87-
msb.append("hi = " + appliedRewrites._1 + "((Hop) hi);\t\t// ");
88-
msb.append(appliedRewrites._2.toString());
89-
msb.append('\n');
91+
if (optimize) {
92+
List<Tuple2<String, RewriterRule>> implementedRewrites = rewrites.stream().filter(t -> implemented.contains(t._1)).collect(Collectors.toList());
93+
94+
List<RewriterRule> rules = rewrites.stream().map(t -> t._2).collect(Collectors.toList());
95+
Map<RewriterRule, String> ruleNames = new HashMap<>();
96+
97+
for (Tuple2<String, RewriterRule> t : implementedRewrites)
98+
ruleNames.put(t._2, t._1);
99+
100+
List<CodeGenCondition> conditions = CodeGenCondition.buildCondition(rules, 5, ctx);
101+
CodeGenCondition.buildSelection(msb, conditions, 2, ruleNames, ctx);
102+
} else {
103+
for (Tuple2<String, RewriterRule> appliedRewrites : rewrites) {
104+
if (implemented.contains(appliedRewrites._1)) {
105+
indent(2, msb);
106+
msb.append("hi = " + appliedRewrites._1 + "((Hop) hi);\t\t// ");
107+
msb.append(appliedRewrites._2.toString());
108+
msb.append('\n');
109+
}
90110
}
91111
}
92112

src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenConditionTests.java

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,44 @@ public void test2() {
8080
}
8181

8282
@Test
83+
public void test3() {
84+
String ruleStr = "MATRIX:A\nFLOAT:b\n" +
85+
"\n" +
86+
"!=(-(b,rev(A)),A)\n" +
87+
"=>\n" +
88+
"!=(A,-(b,A))";
89+
90+
RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx);
91+
92+
String ruleStr2 = "MATRIX:A,B\n" +
93+
"\n" +
94+
"!=(-(B,rev(A)),A)\n" +
95+
"=>\n" +
96+
"!=(A,-(B,A))";
97+
98+
RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx);
99+
100+
String ruleStr3 = "MATRIX:A,B\n" +
101+
"\n" +
102+
"%*%(t(A), t(B))\n" +
103+
"=>\n" +
104+
"t(%*%(B, A))";
105+
106+
RewriterRule rule3 = RewriterUtils.parseRule(ruleStr3, ctx);
107+
108+
Map<RewriterRule, String> fNames = new HashMap<>();
109+
fNames.put(rule, "rule1");
110+
fNames.put(rule2, "rule2");
111+
fNames.put(rule3, "rule3");
112+
113+
List<CodeGenCondition> cgcs = CodeGenCondition.buildCondition(List.of(rule, rule2, rule3), 1, ctx);
114+
System.out.println(cgcs);
115+
System.out.println(CodeGenCondition.getSelectionString(cgcs, 0, fNames, ctx));
116+
}
117+
118+
119+
120+
/*@Test
83121
public void codeGen() {
84122
try {
85123
List<String> lines = Files.readAllLines(Paths.get(RewriteAutomaticallyGenerated.FILE_PATH));
@@ -95,6 +133,6 @@ public void codeGen() {
95133
} catch (IOException e) {
96134
e.printStackTrace();
97135
}
98-
}
136+
}*/
99137

100138
}

src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.junit.Test;
2323
import scala.Tuple2;
2424

25+
import java.io.FileWriter;
2526
import java.io.IOException;
2627
import java.nio.file.Files;
2728
import java.nio.file.Paths;
@@ -50,7 +51,7 @@ public void test1() {
5051
.completeRule(stmt1, stmt2)
5152
.build();
5253

53-
System.out.println(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false));
54+
System.out.println(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), false, false, ctx, false, false));
5455

5556
try {
5657
Function<Hop, Hop> f = RewriterCodeGen.compileRewrites("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false);
@@ -78,7 +79,7 @@ public void test2() {
7879
.completeRule(stmt1, stmt2)
7980
.build();
8081

81-
System.out.println(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false));
82+
System.out.println(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), false, false, ctx, false, false));
8283

8384
try {
8485
Function<Hop, Hop> f = RewriterCodeGen.compileRewrites("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false);
@@ -114,7 +115,7 @@ public void test3() {
114115
.completeRule(stmt1, stmt2)
115116
.build();
116117

117-
System.out.println(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false));
118+
System.out.println(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), false, false, ctx, false, false));
118119

119120
try {
120121
Function<Hop, Hop> f = RewriterCodeGen.compileRewrites("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false);
@@ -150,7 +151,7 @@ public void test4() {
150151
.completeRule(stmt1, stmt2)
151152
.build();
152153

153-
System.out.println(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false));
154+
System.out.println(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), false, false, ctx, false, false));
154155

155156
try {
156157
Function<Hop, Hop> f = RewriterCodeGen.compileRewrites("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false);
@@ -186,7 +187,7 @@ public void test5() {
186187
.completeRule(stmt1, stmt2)
187188
.build();
188189

189-
System.out.println(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false));
190+
System.out.println(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), false, false, ctx, false, false));
190191

191192
try {
192193
Function<Hop, Hop> f = RewriterCodeGen.compileRewrites("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false);
@@ -212,7 +213,16 @@ public void codeGen() {
212213
try {
213214
List<String> lines = Files.readAllLines(Paths.get(RewriteAutomaticallyGenerated.FILE_PATH));
214215
RewriterRuleSet ruleSet = RewriterRuleSet.deserialize(lines, ctx);
215-
System.out.println(ruleSet.toJavaCode("GeneratedRewriteClass", true));
216+
String javaCode = ruleSet.toJavaCode("GeneratedRewriteClass", true, true, true);
217+
String filePath = "/Users/janniklindemann/Dev/MScThesis/other/GeneratedRewriteClass.java";
218+
219+
try (FileWriter writer = new FileWriter(filePath)) {
220+
writer.write(javaCode);
221+
} catch (IOException e) {
222+
e.printStackTrace();
223+
}
224+
225+
System.out.println(javaCode);
216226
} catch (IOException e) {
217227
e.printStackTrace();
218228
}

0 commit comments

Comments
 (0)