Skip to content

Commit 970149f

Browse files
committed
Some improvements
1 parent 47ee20f commit 970149f

File tree

8 files changed

+87
-11
lines changed

8 files changed

+87
-11
lines changed

src/main/java/org/apache/sysds/hops/rewriter/MetaPropagator.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,10 +300,11 @@ private RewriterStatement propagateDims(RewriterStatement root, RewriterStatemen
300300

301301
// Fused ops
302302
case "1-*(MATRIX,MATRIX)":
303+
case "log_nz(MATRIX)":
303304
root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow"));
304305
root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol"));
305306
return null;
306-
case "log_nz(MATRIX)":
307+
case "const(MATRIX,FLOAT)":
307308
root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow"));
308309
root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol"));
309310
return null;

src/main/java/org/apache/sysds/hops/rewriter/RewriterAlphabetEncoder.java

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ public class RewriterAlphabetEncoder {
4040

4141
// Fused operators
4242
new Operand("1-*", 2, MATRIX), // TODO: We have to include literals in the search
43-
new Operand("log_nz", 1, MATRIX) // TODO: We have to include literals in the search
43+
new Operand("log_nz", 1, MATRIX), // TODO: We have to include literals in the search
44+
45+
// Placeholder operators
46+
new Operand("zero", 0, ALL_TYPES),
47+
new Operand("one", 0, ALL_TYPES)
4448
};
4549

4650
private static String[] varNames = new String[] {
@@ -67,6 +71,14 @@ public class RewriterAlphabetEncoder {
6771
throw new NotImplementedException();
6872
}*/
6973

74+
public static int getMaxSearchNumberForNumOps(int numOps) {
75+
int out = 1;
76+
for (int i = 0; i < numOps; i++)
77+
out *= instructionAlphabet.length;
78+
79+
return out;
80+
}
81+
7082
public static void rename(RewriterStatement stmt) {
7183
Set<RewriterStatement> namedVars = new HashSet<>();
7284

@@ -230,7 +242,18 @@ public static List<RewriterStatement> buildAllPossibleDAGs(List<Operand> operand
230242

231243
private static List<RewriterStatement> recursivelyFindAllCombinations(List<Operand> operands) {
232244
if (operands.isEmpty())
233-
return Stream.concat(ALL_TYPES.stream().map(t -> new RewriterDataType().as(UUID.randomUUID().toString()).ofType(t).consolidate(ctx)), Stream.of(RewriterStatement.literal(ctx, 1.0D), RewriterStatement.literal(ctx, 0.0D))).collect(Collectors.toList());
245+
return ALL_TYPES.stream().map(t -> new RewriterDataType().as(UUID.randomUUID().toString()).ofType(t).consolidate(ctx)).collect(Collectors.toList());
246+
247+
// Check if op is a placeholder
248+
Operand op = operands.get(0);
249+
if (op.op.equals("zero") || op.op.equals("one")) {
250+
List<RewriterStatement> l = new ArrayList<>(4);
251+
l.add(RewriterStatement.literal(ctx, 1.0D));
252+
l.add(RewriterStatement.literal(ctx, 0.0D));
253+
l.add(new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("const").withOps(new RewriterDataType().as(UUID.randomUUID().toString()).ofType("MATRIX").consolidate(ctx), RewriterStatement.literal(ctx, 0.0D)).consolidate(ctx));
254+
l.add(new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("const").withOps(new RewriterDataType().as(UUID.randomUUID().toString()).ofType("MATRIX").consolidate(ctx), RewriterStatement.literal(ctx, 1.0D)).consolidate(ctx));
255+
return l;
256+
}
234257

235258
int nOps = operands.get(0).numArgs;
236259
int[] slices = new int[nOps-1];

src/main/java/org/apache/sysds/hops/rewriter/RewriterContextSettings.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,8 @@ public static String getDefaultContextString() {
345345
builder.append("log_nz(MATRIX," + t + ")::MATRIX\n");
346346
});
347347

348+
builder.append("const(MATRIX,FLOAT)::MATRIX\n");
349+
348350
builder.append("_m(INT,INT,FLOAT)::MATRIX\n");
349351
builder.append("_m(INT,INT,BOOL)::MATRIX\n");
350352
builder.append("_m(INT,INT,INT)::MATRIX\n");

src/main/java/org/apache/sysds/hops/rewriter/RewriterCostEstimator.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,8 @@ private static RewriterStatement computeScalarOpCost(RewriterInstruction instr,
323323
assertions.addEqualityAssertion(map.get("nrowA"), map.get("ncolA"));
324324
assertions.addEqualityAssertion(map.get("nrowA"), RewriterStatement.literal(ctx, 1L));
325325
return uniqueCosts.get(uniqueCosts.size()-1);
326+
case "const(MATRIX,FLOAT)":
327+
return RewriterStatement.literal(ctx, 0L);
326328
}
327329

328330
long opCost = atomicOpCost(instr.trueInstruction());

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,25 @@ public static void canonicalizeBooleanStatements(final List<RewriterRule> rules,
649649
public static void expandStreamingExpressions(final List<RewriterRule> rules, final RuleContext ctx) {
650650
HashMap<Integer, RewriterStatement> hooks = new HashMap<>();
651651

652+
// Const
653+
rules.add(new RewriterRuleBuilder(ctx, "Expand const matrix")
654+
.setUnidirectional(true)
655+
.parseGlobalVars("MATRIX:A")
656+
.parseGlobalVars("FLOAT:a")
657+
.parseGlobalVars("LITERAL_INT:1")
658+
.withParsedStatement("const(A, a)", hooks)
659+
.toParsedStatement("$4:_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), a)", hooks)
660+
.apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide
661+
.apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide
662+
.apply(hooks.get(4).getId(), (stmt, match) -> {
663+
UUID id = UUID.randomUUID();
664+
stmt.unsafePutMeta("ownerId", id);
665+
stmt.getChild(0).unsafePutMeta("ownerId", id);
666+
stmt.getChild(1).unsafePutMeta("ownerId", id);
667+
}, true) // Assumes it will never collide
668+
.build()
669+
);
670+
652671

653672
// Matrix Multiplication
654673
rules.add(new RewriterRuleBuilder(ctx, "Expand matrix product")

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,13 @@ public void testExpressionClustering() {
149149
}
150150

151151
if (useRandomized) {
152-
long MAX_MILLIS = 300000;
152+
long MAX_MILLIS = 100000000; // Should be bound by number of ops
153153
int BATCH_SIZE = 200;
154+
int maxN = RewriterAlphabetEncoder.getMaxSearchNumberForNumOps(2);
154155
long startMillis = System.currentTimeMillis();
155156

156-
for (int batch = 0; batch < 100 && System.currentTimeMillis() - startMillis < MAX_MILLIS; batch++) {
157-
List<Integer> indices = IntStream.range(batch * BATCH_SIZE, (batch + 1) * BATCH_SIZE - 1).boxed().collect(Collectors.toList());
157+
for (int batch = 0; batch < 100 && System.currentTimeMillis() - startMillis < MAX_MILLIS && batch * BATCH_SIZE < maxN; batch++) {
158+
List<Integer> indices = IntStream.range(batch * BATCH_SIZE, Math.min((batch + 1) * BATCH_SIZE - 1, maxN)).boxed().collect(Collectors.toList());
158159
Collections.shuffle(indices);
159160
MutableInt ctr2 = new MutableInt(0);
160161
int maxSize = indices.size();
@@ -176,17 +177,23 @@ public void testExpressionClustering() {
176177

177178
List<RewriterStatement> equivalentExpressions = new ArrayList<>();
178179
equivalentExpressions.add(stmt);
179-
canonicalForm.unsafePutMeta("equivalentExpressions", equivalentExpressions);
180+
181+
// TODO: Better handling
182+
if (!canonicalForm.isLiteral())
183+
canonicalForm.unsafePutMeta("equivalentExpressions", equivalentExpressions);
180184

181185
// Insert the canonical form or retrieve the existing entry
182186
RewriterStatement existingEntry = canonicalExprDB.insertOrReturn(ctx, canonicalForm);
183187

184188
if (existingEntry != null) {
185189
equivalentExpressions = (List<RewriterStatement>) existingEntry.getMeta("equivalentExpressions");
186-
equivalentExpressions.add(stmt);
190+
// TODO: Better handling
191+
if (equivalentExpressions != null) {
192+
equivalentExpressions.add(stmt);
187193

188-
if (equivalentExpressions.size() == 2)
189-
foundEquivalences.add(existingEntry);
194+
if (equivalentExpressions.size() == 2)
195+
foundEquivalences.add(existingEntry);
196+
}
190197

191198
//System.out.println("Found equivalent statement!");
192199
}
@@ -256,6 +263,9 @@ public void testExpressionClustering() {
256263
}
257264

258265
private void computeCost(RewriterStatement subExpr, final RuleContext ctx) {
266+
if (subExpr.isLiteral())
267+
return;
268+
259269
if (subExpr.getMeta("_cost") == null) {
260270
long cost = -1;
261271
try {

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,25 @@ public void testColSumEquivalence5() {
10151015
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));
10161016
}
10171017

1018+
@Test
1019+
public void testZeroElimination() {
1020+
RewriterStatement stmt1 = RewriterUtils.parse("*(A,0.0)", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1", "LITERAL_FLOAT:0.0");
1021+
RewriterStatement stmt2 = RewriterUtils.parse("0.0", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1", "LITERAL_FLOAT:0.0");
1022+
1023+
System.out.println("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx));
1024+
System.out.println("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx));
1025+
1026+
stmt1 = canonicalConverter.apply(stmt1);
1027+
stmt2 = canonicalConverter.apply(stmt2);
1028+
1029+
System.out.println("==========");
1030+
System.out.println(stmt1.toParsableString(ctx, true));
1031+
System.out.println("==========");
1032+
System.out.println(stmt2.toParsableString(ctx, true));
1033+
1034+
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));
1035+
}
1036+
10181037
@Test
10191038
public void testFused1() {
10201039
RewriterStatement stmt1 = RewriterUtils.parse("1-*(A, B)", ctx, "MATRIX:A,B");

src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RewriterAlphabetTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public void testEncode1() {
5454
@Test
5555
public void testRandomStatementGeneration() {
5656
int ctr = 0;
57-
for (int i = 1; i < 16; i++) {
57+
for (int i = 0; i < 20; i++) {
5858
List<RewriterAlphabetEncoder.Operand> ops = RewriterAlphabetEncoder.decodeOrderedStatements(i);
5959
//System.out.println("Idx: " + i);
6060
//System.out.println(ops);

0 commit comments

Comments
 (0)