Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/apache/systemds into main
Browse files Browse the repository at this point in the history
  • Loading branch information
min-guk committed Sep 23, 2024
2 parents 023c591 + c940502 commit db4c085
Show file tree
Hide file tree
Showing 110 changed files with 7,737 additions and 1,917 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/javaTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
tests: [
"org.apache.sysds.test.applications.**",
"**.test.usertest.**",
"**.component.c**.**",
"**.component.c**.** -Dtest-threadCount=1 -Dtest-forkCount=1",
"**.component.e**.**,**.component.f**.**,**.component.m**.**",
"**.component.p**.**,**.component.r**.**,**.component.s**.**,**.component.t**.**,**.component.u**.**",
"**.functions.a**.**,**.functions.binary.matrix.**,**.functions.binary.scalar.**,**.functions.binary.tensor.**",
Expand Down
524 changes: 300 additions & 224 deletions scripts/builtin/incSliceLine.dml

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions scripts/perftest/resource/test_ops.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

X = read($X);
Y = read($Y);
Z = read($Z);

A = X%*%Y;
B = A + Z;
C = B[1:1000,1:1000];

print(nrow(A));
print(nrow(B));
print(nrow(C));

4 changes: 2 additions & 2 deletions src/main/java/org/apache/sysds/common/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -739,10 +739,10 @@ public String toString() {

/** Operations that require a variable number of operands*/
public enum OpOpN {
PRINTF, CBIND, RBIND, MIN, MAX, PLUS, EVAL, LIST;
PRINTF, CBIND, RBIND, MIN, MAX, PLUS, MULT, EVAL, LIST;

public boolean isCellOp() {
return this == MIN || this == MAX || this == PLUS;
return this == MIN || this == MAX || this == PLUS || this == MULT;
}
}

Expand Down
149 changes: 86 additions & 63 deletions src/main/java/org/apache/sysds/hops/AggUnaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.sysds.hops;

import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
Expand All @@ -30,6 +31,7 @@
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.lops.Nary;
import org.apache.sysds.lops.PartialAggregate;
import org.apache.sysds.lops.TernaryAggregate;
import org.apache.sysds.lops.UAggOuterChain;
Expand All @@ -38,6 +40,8 @@
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

import java.util.List;

// Aggregate unary (cell) operation: Sum (aij), col_sum, row_sum

public class AggUnaryOp extends MultiThreadedHop
Expand Down Expand Up @@ -475,6 +479,17 @@ else if (binput1.getOp() == OpOp2.MULT ) {
}
}
}
if (input1.getParent().size() == 1
&& input1 instanceof NaryOp) { //sum single consumer
NaryOp nop = (NaryOp) input1;
if(nop.getOp() == Types.OpOpN.MULT){
List<Hop> inputsN = nop.getInput();
if(inputsN.size() == 3){
ret = HopRewriteUtils.isEqualSize(inputsN.get(0), inputsN.get(1)) &&
HopRewriteUtils.isEqualSize(inputsN.get(1), inputsN.get(2));
}
}
}
}
return ret;
}
Expand Down Expand Up @@ -554,83 +569,91 @@ private boolean isUnaryAggregateOuterCPRewriteApplicable() {

private Lop constructLopsTernaryAggregateRewrite(ExecType et)
{
BinaryOp input1 = (BinaryOp)getInput().get(0);
Hop input11 = input1.getInput().get(0);
Hop input12 = input1.getInput().get(1);

Lop in1 = null, in2 = null, in3 = null;
boolean handled = false;

if (input1.getOp() == OpOp2.POW) {
assert(HopRewriteUtils.isLiteralOfValue(input12, 3)) : "this case can only occur with a power of 3";
in1 = input11.constructLops();
in2 = in1;
in3 = in1;
handled = true;
}
else if (HopRewriteUtils.isBinary(input11, OpOp2.MULT, OpOp2.POW) ) {
BinaryOp b11 = (BinaryOp)input11;
switch( b11.getOp() ) {
case MULT: // A*B*C case
in1 = input11.getInput().get(0).constructLops();
in2 = input11.getInput().get(1).constructLops();
in3 = input12.constructLops();
handled = true;
break;
case POW: // A*A*B case
Hop b112 = b11.getInput().get(1);
if ( !(input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT)
&& HopRewriteUtils.isLiteralOfValue(b112, 2) ) {
in1 = b11.getInput().get(0).constructLops();
in2 = in1;
in3 = input12.constructLops();
handled = true;
}
break;
default: break;
}
}
else if( HopRewriteUtils.isBinary(input12, OpOp2.MULT, OpOp2.POW) ) {
BinaryOp b12 = (BinaryOp)input12;
switch (b12.getOp()) {
case MULT: // A*B*C case
Hop input = getInput().get(0);
if(input instanceof BinaryOp) {
BinaryOp input1 = (BinaryOp) input;
Hop input11 = input1.getInput().get(0);
Hop input12 = input1.getInput().get(1);

boolean handled = false;

if (input1.getOp() == OpOp2.POW) {
assert (HopRewriteUtils.isLiteralOfValue(input12, 3)) : "this case can only occur with a power of 3";
in1 = input11.constructLops();
in2 = input12.getInput().get(0).constructLops();
in3 = input12.getInput().get(1).constructLops();
in2 = in1;
in3 = in1;
handled = true;
break;
case POW: // A*B*B case
Hop b112 = b12.getInput().get(1);
if ( HopRewriteUtils.isLiteralOfValue(b112, 2) ) {
in1 = b12.getInput().get(0).constructLops();
in2 = in1;
in3 = input11.constructLops();
handled = true;
} else if (HopRewriteUtils.isBinary(input11, OpOp2.MULT, OpOp2.POW)) {
BinaryOp b11 = (BinaryOp) input11;
switch (b11.getOp()) {
case MULT: // A*B*C case
in1 = input11.getInput().get(0).constructLops();
in2 = input11.getInput().get(1).constructLops();
in3 = input12.constructLops();
handled = true;
break;
case POW: // A*A*B case
Hop b112 = b11.getInput().get(1);
if (!(input12 instanceof BinaryOp && ((BinaryOp) input12).getOp() == OpOp2.MULT)
&& HopRewriteUtils.isLiteralOfValue(b112, 2)) {
in1 = b11.getInput().get(0).constructLops();
in2 = in1;
in3 = input12.constructLops();
handled = true;
}
break;
default:
break;
}
} else if (HopRewriteUtils.isBinary(input12, OpOp2.MULT, OpOp2.POW)) {
BinaryOp b12 = (BinaryOp) input12;
switch (b12.getOp()) {
case MULT: // A*B*C case
in1 = input11.constructLops();
in2 = input12.getInput().get(0).constructLops();
in3 = input12.getInput().get(1).constructLops();
handled = true;
break;
case POW: // A*B*B case
Hop b112 = b12.getInput().get(1);
if (HopRewriteUtils.isLiteralOfValue(b112, 2)) {
in1 = b12.getInput().get(0).constructLops();
in2 = in1;
in3 = input11.constructLops();
handled = true;
}
break;
default:
break;
}
break;
default: break;
}
}

if (!handled) {
in1 = input11.constructLops();
in2 = input12.constructLops();
in3 = new LiteralOp(1).constructLops();
if (!handled) {
in1 = input11.constructLops();
in2 = input12.constructLops();
in3 = new LiteralOp(1).constructLops();
}
} else {
NaryOp input1 = (NaryOp) input;
in1 = input1.getInput().get(0).constructLops();
in2 = input1.getInput().get(1).constructLops();
in3 = input1.getInput().get(2).constructLops();
}

//create new ternary aggregate operator
int k = OptimizerUtils.getConstrainedNumThreads( _maxNumThreads );
//create new ternary aggregate operator
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
// The execution type of a unary aggregate instruction should depend on the execution type of inputs to avoid OOM
// Since we only support matrix-vector and not vector-matrix, checking the execution type of input1 should suffice.
ExecType et_input = input1.optFindExecType();
ExecType et_input = input.optFindExecType();
// Because ternary aggregate are not supported on GPU
et_input = et_input == ExecType.GPU ? ExecType.CP : et_input;
et_input = et_input == ExecType.GPU ? ExecType.CP : et_input;
// If forced ExecType is FED, it means that the federated planner updated the ExecType and
// execution may fail if ExecType is not FED
et_input = (getForcedExecType() == ExecType.FED) ? ExecType.FED : et_input;
return new TernaryAggregate(in1, in2, in3, AggOp.SUM,
OpOp2.MULT, _direction, getDataType(), ValueType.FP64, et_input, k);

return new TernaryAggregate(in1, in2, in3, AggOp.SUM,
OpOp2.MULT, _direction, getDataType(), ValueType.FP64, et_input, k);
}

@Override
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/org/apache/sysds/hops/NaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) {
HopRewriteUtils.getSumValidInputNnz(dc, true));
case MIN:
case MAX:
case PLUS: return new MatrixCharacteristics(
case PLUS:
case MULT: return new MatrixCharacteristics(
HopRewriteUtils.getMaxInputDim(this, true),
HopRewriteUtils.getMaxInputDim(this, false), -1, -1);
case LIST:
Expand Down Expand Up @@ -230,6 +231,7 @@ public void refreshSizeInformation() {
case MIN:
case MAX:
case PLUS:
case MULT:
setDim1(getDataType().isScalar() ? 0 : HopRewriteUtils.getMaxInputDim(this, true));
setDim2(getDataType().isScalar() ? 0 : HopRewriteUtils.getMaxInputDim(this, false));
break;
Expand Down
11 changes: 9 additions & 2 deletions src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -745,8 +745,15 @@ public static LeftIndexingOp createLeftIndexingOp(Hop lhs, Hop rhs, Hop rl, Hop

public static NaryOp createNary(OpOpN op, Hop... inputs) {
Hop mainInput = inputs[0];
NaryOp nop = new NaryOp(mainInput.getName(), mainInput.getDataType(),
mainInput.getValueType(), op, inputs);
// safe for unordered inputs of Scalars and Matrices
// e.g.: S*M*S = M
// safe for Scalar with different value type
// e.g.: Scalar(Int) * Scalar(FP64) = Scalar(FP64)
boolean containsMatrix = Arrays.stream(inputs).anyMatch(Hop::isMatrix);
boolean containsFP64 = Arrays.stream(inputs).anyMatch(h -> h.getValueType() == ValueType.FP64);
DataType dtOut = containsMatrix ? DataType.MATRIX : mainInput.getDataType();
ValueType vtOut = containsFP64? ValueType.FP64 : mainInput.getValueType();
NaryOp nop = new NaryOp(mainInput.getName(), dtOut, vtOut, op, inputs);
nop.setBlocksize(mainInput.getBlocksize());
copyLineNumbers(mainInput, nop);
nop.refreshSizeInformation();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2801,8 +2801,8 @@ else if( HopRewriteUtils.isBasic1NSequence(second, first, true)

private static Hop foldMultipleMinMaxOperations(Hop hi)
{
if( (HopRewriteUtils.isBinary(hi, OpOp2.MIN, OpOp2.MAX, OpOp2.PLUS)
|| HopRewriteUtils.isNary(hi, OpOpN.MIN, OpOpN.MAX, OpOpN.PLUS))
if( (HopRewriteUtils.isBinary(hi, OpOp2.MIN, OpOp2.MAX, OpOp2.PLUS, OpOp2.MULT)
|| HopRewriteUtils.isNary(hi, OpOpN.MIN, OpOpN.MAX, OpOpN.PLUS, OpOpN.MULT))
&& hi.getValueType() != ValueType.STRING //exclude string concat
&& HopRewriteUtils.isNotMatrixVectorBinaryOperation(hi))
{
Expand Down Expand Up @@ -2839,7 +2839,7 @@ private static Hop foldMultipleMinMaxOperations(Hop hi)
for( Hop p : parents )
HopRewriteUtils.replaceChildReference(p, hi, hnew);
hi = hnew;
LOG.debug("Applied foldMultipleMinMaxPlusOperations (line "+hi.getBeginLine()+").");
LOG.debug("Applied foldMultipleMinMaxPlusMultOperations (line "+hi.getBeginLine()+").");
}
else {
converged = true;
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/apache/sysds/lops/Nary.java
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ private String getOpcode() {
return "n"+operationType.name().toLowerCase();
case PLUS:
return "n+";
case MULT:
return "n*";
default:
throw new UnsupportedOperationException(
"Nary operation type (" + operationType + ") is not defined.");
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/org/apache/sysds/parser/IfStatementBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,20 @@ public VariableSet validate(DMLProgram dmlProg, VariableSet ids, HashMap<String,

/////////////////////////////////////////////////////////////////////////////////
// check data type and value type are same for updated variables in both
// if statement and else statement
// if statement and else statement
// (reject conditional data type change)
/////////////////////////////////////////////////////////////////////////////////
for (String updatedVar : this._updated.getVariableNames()){
DataIdentifier origVersion = idsOrigCopy.getVariable(updatedVar);
DataIdentifier ifVersion = idsIfCopy.getVariable(updatedVar);
DataIdentifier ifVersion = idsIfCopy.getVariable(updatedVar);
DataIdentifier elseVersion = idsElseCopy.getVariable(updatedVar);

//data type handling: reject conditional data type change
if( ifVersion != null && elseVersion != null ) //both branches exist
{
if (!ifVersion.getOutput().getDataType().equals(elseVersion.getOutput().getDataType())){
raiseValidateError("IfStatementBlock has unsupported conditional data type change of variable '"+updatedVar+"' in if/else branch.", conditional);
}
}
}
else if( origVersion !=null ) //only if branch exists
{
Expand All @@ -99,7 +99,7 @@ else if( origVersion !=null ) //only if branch exists
}
}

//value type handling
//value type handling
if (ifVersion != null && elseVersion != null && !ifVersion.getOutput().getValueType().equals(elseVersion.getOutput().getValueType())){
LOG.warn(elseVersion.printWarningLocation() + "Variable " + elseVersion.getName() + " defined with different value type in if and else clause.");
}
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/org/apache/sysds/parser/StatementBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,10 @@ else if ( source instanceof BuiltinFunctionExpression || source instanceof Param
ids.addVariable(targetList.get(j).getName(), (DataIdentifier)outputs[j]);
}
}

// remove updated constant vars (for correctness)
for(DataIdentifier target : targetList)
currConstVars.remove(target.getName());
}

public void setStatementFormatType(OutputStatement s, boolean conditionalValidate)
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/resource/CloudUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ public static InstanceSize customValueOf(String name) {
}
}

public static final String SPARK_VERSION = "3.3.0";
public static final double MINIMAL_EXECUTION_TIME = 120; // seconds; NOTE: set always equal or higher than DEFAULT_CLUSTER_LAUNCH_TIME

public static final double DEFAULT_CLUSTER_LAUNCH_TIME = 120; // seconds; NOTE: set always to at least 60 seconds

public static long GBtoBytes(double gb) {
Expand Down
Loading

0 comments on commit db4c085

Please sign in to comment.