Skip to content

Commit

Permalink
[SYSTEMDS-3774] Improved test coverage of simplification rewrites
Browse files Browse the repository at this point in the history
Closes #2109.
  • Loading branch information
ReneEnjilian authored and mboehm7 committed Sep 24, 2024
1 parent 2ce1910 commit c86aa0a
Show file tree
Hide file tree
Showing 70 changed files with 4,426 additions and 37 deletions.
6 changes: 6 additions & 0 deletions src/main/java/org/apache/sysds/hops/OptimizerUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ public enum MemoryManager {
* all sum-product related rewrites.
*/
public static boolean ALLOW_SUM_PRODUCT_REWRITES = true;

/**
* Enables additional mmchain optimizations. in the future, this might be merged with
* ALLOW_SUM_PRODUCT_REWRITES.
*/
public static boolean ALLOW_ADVANCED_MMCHAIN_REWRITES = false;

/**
* Enables a specific hop dag rewrite that splits hop dags after csv persistent reads with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
_dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse
_dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse
}
if(OptimizerUtils.ALLOW_ADVANCED_MMCHAIN_REWRITES){
_dagRuleSet.add( new RewriteMatrixMultChainOptimizationTranspose() ); //dependency: cse
_dagRuleSet.add( new RewriteMatrixMultChainOptimizationSparse() ); //dependency: cse
}
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) {
_dagRuleSet.add( new RewriteAlgebraicSimplificationDynamic() ); //dependencies: cse
_dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1635,7 +1635,7 @@ private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos)
if( !HopRewriteUtils.isTransposeOperation(tX) ) {
tX = HopRewriteUtils.createTranspose(tX);
}
else
else
tX = tX.getInput().get(0);

hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
Expand Down Expand Up @@ -1664,7 +1664,7 @@ private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos)
if( !HopRewriteUtils.isTransposeOperation(tX) ) {
tX = HopRewriteUtils.createTranspose(tX);
}
else
else
tX = tX.getInput().get(0);

hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
Expand All @@ -1690,7 +1690,7 @@ private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos)
if( !HopRewriteUtils.isTransposeOperation(tX) ) {
tX = HopRewriteUtils.createTranspose(tX);
}
else
else
tX = tX.getInput().get(0);

hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
Expand Down Expand Up @@ -1722,7 +1722,7 @@ private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos)
if( !HopRewriteUtils.isTransposeOperation(tX) ) {
tX = HopRewriteUtils.createTranspose(tX);
}
else
else
tX = tX.getInput().get(0);

hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
Expand Down Expand Up @@ -2157,7 +2157,7 @@ private static Hop simplifyWeightedUnaryMM(Hop parent, Hop hi, int pos) {

if( !HopRewriteUtils.isTransposeOperation(V) )
V = HopRewriteUtils.createTranspose(V);
else
else
V = V.getInput().get(0);

hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
Expand Down Expand Up @@ -2251,7 +2251,7 @@ else if( left.getDataType()==DataType.SCALAR && left instanceof LiteralOp

if( !HopRewriteUtils.isTransposeOperation(V) )
V = HopRewriteUtils.createTranspose(V);
else
else
V = V.getInput().get(0);

hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.test.functions.rewrite;

import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
import org.junit.Test;

import java.util.HashMap;

public class RewriteFuseBinarySubDAGToUnaryOperationTest extends AutomatedTestBase {

private static final String TEST_NAME = "RewriteFuseBinarySubDAGToUnaryOperation";
private static final String TEST_DIR = "functions/rewrite/";
private static final String TEST_CLASS_DIR =
TEST_DIR + RewriteFuseBinarySubDAGToUnaryOperationTest.class.getSimpleName() + "/";

private static final int rows = 300;
private static final int cols = 200;
private static final double eps = Math.pow(10, -10);

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
}

@Test
public void testSampleProportionLeftNoRewrite(){
testSimplifyDistributiveBinaryOperation(1, false);
}

@Test
public void testSampleProportionLeftRewrite(){
testSimplifyDistributiveBinaryOperation(1, true); //pattern: (1-X)*X -> sprop(X)
}

@Test
public void testSampleProportionRightNoRewrite(){
testSimplifyDistributiveBinaryOperation(2, false);
}

@Test
public void testSampleProportionRightRewrite(){
testSimplifyDistributiveBinaryOperation(2, true); //pattern: X*(1-X) -> sprop(X)
}

@Test
public void testFuseBinarySubDAGToUnarySigmoidNoRewrite(){
testSimplifyDistributiveBinaryOperation(3, false);
}

@Test
public void testFuseBinarySubDAGToUnarySigmoidRewrite(){
testSimplifyDistributiveBinaryOperation(3, true); //pattern: 1/(1+exp(-X)) -> sigmoid(X)
}


private void testSimplifyDistributiveBinaryOperation(int ID, boolean rewrites) {
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
try {
TestConfiguration config = getTestConfiguration(TEST_NAME);
loadTestConfiguration(config);

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "-args", input("X"), String.valueOf(ID), output("R")};
fullRScriptName = HOME + TEST_NAME + ".R";
rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir());

OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;

//create dense matrix so that rewrites are possible
double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.80d, 3);
writeInputMatrixWithMTD("X", X, true);

runTest(true, false, null, -1);
runRScript(true);

//compare matrices
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");

if (rewrites)
Assert.assertTrue(heavyHittersContainsString("sprop") || heavyHittersContainsString("sigmoid"));
else
Assert.assertFalse(heavyHittersContainsString("sprop") || heavyHittersContainsString("sigmoid"));


}
finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.test.functions.rewrite;

import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
import org.junit.Test;

import java.util.HashMap;

public class RewriteFuseLeftIndexingChainToAppendTest extends AutomatedTestBase {
private static final String TEST_NAME = "RewriteFuseLeftIndexingChainToAppend";
private static final String TEST_DIR = "functions/rewrite/";
private static final String TEST_CLASS_DIR =
TEST_DIR + RewriteFuseLeftIndexingChainToAppendTest.class.getSimpleName() + "/";

private static final int rows = 300;
private static final int cols = 1;
private static final double eps = Math.pow(10, -10);

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
}

@Test
public void testFuseLeftIndexingChainColumnNoRewrite() {
testRewriteFuseLeftIndexingChainToAppend(1, false);
}

@Test
public void testFuseLeftIndexingChainColumnRewrite() {
testRewriteFuseLeftIndexingChainToAppend(1, true);
}

@Test
public void testFuseLeftIndexingChainRowNoRewrite() {
testRewriteFuseLeftIndexingChainToAppend(2, false);
}

@Test
public void testFuseLeftIndexingChainRowRewrite() {
testRewriteFuseLeftIndexingChainToAppend(2, true);
}

private void testRewriteFuseLeftIndexingChainToAppend(int ID, boolean rewrites) {
boolean oldFlag1 = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
boolean oldFlag2 = OptimizerUtils.ALLOW_OPERATOR_FUSION;
try {
TestConfiguration config = getTestConfiguration(TEST_NAME);
loadTestConfiguration(config);

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "-args", input("A"), input("B"), String.valueOf(ID), output("R")};
fullRScriptName = HOME + TEST_NAME + ".R";
rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir());

OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
OptimizerUtils.ALLOW_OPERATOR_FUSION = rewrites;

//create matrices
double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.80d, 3);
double[][] B = getRandomMatrix(rows, cols, -1, 1, 0.80d, 5);
writeInputMatrixWithMTD("A", A, true);
writeInputMatrixWithMTD("B", B, true);

runTest(true, false, null, -1);
runRScript(true);

//compare matrices
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");

if(rewrites)
Assert.assertTrue(heavyHittersContainsString("append"));
else
Assert.assertTrue(heavyHittersContainsString("leftIndex"));

}
finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag1;
OptimizerUtils.ALLOW_OPERATOR_FUSION = oldFlag2;
}

}
}
Loading

0 comments on commit c86aa0a

Please sign in to comment.