Skip to content

Commit

Permalink
[SYSTEMDS-3729] Add roll reorg operations in SP
Browse files Browse the repository at this point in the history
  • Loading branch information
min-guk committed Sep 23, 2024
1 parent b6adcca commit 9a193a8
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ public class SPInstructionParser extends InstructionParser
// Reorg Instruction Opcodes (repositioning of existing values)
String2SPInstructionType.put( "r'", SPType.Reorg);
String2SPInstructionType.put( "rev", SPType.Reorg);
String2SPInstructionType.put( "roll", SPType.Reorg);
String2SPInstructionType.put( "rdiag", SPType.Reorg);
String2SPInstructionType.put( "rshape", SPType.MatrixReshape);
String2SPInstructionType.put( "rsort", SPType.Reorg);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.DiagIndex;
import org.apache.sysds.runtime.functionobjects.RevIndex;
import org.apache.sysds.runtime.functionobjects.RollIndex;
import org.apache.sysds.runtime.functionobjects.SortIndex;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
Expand Down Expand Up @@ -68,6 +69,7 @@ public class ReorgSPInstruction extends UnarySPInstruction {
private CPOperand _desc = null;
private CPOperand _ixret = null;
private boolean _bSortIndInMem = false;
private CPOperand _shift = null;

private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr) {
super(SPType.Reorg, op, in, out, opcode, istr);
Expand All @@ -82,6 +84,11 @@ private ReorgSPInstruction(Operator op, CPOperand in, CPOperand col, CPOperand d
_bSortIndInMem = bSortIndInMem;
}

private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) {
this(op, in, out, opcode, istr);
_shift = shift;
}

public static ReorgSPInstruction parseInstruction ( String str ) {
CPOperand in = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
Expand All @@ -95,6 +102,15 @@ else if ( opcode.equalsIgnoreCase("rev") ) {
parseUnaryInstruction(str, in, out); //max 2 operands
return new ReorgSPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
}
else if (opcode.equalsIgnoreCase("roll")) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
InstructionUtils.checkNumFields(str, 3);
in.split(parts[1]);
out.split(parts[3]);
CPOperand shift = new CPOperand(parts[2]);
return new ReorgSPInstruction(new ReorgOperator(new RollIndex(0)),
in, out, shift, opcode, str);
}
else if ( opcode.equalsIgnoreCase("rdiag") ) {
parseUnaryInstruction(str, in, out); //max 2 operands
return new ReorgSPInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
Expand Down Expand Up @@ -141,6 +157,14 @@ else if( opcode.equalsIgnoreCase("rev") ) //REVERSE
if( mcIn.getRows() % mcIn.getBlocksize() != 0 )
out = RDDAggregateUtils.mergeByKey(out, false);
}
else if (opcode.equalsIgnoreCase("roll")) // ROLL
{
int shift = (int) ec.getScalarInput(_shift).getLongValue();

//execute roll reorg operation
out = in1.flatMapToPair(new RDDRollFunction(mcIn, shift));
out = RDDAggregateUtils.mergeByKey(out, false);
}
else if ( opcode.equalsIgnoreCase("rdiag") ) // DIAG
{
if(mcIn.getCols() == 1) { // diagV2M
Expand Down Expand Up @@ -233,7 +257,7 @@ else if ( getOpcode().equalsIgnoreCase("rsort") ) {
boolean ixret = sec.getScalarInput(_ixret).getBooleanValue();
mcOut.set(mc1.getRows(), ixret?1:mc1.getCols(), mc1.getBlocksize(), mc1.getBlocksize());
}
else { //e.g., rev
else { //e.g., rev, roll
mcOut.set(mc1);
}
}
Expand All @@ -243,7 +267,7 @@ else if ( getOpcode().equalsIgnoreCase("rsort") ) {
boolean sortIx = getOpcode().equalsIgnoreCase("rsort") && sec.getScalarInput(_ixret).getBooleanValue();
if( sortIx )
mcOut.setNonZeros(mc1.getRows());
else //default (r', rdiag, rev, rsort data)
else //default (r', rdiag, rev, roll, rsort data)
mcOut.setNonZeros(mc1.getNonZeros());
}
}
Expand Down Expand Up @@ -315,6 +339,31 @@ public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( Tuple2<MatrixIndexes,
}
}

private static class RDDRollFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
private static final long serialVersionUID = 1183373828539843938L;

private DataCharacteristics _mcIn = null;
private int _shift = 0;

public RDDRollFunction(DataCharacteristics mcIn, int shift) {
_mcIn = mcIn;
_shift = shift;
}

@Override
public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) {
//construct input
IndexedMatrixValue in = SparkUtils.toIndexedMatrixBlock(arg0);

//execute roll operation
ArrayList<IndexedMatrixValue> out = new ArrayList<>();
LibMatrixReorg.roll(in, _mcIn.getRows(), _mcIn.getBlocksize(), _shift, out);

//construct output
return SparkUtils.fromIndexedMatrixBlock(out).iterator();
}
}

private static class ExtractColumn implements Function<MatrixBlock, MatrixBlock>
{
private static final long serialVersionUID = -1472164797288449559L;
Expand Down
138 changes: 88 additions & 50 deletions src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,36 @@ public static MatrixBlock roll(MatrixBlock in, MatrixBlock out, int shift) {
return out;
}

public static void roll(IndexedMatrixValue in, long rlen, int blen, int shift, ArrayList<IndexedMatrixValue> out) {
MatrixIndexes inMtxIdx = in.getIndexes();
MatrixBlock inMtxBlk = (MatrixBlock) in.getValue();
shift %= ((rlen != 0) ? (int) rlen : 1); // Handle row length boundaries for shift

long inRowIdx = UtilFunctions.computeCellIndex(inMtxIdx.getRowIndex(), blen, 0) - 1;

int totalCopyLen = 0;
while (totalCopyLen < inMtxBlk.getNumRows()) {
// Calculate row and block index for the current part
long outRowIdx = (inRowIdx + shift) % rlen;
long outBlkIdx = UtilFunctions.computeBlockIndex(outRowIdx + 1, blen);
int outBlkLen = UtilFunctions.computeBlockSize(rlen, outBlkIdx, blen);
int outRowIdxInBlk = (int) (outRowIdx % blen);

// Calculate copy length
int copyLen = Math.min((int) (outBlkLen - outRowIdxInBlk), inMtxBlk.getNumRows() - totalCopyLen);

// Create the output block and copy data
MatrixIndexes outMtxIdx = new MatrixIndexes(outBlkIdx, inMtxIdx.getColumnIndex());
MatrixBlock outMtxBlk = new MatrixBlock(outBlkLen, inMtxBlk.getNumColumns(), inMtxBlk.isInSparseFormat());
copyMtx(inMtxBlk, outMtxBlk, totalCopyLen, outRowIdxInBlk, copyLen, false, false);
out.add(new IndexedMatrixValue(outMtxIdx, outMtxBlk));

// Update counters for next iteration
totalCopyLen += copyLen;
inRowIdx += totalCopyLen;
}
}

public static MatrixBlock diag( MatrixBlock in, MatrixBlock out ) {
//Timing time = new Timing(true);

Expand Down Expand Up @@ -2274,77 +2304,85 @@ private static void reverseSparse(MatrixBlock in, MatrixBlock out) {

private static void rollDense(MatrixBlock in, MatrixBlock out, int shift) {
final int m = in.rlen;
final int n = in.clen;
shift %= (m != 0 ? m : 1); // roll matrix with axis=none

//set basic meta data and allocate output
out.sparse = false;
out.nonZeros = in.nonZeros;
out.allocateDenseBlock(false);
copyDenseMtx(in, out, 0, shift, m - shift, false, true);
copyDenseMtx(in, out, m - shift, 0, shift, true, true);
}

//copy all rows into target positions
if (n == 1) { //column vector
private static void rollSparse(MatrixBlock in, MatrixBlock out, int shift) {
final int m = in.rlen;
shift %= (m != 0 ? m : 1); // roll matrix with axis=0

copySparseMtx(in, out, 0, shift, m - shift, false, true);
copySparseMtx(in, out, m-shift, 0, shift, false, true);
}

public static void copyMtx(MatrixBlock in, MatrixBlock out, int inStart, int outStart, int copyLen,
boolean isAllocated, boolean copyTotalNonZeros) {
if (in.isInSparseFormat()){
copySparseMtx(in, out, inStart, outStart, copyLen, isAllocated, copyTotalNonZeros);
} else {
copyDenseMtx(in, out, inStart, outStart, copyLen, isAllocated, copyTotalNonZeros);
}
}

public static void copyDenseMtx(MatrixBlock in, MatrixBlock out, int inIdx, int outIdx, int copyLen,
boolean isAllocated, boolean copyTotalNonZeros) {
int clen = in.clen;

// set basic meta data and allocate output
if (!isAllocated){
out.sparse = false;
if (copyTotalNonZeros) out.nonZeros = in.nonZeros;
out.allocateDenseBlock(false);
}

// copy all rows into target positions
if (clen == 1) { //column vector
double[] a = in.getDenseBlockValues();
double[] c = out.getDenseBlockValues();

// roll matrix with axis=none
shift %= (m != 0 ? m : 1);

System.arraycopy(a, 0, c, shift, m - shift);
System.arraycopy(a, m - shift, c, 0, shift);
} else { //general matrix case
System.arraycopy(a, inIdx, c, outIdx, copyLen);
} else {
DenseBlock a = in.getDenseBlock();
DenseBlock c = out.getDenseBlock();

// roll matrix with axis=0
shift %= (m != 0 ? m : 1);
while (copyLen > 0) {
System.arraycopy(a.values(inIdx), a.pos(inIdx),
c.values(outIdx), c.pos(outIdx), clen);

for (int i = 0; i < m - shift; i++) {
System.arraycopy(a.values(i), a.pos(i), c.values(i + shift), c.pos(i + shift), n);
}

for (int i = m - shift; i < m; i++) {
System.arraycopy(a.values(i), a.pos(i), c.values(i + shift - m), c.pos(i + shift - m), n);
inIdx++; outIdx++; copyLen--;
}
}
}

private static void rollSparse(MatrixBlock in, MatrixBlock out, int shift) {
final int m = in.rlen;

private static void copySparseMtx(MatrixBlock in, MatrixBlock out, int inIdx, int outIdx, int copyLen,
boolean isAllocated, boolean copyTotalNonZeros) {
//set basic meta data and allocate output
out.sparse = true;
out.nonZeros = in.nonZeros;
out.allocateSparseRowsBlock(false);
if (!isAllocated){
out.sparse = true;
if (copyTotalNonZeros) out.nonZeros = in.nonZeros;
out.allocateSparseRowsBlock(false);
}

//copy all rows into target positions
SparseBlock a = in.getSparseBlock();
SparseBlock c = out.getSparseBlock();

// roll matrix with axis=0
shift %= (m != 0 ? m : 1);

for (int i = 0; i < m - shift; i++) {
if (a.isEmpty(i)) continue; // skip empty rows
while (copyLen > 0) {
if (a.isEmpty(inIdx)) continue; // skip empty rows

rollSparseRow(a, c, i, i + shift);
}

for (int i = m - shift; i < m; i++) {
if (a.isEmpty(i)) continue; // skip empty rows

rollSparseRow(a, c, i, i + shift - m);
}
}
final int apos = a.pos(inIdx);
final int alen = a.size(inIdx) + apos;
final int[] aix = a.indexes(inIdx);
final double[] avals = a.values(inIdx);

private static void rollSparseRow(SparseBlock a, SparseBlock c, int oriIdx, int shiftIdx) {
final int apos = a.pos(oriIdx);
final int alen = a.size(oriIdx) + apos;
final int[] aix = a.indexes(oriIdx);
final double[] avals = a.values(oriIdx);
// copy only non-zero elements
for (int k = apos; k < alen; k++) {
c.set(outIdx, aix[k], avals[k]);
}

// copy only non-zero elements
for (int k = apos; k < alen; k++) {
c.set(shiftIdx, aix[k], avals[k]);
inIdx++; outIdx++; copyLen--;
}
}

Expand Down
Loading

0 comments on commit 9a193a8

Please sign in to comment.