Skip to content
5 changes: 5 additions & 0 deletions src/main/java/org/apache/sysds/api/DMLOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public class DMLOptions {
public int[] statsNGramSizes = { 3 }; // Default n-gram tuple sizes
public int statsTopKNGrams = 10; // How many of the most heavy hitting n-grams are displayed
public boolean statsNGramsUseLineage = true; // If N-Grams use lineage for data-dependent tracking
public boolean applyGeneratedRewrites = false; // If generated rewrites should be applied
public boolean fedStats = false; // Whether to record and print the federated statistics
public int fedStatsCount = 10; // Default federated statistics count
public boolean memStats = false; // max memory statistics
Expand Down Expand Up @@ -246,6 +247,8 @@ else if (lineageType.equalsIgnoreCase("debugger"))
}
}

dmlOptions.applyGeneratedRewrites = line.hasOption("applyGeneratedRewrites");

dmlOptions.fedStats = line.hasOption("fedStats");
if (dmlOptions.fedStats) {
String fedStatsCount = line.getOptionValue("fedStats");
Expand Down Expand Up @@ -372,6 +375,7 @@ private static Options createCLIOptions() {
Option ngramsOpt = OptionBuilder//.withArgName("ngrams")
.withDescription("monitors and reports the most occurring n-grams; -ngrams <comma separated n's> <topK>")
.hasOptionalArgs(2).create("ngrams");
Option applyGeneratedRewritesOpt = OptionBuilder.withArgName("applyGeneratedRewrites").withDescription("if automatically generated rewrites should be applied").create("applyGeneratedRewrites");
Option fedStatsOpt = OptionBuilder.withArgName("count")
.withDescription("monitors and reports summary execution statistics of federated workers; heavy hitter <count> is 10 unless overridden; default off")
.hasOptionalArg().create("fedStats");
Expand Down Expand Up @@ -434,6 +438,7 @@ private static Options createCLIOptions() {
options.addOption(cleanOpt);
options.addOption(statsOpt);
options.addOption(ngramsOpt);
options.addOption(applyGeneratedRewritesOpt);
options.addOption(fedStatsOpt);
options.addOption(memOpt);
options.addOption(explainOpt);
Expand Down
15 changes: 14 additions & 1 deletion src/main/java/org/apache/sysds/api/DMLScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
import java.util.Date;
import java.util.Map;
import java.util.Scanner;
import java.util.function.BiConsumer;
import java.util.function.Function;

import org.apache.commons.cli.AlreadySelectedException;
import org.apache.commons.cli.HelpFormatter;
Expand Down Expand Up @@ -106,6 +108,7 @@ public class DMLScript
public static int STATISTICS_TOP_K_NGRAMS = DMLOptions.defaultOptions.statsTopKNGrams;
// Set if N-Grams use lineage for data-dependent tracking
public static boolean STATISTICS_NGRAMS_USE_LINEAGE = DMLOptions.defaultOptions.statsNGramsUseLineage;
public static boolean APPLY_GENERATED_REWRITES = DMLOptions.defaultOptions.applyGeneratedRewrites;
// Set statistics maximum wrap length
public static int STATISTICS_MAX_WRAP_LEN = 30;
// Enable/disable to print federated statistics
Expand Down Expand Up @@ -168,6 +171,9 @@ public class DMLScript
public static String _uuid = IDHandler.createDistributedUniqueID();
private static final Log LOG = LogFactory.getLog(DMLScript.class.getName());

public static Function<DMLProgram, Boolean> preHopInterceptor = null; // Intercepts HOPs before they are rewritten
public static Function<DMLProgram, Boolean> hopInterceptor = null; // Intercepts HOPs after they are rewritten

///////////////////////////////
// public external interface
////////
Expand Down Expand Up @@ -261,6 +267,7 @@ public static boolean executeScript( String[] args )
STATISTICS_NGRAMS = dmlOptions.statsNGrams;
STATISTICS_NGRAM_SIZES = dmlOptions.statsNGramSizes;
STATISTICS_TOP_K_NGRAMS = dmlOptions.statsTopKNGrams;
APPLY_GENERATED_REWRITES = dmlOptions.applyGeneratedRewrites;
FED_STATISTICS = dmlOptions.fedStats;
FED_STATISTICS_COUNT = dmlOptions.fedStatsCount;
JMLC_MEM_STATISTICS = dmlOptions.memStats;
Expand Down Expand Up @@ -456,9 +463,15 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, Map<Stri

//init working directories (before usage by following compilation steps)
initHadoopExecution( ConfigurationManager.getDMLConfig() );


if (preHopInterceptor != null && !preHopInterceptor.apply(prog))
return;

//Step 5: rewrite HOP DAGs (incl IPA and memory estimates)
dmlt.rewriteHopsDAG(prog);

if (hopInterceptor != null && !hopInterceptor.apply(prog))
return;

//Step 6: construct lops (incl exec type and op selection)
dmlt.constructLops(prog);
Expand Down
102 changes: 102 additions & 0 deletions src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,30 @@ public static DataGenOp copyDataGenOp( DataGenOp inputGen, double scale, double

return datagen;
}

public static Hop createDataGenOpFromDims( Hop rows, Hop cols, double value ) {
Hop val = new LiteralOp(value);

HashMap<String, Hop> params = new HashMap<>();
params.put(DataExpression.RAND_ROWS, rows);
params.put(DataExpression.RAND_COLS, cols);
params.put(DataExpression.RAND_MIN, val);
params.put(DataExpression.RAND_MAX, val);
params.put(DataExpression.RAND_PDF, new LiteralOp(DataExpression.RAND_PDF_UNIFORM));
params.put(DataExpression.RAND_LAMBDA, new LiteralOp(-1.0));
params.put(DataExpression.RAND_SPARSITY, new LiteralOp(1.0));
params.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED) );

//note internal refresh size information
Hop datagen = new DataGenOp(OpOpDG.RAND, new DataIdentifier("tmp"), params);
datagen.setBlocksize(1000);
//copyLineNumbers(rowInput, datagen);

if( value==0 )
datagen.setNnz(0);

return datagen;
}

public static Hop createDataGenOp( Hop rowInput, Hop colInput, double value )
{
Expand Down Expand Up @@ -661,6 +685,84 @@ public static BinaryOp createBinary(Hop input1, Hop input2, OpOp2 op, boolean ou
bop.refreshSizeInformation();
return bop;
}

// To fix issues with createBinary, which does not always correctly set value types (e.g. INT-MATRIX+FLOAT-SCALAR -> bop(+)::INT)
public static BinaryOp createAutoGeneratedBinary(Hop input1, Hop input2, OpOp2 op) {
Hop mainInput = input1.getDataType().isMatrix() ? input1 :
input2.getDataType().isMatrix() ? input2 : input1;
BinaryOp bop = new BinaryOp(mainInput.getName(), getImplicitDataType(input1, input2),
getImplicitValueType(input1, input2), op, input1, input2);
//cleanup value type for relational operations
if( bop.isPPredOperation() && bop.getDataType().isScalar() )
bop.setValueType(ValueType.BOOLEAN);
bop.setOuterVectorOperation(false);
bop.setBlocksize(mainInput.getBlocksize());
copyLineNumbers(mainInput, bop);
bop.refreshSizeInformation();
return bop;
}

public static DataType getImplicitDataType(Hop... inputs) {
for (int i = 0; i < inputs.length; i++)
if (inputs[i].getDataType().isMatrix())
return inputs[i].getDataType();

return inputs[0].getDataType();
}

public static ValueType getImplicitValueType(Hop... inputs) {
ValueType out = null;
for (int i = 0; i < inputs.length; i++) {
switch (inputs[i].getValueType()) {
case FP64:
return inputs[i].getValueType();
case FP32:
out = inputs[i].getValueType();
break;
case INT64:
out = implicitValueType(out, ValueType.INT64);
break;
case INT32:
out = implicitValueType(out, ValueType.INT32);
break;
case BOOLEAN:
out = implicitValueType(out, ValueType.BOOLEAN);
break;
}
}

return out == null ? inputs[0].getValueType() : out;
}

private static ValueType implicitValueType(ValueType type1, ValueType type2) {
int rank1 = getTypeRank(type1);
int rank2 = getTypeRank(type2);

if (rank1 == Integer.MIN_VALUE && rank2 == Integer.MIN_VALUE)
return null;

return rank1 > rank2 ? type1 : type2;
}

private static int getTypeRank(ValueType vt) {
if (vt == null)
return Integer.MIN_VALUE;

switch (vt) {
case FP64:
return 5;
case FP32:
return 4;
case INT64:
return 3;
case INT32:
return 2;
case BOOLEAN:
return 1;
}

return Integer.MIN_VALUE;
}

public static AggUnaryOp createSum( Hop input ) {
return createAggUnaryOp(input, AggOp.SUM, Direction.RowCol);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
import org.apache.sysds.conf.CompilerConfig.ConfigType;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.rewriter.generated.GeneratedRewriteClass;
import org.apache.sysds.hops.rewriter.generated.RewriteAutomaticallyGenerated;
import org.apache.sysds.hops.rewriter.dml.DMLExecutor;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
Expand Down Expand Up @@ -83,6 +86,9 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
_dagRuleSet.add( new RewriteCommonSubexpressionElimination() );
if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
if ( DMLScript.APPLY_GENERATED_REWRITES ) {
_dagRuleSet.add(new RewriteAutomaticallyGenerated(new GeneratedRewriteClass()));
}
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
_dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) //dependency: simplifications (no need to merge leafs again)
Expand Down Expand Up @@ -124,6 +130,9 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
if ( DMLScript.USE_ACCELERATOR ){
_dagRuleSet.add( new RewriteGPUSpecificOps() ); // gpu-specific rewrites
}
if ( DMLScript.APPLY_GENERATED_REWRITES ) {
_dagRuleSet.add(new RewriteAutomaticallyGenerated(new GeneratedRewriteClass()));
}
if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) {
_dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse
if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 )
Expand Down
Loading
Loading