diff --git a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java index 5e4dbaedebc..5c72b854362 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java @@ -166,6 +166,7 @@ public class SPInstructionParser extends InstructionParser // Reorg Instruction Opcodes (repositioning of existing values) String2SPInstructionType.put( "r'", SPType.Reorg); String2SPInstructionType.put( "rev", SPType.Reorg); + String2SPInstructionType.put( "roll", SPType.Reorg); String2SPInstructionType.put( "rdiag", SPType.Reorg); String2SPInstructionType.put( "rshape", SPType.MatrixReshape); String2SPInstructionType.put( "rsort", SPType.Reorg); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java index de01a71ca83..b096405959b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java @@ -36,6 +36,7 @@ import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysds.runtime.functionobjects.DiagIndex; import org.apache.sysds.runtime.functionobjects.RevIndex; +import org.apache.sysds.runtime.functionobjects.RollIndex; import org.apache.sysds.runtime.functionobjects.SortIndex; import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.instructions.InstructionUtils; @@ -68,6 +69,7 @@ public class ReorgSPInstruction extends UnarySPInstruction { private CPOperand _desc = null; private CPOperand _ixret = null; private boolean _bSortIndInMem = false; + private CPOperand _shift = null; private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr) { super(SPType.Reorg, op, in, out, opcode, istr); @@ -82,6 +84,11 @@ private ReorgSPInstruction(Operator op, CPOperand in, CPOperand col, CPOperand d _bSortIndInMem = bSortIndInMem; } + private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) { + this(op, in, out, opcode, istr); + _shift = shift; + } + public static ReorgSPInstruction parseInstruction ( String str ) { CPOperand in = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); @@ -95,6 +102,15 @@ else if ( opcode.equalsIgnoreCase("rev") ) { parseUnaryInstruction(str, in, out); //max 2 operands return new ReorgSPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str); } + else if (opcode.equalsIgnoreCase("roll")) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(str, 3); + in.split(parts[1]); + out.split(parts[3]); + CPOperand shift = new CPOperand(parts[2]); + return new ReorgSPInstruction(new ReorgOperator(new RollIndex(0)), + in, out, shift, opcode, str); +} else if ( opcode.equalsIgnoreCase("rdiag") ) { parseUnaryInstruction(str, in, out); //max 2 operands return new ReorgSPInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str); @@ -141,6 +157,14 @@ else if( opcode.equalsIgnoreCase("rev") ) //REVERSE if( mcIn.getRows() % mcIn.getBlocksize() != 0 ) out = RDDAggregateUtils.mergeByKey(out, false); } + else if (opcode.equalsIgnoreCase("roll")) // ROLL + { + int shift = (int) ec.getScalarInput(_shift).getLongValue(); + + //execute roll reorg operation + out = in1.flatMapToPair(new RDDRollFunction(mcIn, shift)); + out = RDDAggregateUtils.mergeByKey(out, false); + } else if ( opcode.equalsIgnoreCase("rdiag") ) // DIAG { if(mcIn.getCols() == 1) { // diagV2M @@ -233,7 +257,7 @@ else if ( getOpcode().equalsIgnoreCase("rsort") ) { boolean ixret = sec.getScalarInput(_ixret).getBooleanValue(); mcOut.set(mc1.getRows(), ixret?1:mc1.getCols(), mc1.getBlocksize(), mc1.getBlocksize()); } - else { //e.g., rev + else { //e.g., rev, roll mcOut.set(mc1); } } @@ -243,7 +267,7 @@ else if ( getOpcode().equalsIgnoreCase("rsort") ) { boolean sortIx = getOpcode().equalsIgnoreCase("rsort") && sec.getScalarInput(_ixret).getBooleanValue(); if( sortIx ) mcOut.setNonZeros(mc1.getRows()); - else //default (r', rdiag, rev, rsort data) + else //default (r', rdiag, rev, roll, rsort data) mcOut.setNonZeros(mc1.getNonZeros()); } } @@ -315,6 +339,31 @@ public Iterator> call( Tuple2, MatrixIndexes, MatrixBlock> { + private static final long serialVersionUID = 1183373828539843938L; + + private DataCharacteristics _mcIn = null; + private int _shift = 0; + + public RDDRollFunction(DataCharacteristics mcIn, int shift) { + _mcIn = mcIn; + _shift = shift; + } + + @Override + public Iterator> call(Tuple2 arg0) { + //construct input + IndexedMatrixValue in = SparkUtils.toIndexedMatrixBlock(arg0); + + //execute roll operation + ArrayList out = new ArrayList<>(); + LibMatrixReorg.roll(in, _mcIn.getRows(), _mcIn.getBlocksize(), _shift, out); + + //construct output + return SparkUtils.fromIndexedMatrixBlock(out).iterator(); + } + } + private static class ExtractColumn implements Function { private static final long serialVersionUID = -1472164797288449559L; diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java index 53e8a888326..82defddca87 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java @@ -445,6 +445,36 @@ public static MatrixBlock roll(MatrixBlock in, MatrixBlock out, int shift) { return out; } + public static void roll(IndexedMatrixValue in, long rlen, int blen, int shift, ArrayList out) { + MatrixIndexes inMtxIdx = in.getIndexes(); + MatrixBlock inMtxBlk = (MatrixBlock) in.getValue(); + shift %= ((rlen != 0) ? (int) rlen : 1); // Handle row length boundaries for shift + + long inRowIdx = UtilFunctions.computeCellIndex(inMtxIdx.getRowIndex(), blen, 0) - 1; + + int totalCopyLen = 0; + while (totalCopyLen < inMtxBlk.getNumRows()) { + // Calculate row and block index for the current part + long outRowIdx = (inRowIdx + shift) % rlen; + long outBlkIdx = UtilFunctions.computeBlockIndex(outRowIdx + 1, blen); + int outBlkLen = UtilFunctions.computeBlockSize(rlen, outBlkIdx, blen); + int outRowIdxInBlk = (int) (outRowIdx % blen); + + // Calculate copy length + int copyLen = Math.min((int) (outBlkLen - outRowIdxInBlk), inMtxBlk.getNumRows() - totalCopyLen); + + // Create the output block and copy data + MatrixIndexes outMtxIdx = new MatrixIndexes(outBlkIdx, inMtxIdx.getColumnIndex()); + MatrixBlock outMtxBlk = new MatrixBlock(outBlkLen, inMtxBlk.getNumColumns(), inMtxBlk.isInSparseFormat()); + copyMtx(inMtxBlk, outMtxBlk, totalCopyLen, outRowIdxInBlk, copyLen, false, false); + out.add(new IndexedMatrixValue(outMtxIdx, outMtxBlk)); + + // Update counters for next iteration + totalCopyLen += copyLen; + inRowIdx += totalCopyLen; + } + } + public static MatrixBlock diag( MatrixBlock in, MatrixBlock out ) { //Timing time = new Timing(true); @@ -2274,77 +2304,85 @@ private static void reverseSparse(MatrixBlock in, MatrixBlock out) { private static void rollDense(MatrixBlock in, MatrixBlock out, int shift) { final int m = in.rlen; - final int n = in.clen; + shift %= (m != 0 ? m : 1); // roll matrix with axis=none - //set basic meta data and allocate output - out.sparse = false; - out.nonZeros = in.nonZeros; - out.allocateDenseBlock(false); + copyDenseMtx(in, out, 0, shift, m - shift, false, true); + copyDenseMtx(in, out, m - shift, 0, shift, true, true); + } - //copy all rows into target positions - if (n == 1) { //column vector + private static void rollSparse(MatrixBlock in, MatrixBlock out, int shift) { + final int m = in.rlen; + shift %= (m != 0 ? m : 1); // roll matrix with axis=0 + + copySparseMtx(in, out, 0, shift, m - shift, false, true); + copySparseMtx(in, out, m-shift, 0, shift, false, true); + } + + public static void copyMtx(MatrixBlock in, MatrixBlock out, int inStart, int outStart, int copyLen, + boolean isAllocated, boolean copyTotalNonZeros) { + if (in.isInSparseFormat()){ + copySparseMtx(in, out, inStart, outStart, copyLen, isAllocated, copyTotalNonZeros); + } else { + copyDenseMtx(in, out, inStart, outStart, copyLen, isAllocated, copyTotalNonZeros); + } + } + + public static void copyDenseMtx(MatrixBlock in, MatrixBlock out, int inIdx, int outIdx, int copyLen, + boolean isAllocated, boolean copyTotalNonZeros) { + int clen = in.clen; + + // set basic meta data and allocate output + if (!isAllocated){ + out.sparse = false; + if (copyTotalNonZeros) out.nonZeros = in.nonZeros; + out.allocateDenseBlock(false); + } + + // copy all rows into target positions + if (clen == 1) { //column vector double[] a = in.getDenseBlockValues(); double[] c = out.getDenseBlockValues(); - // roll matrix with axis=none - shift %= (m != 0 ? m : 1); - - System.arraycopy(a, 0, c, shift, m - shift); - System.arraycopy(a, m - shift, c, 0, shift); - } else { //general matrix case + System.arraycopy(a, inIdx, c, outIdx, copyLen); + } else { DenseBlock a = in.getDenseBlock(); DenseBlock c = out.getDenseBlock(); - // roll matrix with axis=0 - shift %= (m != 0 ? m : 1); + while (copyLen > 0) { + System.arraycopy(a.values(inIdx), a.pos(inIdx), + c.values(outIdx), c.pos(outIdx), clen); - for (int i = 0; i < m - shift; i++) { - System.arraycopy(a.values(i), a.pos(i), c.values(i + shift), c.pos(i + shift), n); - } - - for (int i = m - shift; i < m; i++) { - System.arraycopy(a.values(i), a.pos(i), c.values(i + shift - m), c.pos(i + shift - m), n); + inIdx++; outIdx++; copyLen--; } } } - private static void rollSparse(MatrixBlock in, MatrixBlock out, int shift) { - final int m = in.rlen; - + private static void copySparseMtx(MatrixBlock in, MatrixBlock out, int inIdx, int outIdx, int copyLen, + boolean isAllocated, boolean copyTotalNonZeros) { //set basic meta data and allocate output - out.sparse = true; - out.nonZeros = in.nonZeros; - out.allocateSparseRowsBlock(false); + if (!isAllocated){ + out.sparse = true; + if (copyTotalNonZeros) out.nonZeros = in.nonZeros; + out.allocateSparseRowsBlock(false); + } - //copy all rows into target positions SparseBlock a = in.getSparseBlock(); SparseBlock c = out.getSparseBlock(); - // roll matrix with axis=0 - shift %= (m != 0 ? m : 1); - - for (int i = 0; i < m - shift; i++) { - if (a.isEmpty(i)) continue; // skip empty rows + while (copyLen > 0) { + if (a.isEmpty(inIdx)) continue; // skip empty rows - rollSparseRow(a, c, i, i + shift); - } - - for (int i = m - shift; i < m; i++) { - if (a.isEmpty(i)) continue; // skip empty rows - - rollSparseRow(a, c, i, i + shift - m); - } - } + final int apos = a.pos(inIdx); + final int alen = a.size(inIdx) + apos; + final int[] aix = a.indexes(inIdx); + final double[] avals = a.values(inIdx); - private static void rollSparseRow(SparseBlock a, SparseBlock c, int oriIdx, int shiftIdx) { - final int apos = a.pos(oriIdx); - final int alen = a.size(oriIdx) + apos; - final int[] aix = a.indexes(oriIdx); - final double[] avals = a.values(oriIdx); + // copy only non-zero elements + for (int k = apos; k < alen; k++) { + c.set(outIdx, aix[k], avals[k]); + } - // copy only non-zero elements - for (int k = apos; k < alen; k++) { - c.set(shiftIdx, aix[k], avals[k]); + inIdx++; outIdx++; copyLen--; } } diff --git a/src/test/java/org/apache/sysds/test/functions/reorg/FullRollTest.java b/src/test/java/org/apache/sysds/test/functions/reorg/FullRollTest.java new file mode 100644 index 00000000000..26d0ef8bbaa --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/reorg/FullRollTest.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.reorg; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + + +public class FullRollTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "Roll1"; + private final static String TEST_NAME2 = "Roll2"; + + private final static String TEST_DIR = "functions/reorg/"; + private static final String TEST_CLASS_DIR = TEST_DIR + FullRollTest.class.getSimpleName() + "/"; + + private final static int rows1 = 2017; + private final static int cols1 = 1001; + private final static double sparsity1 = 0.7; + private final static double sparsity2 = 0.1; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"B", "C"})); + } + + @Test + public void testRollVectorDense() { + runRollTest(TEST_NAME1, false, false); + } + + @Test + public void testRollVectorSparse() { + runRollTest(TEST_NAME1, false, true); + } + + @Test + public void testRollMatrixDense() { + runRollTest(TEST_NAME1, true, false); + } + + @Test + public void testRollMatrixSparse() { + runRollTest(TEST_NAME1, true, true); + } + + private void runRollTest(String testname, boolean matrix, boolean sparse) { + //rtplatform for MR + ExecMode platformOld = rtplatform; + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + + String TEST_NAME = testname; + + try { + int cols = matrix ? cols1 : 1; + double sparsity = sparse ? sparsity2 : sparsity1; + getAndLoadTestConfiguration(TEST_NAME); + + /* This is for running the junit test the new way, i.e., construct the arguments directly */ + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + + //generate actual dataset + double[][] A = getRandomMatrix(rows1, cols, -1, 1, sparsity, 7); + writeInputMatrixWithMTD("A", A, true); + + // Run test CP + rtplatform = ExecMode.HYBRID; + DMLScript.USE_LOCAL_SPARK_CONFIG = false; + programArgs = new String[]{"-stats", "-explain", "-args", input("A"), output("B")}; + runTest(true, false, null, -1); + + // Run test SP + rtplatform = ExecMode.SPARK; + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + programArgs = new String[]{"-stats", "-explain", "-args", input("A"), output("C")}; + runTest(true, false, null, -1); + + //compare matrices + HashMap dmlfileCP = readDMLMatrixFromOutputDir("B"); + HashMap dmlfileSP = readDMLMatrixFromOutputDir("C"); + + TestUtils.compareMatrices(dmlfileCP, dmlfileSP, 0, "Stat-DML-CP", "Stat-DML-SP"); + + Assert.assertTrue("Missing opcode: roll", Statistics.getCPHeavyHitterOpCodes().contains("roll")); + Assert.assertTrue("Missing opcode: " + Instruction.SP_INST_PREFIX + + "roll", Statistics.getCPHeavyHitterOpCodes().contains(Instruction.SP_INST_PREFIX + "sp_roll")); + } finally { + //reset flags + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } +} diff --git a/src/test/scripts/functions/reorg/Roll1.dml b/src/test/scripts/functions/reorg/Roll1.dml new file mode 100644 index 00000000000..8928fdf2192 --- /dev/null +++ b/src/test/scripts/functions/reorg/Roll1.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +A = read($1); +B = roll(A, 1); +write(B, $2); \ No newline at end of file