Skip to content

Commit

Permalink
[SYSTEMDS-3729] New roll reorg operations in CP, incl tests
Browse files Browse the repository at this point in the history
Closes #2103.
  • Loading branch information
min-guk authored and mboehm7 committed Sep 23, 2024
1 parent c940502 commit edfce10
Show file tree
Hide file tree
Showing 16 changed files with 473 additions and 39 deletions.
1 change: 1 addition & 0 deletions .github/workflows/javaTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,4 @@ jobs:
name: Java Code Coverage (Jacoco)
path: target/site/jacoco
retention-days: 3

1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/common/Builtins.java
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ public enum Builtins {
RCM("rowClassMeet", "rcm", false, false, ReturnType.MULTI_RETURN),
REMOVE("remove", false, ReturnType.MULTI_RETURN),
REV("rev", false),
ROLL("roll", false),
ROUND("round", false),
ROW_COUNT_DISTINCT("rowCountDistinct",false),
ROWINDEXMAX("rowIndexMax", false),
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/common/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ public boolean isCellOp() {
/** Operations that perform internal reorganization of an allocation */
public enum ReOrgOp {
DIAG, //DIAG_V2M and DIAG_M2V could not be distinguished if sizes unknown
RESHAPE, REV, SORT, TRANS;
RESHAPE, REV, ROLL, SORT, TRANS;

@Override
public String toString() {
Expand Down
22 changes: 20 additions & 2 deletions src/main/java/org/apache/sysds/hops/ReorgOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ public void checkArity() {
case REV:
HopsException.check(sz == 1, this, "should have arity 1 for op %s but has arity %d", _op, sz);
break;
case ROLL:
HopsException.check(sz == 2, this, "should have arity 2 for op %s but has arity %d", _op, sz);
break;
case RESHAPE:
case SORT:
HopsException.check(sz == 5, this, "should have arity 5 for op %s but has arity %d", _op, sz);
Expand Down Expand Up @@ -125,6 +128,7 @@ public boolean isGPUEnabled() {
}
case DIAG:
case REV:
case ROLL:
case SORT:
return false;
default:
Expand Down Expand Up @@ -175,6 +179,18 @@ else if( getDim1()==1 && getDim2()==1 )
setLops(transform1);
break;
}
case ROLL: {
Lop[] linputs = new Lop[2]; //input, shift
for (int i = 0; i < 2; i++)
linputs[i] = getInput().get(i).constructLops();

Transform transform1 = new Transform(linputs, _op, getDataType(), getValueType(), et, 1);

setOutputDimensions(transform1);
setLineNumbers(transform1);
setLops(transform1);
break;
}
case RESHAPE: {
Lop[] linputs = new Lop[5]; //main, rows, cols, dims, byrow
for (int i = 0; i < 5; i++)
Expand Down Expand Up @@ -279,9 +295,10 @@ protected DataCharacteristics inferOutputCharacteristics( MemoTable memo )
ret = new MatrixCharacteristics(dc.getCols(), dc.getRows(), -1, dc.getNonZeros());
break;
}
case REV: {
case REV:
case ROLL: {
// dims and nnz are exactly the same as in input
if( dc.dimsKnown() )
if (dc.dimsKnown())
ret = new MatrixCharacteristics(dc.getRows(), dc.getCols(), -1, dc.getNonZeros());
break;
}
Expand Down Expand Up @@ -397,6 +414,7 @@ public void refreshSizeInformation()
break;
}
case REV:
case ROLL:
{
// dims and nnz are exactly the same as in input
setDim1(input1.getDim1());
Expand Down
11 changes: 10 additions & 1 deletion src/main/java/org/apache/sysds/lops/Transform.java
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ private String getOpcode() {
case REV:
// Transpose a matrix
return "rev";


case ROLL:
return "roll";

case DIAG:
// Transform a vector into a diagonal matrix
return "rdiag";
Expand All @@ -138,6 +141,12 @@ public String getInstructions(String input1, String output) {
return getInstructions(input1, 1, output);
}

@Override
public String getInstructions(String input1, String input2, String output) {
//opcodes: roll
return getInstructions(input1, 2, output);
}

@Override
public String getInstructions(String input1, String input2, String input3, String input4, String output) {
//opcodes: rsort
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,17 @@ else if( getOpCode() == Builtins.RBIND ) {
output.setBlocksize (id.getBlocksize());
output.setValueType(id.getValueType());
break;


case ROLL:
checkNumParameters(2);
checkMatrixParam(getFirstExpr());
checkScalarParam(getSecondExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(id.getDim1(), id.getDim2());
output.setBlocksize(id.getBlocksize());
output.setValueType(id.getValueType());
break;

case DIAG:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/org/apache/sysds/parser/DMLTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -2481,6 +2481,14 @@ else if ( sop.equalsIgnoreCase("!=") )
target.getValueType(), ReOrgOp.valueOf(source.getOpCode().name()), expr);
break;

case ROLL:
ArrayList<Hop> inputs = new ArrayList<>();
inputs.add(expr);
inputs.add(expr2);
currBuiltinOp = new ReorgOp(target.getName(), DataType.MATRIX,
target.getValueType(), ReOrgOp.valueOf(source.getOpCode().name()), inputs);
break;

case CBIND:
case RBIND:
OpOp2 appendOp2 = (source.getOpCode()==Builtins.CBIND) ? OpOp2.CBIND : OpOp2.RBIND;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.functionobjects;

import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.runtime.meta.DataCharacteristics;

/**
* This index function is NOT used for actual sorting but just as a reference
* in ReorgOperator in order to identify sort operations.
*/
public class RollIndex extends IndexFunction {
private static final long serialVersionUID = -8446389232078905200L;

private final int _shift;

public RollIndex(int shift) {
_shift = shift;
}

public int getShift() {
return _shift;
}

@Override
public boolean computeDimension(int row, int col, CellIndex retDim) {
retDim.set(row, col);
return false;
}

@Override
public boolean computeDimension(DataCharacteristics in, DataCharacteristics out) {
out.set(in.getRows(), in.getCols(), in.getBlocksize(), in.getNonZeros());
return false;
}

@Override
public void execute(MatrixIndexes in, MatrixIndexes out) {
throw new NotImplementedException();
}

@Override
public void execute(CellIndex in, CellIndex out) {
throw new NotImplementedException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ public class CPInstructionParser extends InstructionParser {
// Reorg Instruction Opcodes (repositioning of existing values)
String2CPInstructionType.put( "r'" , CPType.Reorg);
String2CPInstructionType.put( "rev" , CPType.Reorg);
String2CPInstructionType.put( "roll" , CPType.Reorg);
String2CPInstructionType.put( "rdiag" , CPType.Reorg);
String2CPInstructionType.put( "rshape" , CPType.Reshape);
String2CPInstructionType.put( "rsort" , CPType.Reorg);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.functionobjects.DiagIndex;
import org.apache.sysds.runtime.functionobjects.RevIndex;
import org.apache.sysds.runtime.functionobjects.RollIndex;
import org.apache.sysds.runtime.functionobjects.SortIndex;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
Expand All @@ -38,51 +39,58 @@ public class ReorgCPInstruction extends UnaryCPInstruction {
private final CPOperand _col;
private final CPOperand _desc;
private final CPOperand _ixret;
private final CPOperand _shift;

/**
* for opcodes r' and rdiag
*
* @param op
* operator
* @param in
* cp input operand
* @param out
* cp output operand
* @param opcode
* the opcode
* @param istr
* ?
*
* @param op operator
* @param in cp input operand
* @param out cp output operand
* @param opcode the opcode
* @param istr ?
*/
private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr) {
this(op, in, out, null, null, null, opcode, istr);
}

/**
* for opcode rsort
*
* @param op
* operator
* @param in
* cp input operand
* @param col
* ?
* @param desc
* ?
* @param ixret
* ?
* @param out
* cp output operand
* @param opcode
* the opcode
* @param istr
* ?
*
* @param op operator
* @param in cp input operand
* @param col ?
* @param desc ?
* @param ixret ?
* @param out cp output operand
* @param opcode the opcode
* @param istr ?
*/
private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand col, CPOperand desc, CPOperand ixret,
String opcode, String istr) {
String opcode, String istr) {
super(CPType.Reorg, op, in, out, opcode, istr);
_col = col;
_desc = desc;
_ixret = ixret;
_shift = null;
}

/**
* for opcode roll
*
* @param op operator
* @param in cp input operand
* @param shift ?
* @param out cp output operand
* @param opcode the opcode
* @param istr ?
*/
private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) {
super(CPType.Reorg, op, in, out, opcode, istr);
_col = null;
_desc = null;
_ixret = null;
_shift = shift;
}

public static ReorgCPInstruction parseInstruction ( String str ) {
Expand All @@ -103,6 +111,13 @@ else if ( opcode.equalsIgnoreCase("rev") ) {
parseUnaryInstruction(str, in, out); //max 2 operands
return new ReorgCPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
}
else if (opcode.equalsIgnoreCase("roll")) {
InstructionUtils.checkNumFields(str, 3);
in.split(parts[1]);
out.split(parts[3]);
CPOperand shift = new CPOperand(parts[2]);
return new ReorgCPInstruction(new ReorgOperator(new RollIndex(0)), in, out, shift, opcode, str);
}
else if ( opcode.equalsIgnoreCase("rdiag") ) {
parseUnaryInstruction(str, in, out); //max 2 operands
return new ReorgCPInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
Expand Down Expand Up @@ -136,7 +151,12 @@ public void processInstruction(ExecutionContext ec) {
boolean ixret = ec.getScalarInput(_ixret).getBooleanValue();
r_op = r_op.setFn(new SortIndex(cols, desc, ixret));
}


if (r_op.fn instanceof RollIndex) {
int shift = (int) ec.getScalarInput(_shift).getLongValue();
r_op = r_op.setFn(new RollIndex(shift));
}

//execute operation
MatrixBlock soresBlock = matBlock.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public class LineageCacheConfig

// Relatively expensive instructions. Most include shuffles.
private static final String[] PERSIST_OPCODES1 = new String[] {
"cpmm", "rmm", "pmm", "zipmm", "rev", "rshape", "rsort", "-", "*", "+",
"cpmm", "rmm", "pmm", "zipmm", "rev", "roll", "rshape", "rsort", "-", "*", "+",
"/", "%%", "%/%", "1-*", "^", "^2", "*2", "==", "!=", "<", ">",
"<=", ">=", "&&", "||", "xor", "max", "min", "rmempty", "rappend",
"gappend", "galignedappend", "rbind", "cbind", "nmin", "nmax",
Expand Down
Loading

0 comments on commit edfce10

Please sign in to comment.