Skip to content

Commit

Permalink
[SYSTEMDS-3771] Compressed Identity Dictionary and Selection Multiply
Browse files Browse the repository at this point in the history
This commit contains the implementation details on
LLM refinements for supporting the new Identity dictionaries,
that remove the need for many of the matrix multiplications.

Furthermore it also contains the implementation details and optimizations
for selective Matrix Multiplications of matrices in the left side
containing only a single 1 in each row. The implementation there
simply decompress the rows associated with the 1, making the overall
compressed operation very efficient.

The overall implementation further improves the code-coverage of the
project by 0.23%

Closes #2084
  • Loading branch information
Baunsgaard committed Sep 23, 2024
1 parent eea4afc commit 2ce1910
Show file tree
Hide file tree
Showing 52 changed files with 3,947 additions and 591 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,11 @@ public void allocateAndResetSparseBlock(boolean clearNNZ, SparseBlock.Type stype
throw new DMLCompressionException("Invalid to allocate block on a compressed MatrixBlock");
}

@Override
public MatrixBlock transpose(int k) {
return getUncompressed().transpose(k);
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex.SliceResult;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
Expand Down Expand Up @@ -728,6 +729,44 @@ public AColGroup sortColumnIndexes() {
*/
public abstract AColGroup reduceCols();

/**
* Selection (left matrix multiply)
*
* @param selection A sparse matrix with "max" a single one in each row all other values are zero.
* @param points The coordinates in the selection matrix to extract.
* @param ret The MatrixBlock to decompress the selected rows into
* @param rl The row to start at in the selection matrix
* @param ru the row to end at in the selection matrix (not inclusive)
*/
public final void selectionMultiply(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
if(ret.isInSparseFormat())
sparseSelection(selection, points, ret, rl, ru);
else
denseSelection(selection, points, ret, rl, ru);
}

/**
* Sparse selection (left matrix multiply)
*
* @param selection A sparse matrix with "max" a single one in each row all other values are zero.
* @param points The coordinates in the selection matrix to extract.
* @param ret The Sparse MatrixBlock to decompress the selected rows into
* @param rl The row to start at in the selection matrix
* @param ru the row to end at in the selection matrix (not inclusive)
*/
protected abstract void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru);

/**
* Dense selection (left matrix multiply)
*
* @param selection A sparse matrix with "max" a single one in each row all other values are zero.
* @param points The coordinates in the selection matrix to extract.
* @param ret The Dense MatrixBlock to decompress the selected rows into
* @param rl The row to start at in the selection matrix
* @param ru the row to end at in the selection matrix (not inclusive)
*/
protected abstract void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru);

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ else if(lhs instanceof ColGroupUncompressed)
* @return A aggregate dictionary
*/
public final IDictionary preAggregateThatIndexStructure(APreAgg that) {
final long outputLength = (long)that._colIndexes.size() * this.getNumValues();
final long outputLength = (long) that._colIndexes.size() * this.getNumValues();
if(outputLength > Integer.MAX_VALUE)
throw new NotImplementedException("Not supported pre aggregate of above integer length");
if(outputLength <= 0) // if the pre aggregate output is empty or nothing, return null
return null;

// create empty Dictionary that we slowly fill, hence the dictionary is empty and no check
final Dictionary ret = Dictionary.createNoCheck(new double[(int)outputLength]);
final Dictionary ret = Dictionary.createNoCheck(new double[(int) outputLength]);

if(that instanceof ColGroupDDC)
preAggregateThatDDCStructure((ColGroupDDC) that, ret);
Expand All @@ -119,7 +119,7 @@ else if(that instanceof ColGroupRLE)
*/
public final void preAggregate(MatrixBlock m, double[] preAgg, int rl, int ru) {
if(m.isInSparseFormat())
preAggregateSparse(m.getSparseBlock(), preAgg, rl, ru);
preAggregateSparse(m.getSparseBlock(), preAgg, rl, ru, 0, m.getNumColumns());
else
preAggregateDense(m, preAgg, rl, ru, 0, m.getNumColumns());
}
Expand All @@ -136,7 +136,7 @@ public final void preAggregate(MatrixBlock m, double[] preAgg, int rl, int ru) {
*/
public abstract void preAggregateDense(MatrixBlock m, double[] preAgg, int rl, int ru, int cl, int cu);

public abstract void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru);
public abstract void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru, int cl, int cu);

protected abstract void preAggregateThatDDCStructure(ColGroupDDC that, Dictionary ret);

Expand All @@ -160,11 +160,13 @@ private void tsmmAPreAgg(APreAgg lg, MatrixBlock result) {
final boolean left = shouldPreAggregateLeft(lg);
if(!loggedWarningForDirect && shouldDirectMultiply(lg, leftIdx.size(), rightIdx.size(), left)) {
loggedWarningForDirect = true;
LOG.warn("Not implemented direct tsmm colgroup: " + lg.getClass().getSimpleName() + " %*% " + this.getClass().getSimpleName() );
LOG.warn("Not implemented direct tsmm colgroup: " + lg.getClass().getSimpleName() + " %*% "
+ this.getClass().getSimpleName());
}

if(left) {
final IDictionary lpa = this.preAggregateThatIndexStructure(lg);

if(lpa != null)
DictLibMatrixMult.TSMMToUpperTriangle(lpa, _dict, leftIdx, rightIdx, result);
}
Expand Down Expand Up @@ -222,7 +224,7 @@ else if(shouldPreAggregateLeft(lhs)) {// left preAgg
DictLibMatrixMult.MMDicts(lDict, lhsPA, leftIdx, rightIdx, result);
}
else {// right preAgg
final IDictionary rhsPA = preAggregateThatIndexStructure(lhs);
final IDictionary rhsPA = this.preAggregateThatIndexStructure(lhs);
if(rhsPA != null)
DictLibMatrixMult.MMDicts(rhsPA, rDict, leftIdx, rightIdx, result);
}
Expand Down Expand Up @@ -311,17 +313,20 @@ public void mmWithDictionary(MatrixBlock preAgg, MatrixBlock tmpRes, MatrixBlock
// Shallow copy the preAgg to allow sparse PreAgg multiplication but do not remove the original dense allocation
// since the dense allocation is reused.
final MatrixBlock preAggCopy = new MatrixBlock();
preAggCopy.copy(preAgg);
preAggCopy.copyShallow(preAgg);
final MatrixBlock tmpResCopy = new MatrixBlock();
tmpResCopy.copy(tmpRes);
tmpResCopy.copyShallow(tmpRes);
// Get dictionary matrixBlock
final MatrixBlock dict = getDictionary().getMBDict(_colIndexes.size()).getMatrixBlock();
if(dict != null) {
// Multiply
LibMatrixMult.matrixMult(preAggCopy, dict, tmpResCopy, k);
ColGroupUtils.addMatrixToResult(tmpResCopy, ret, _colIndexes, rl, ru);
LibMatrixMult.matrixMult(preAggCopy, dict, tmpRes, k);
ColGroupUtils.addMatrixToResult(tmpRes, ret, _colIndexes, rl, ru);
}
}

protected abstract int numRowsToMultiply();

public abstract void leftMMIdentityPreAggregateDense(MatrixBlock that, MatrixBlock ret, int rl, int ru, int cl,
int cu);
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
Expand Down Expand Up @@ -647,4 +648,14 @@ public AMapToData getMapToData() {
return MapToFactory.create(0, 0);
}

@Override
protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
throw new NotImplementedException();
}

@Override
protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
throw new NotImplementedException();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.indexes.RangeIndex;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToByte;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToChar;
Expand Down Expand Up @@ -398,7 +401,10 @@ public void preAggregateDense(MatrixBlock m, double[] preAgg, int rl, int ru, in
}

@Override
public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru) {
public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru, int cl, int cu) {
if(cl != 0 || cu != _data.size()) {
throw new NotImplementedException();
}
_data.preAggregateSparse(sb, preAgg, rl, ru);
}

Expand Down Expand Up @@ -628,6 +634,90 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) {
return ColGroupDDC.create(newColIndex, _dict.reorder(reordering), _data, getCachedCounts());
}

@Override
public void sparseSelection(MatrixBlock selection,P[] points, MatrixBlock ret, int rl, int ru) {
// morph(CompressionType.UNCOMPRESSED, _data.size()).sparseSelection(selection, ret, rl, ru);;
final SparseBlock sb = selection.getSparseBlock();
final SparseBlock retB = ret.getSparseBlock();
for(int r = rl; r < ru; r++) {
if(sb.isEmpty(r))
continue;
final int sPos = sb.pos(r);
final int rowCompressed = sb.indexes(r)[sPos]; // column index with 1
decompressToSparseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0);
}
}


@Override
protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
// morph(CompressionType.UNCOMPRESSED, _data.size()).sparseSelection(selection, ret, rl, ru);;
final SparseBlock sb = selection.getSparseBlock();
final DenseBlock retB = ret.getDenseBlock();
for(int r = rl; r < ru; r++) {
if(sb.isEmpty(r))
continue;
final int sPos = sb.pos(r);
final int rowCompressed = sb.indexes(r)[sPos]; // column index with 1
decompressToDenseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0);
}
}

@Override
public void leftMMIdentityPreAggregateDense(MatrixBlock that, MatrixBlock ret, int rl, int ru, int cl, int cu) {
DenseBlock db = that.getDenseBlock();
DenseBlock retDB = ret.getDenseBlock();
if(rl == ru - 1)
leftMMIdentityPreAggregateDenseSingleRow(db.values(rl), db.pos(rl), retDB.values(rl), retDB.pos(rl), cl, cu);
else
throw new NotImplementedException();
}


private void leftMMIdentityPreAggregateDenseSingleRow(double[] values, int pos, double[] values2, int pos2, int cl,
int cu) {
IdentityDictionary a = (IdentityDictionary) _dict;
if(_colIndexes instanceof RangeIndex)
leftMMIdentityPreAggregateDenseSingleRowRangeIndex(values, pos, values2, pos2, cl, cu);
else {

pos += cl; // left side matrix position offset.
if(a.withEmpty()) {
final int nVal = _dict.getNumberOfValues(_colIndexes.size()) - 1;
for(int rc = cl; rc < cu; rc++, pos++) {
final int idx = _data.getIndex(rc);
if(idx != nVal)
values2[_colIndexes.get(idx)] += values[pos];
}
}
else {
for(int rc = cl; rc < cu; rc++, pos++)
values2[_colIndexes.get(_data.getIndex(rc))] += values[pos];
}
}
}


private void leftMMIdentityPreAggregateDenseSingleRowRangeIndex(double[] values, int pos, double[] values2, int pos2,
int cl, int cu) {
IdentityDictionary a = (IdentityDictionary) _dict;

final int firstCol = _colIndexes.get(0);
pos += cl; // left side matrix position offset.
if(a.withEmpty()) {
final int nVal = _dict.getNumberOfValues(_colIndexes.size()) - 1;
for(int rc = cl; rc < cu; rc++, pos++) {
final int idx = _data.getIndex(rc);
if(idx != nVal)
values2[firstCol + idx] += values[pos];
}
}
else {
for(int rc = cl; rc < cu; rc++, pos++)
values2[firstCol + _data.getIndex(rc)] += values[pos];
}
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
Expand All @@ -40,6 +41,8 @@
import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.compress.utils.Util;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Minus;
Expand Down Expand Up @@ -252,7 +255,7 @@ public AColGroup replace(double pattern, double replace) {
if(patternInReference) {
double[] nRef = new double[_reference.length];
for(int i = 0; i < _reference.length; i++)
if(Util.eq(pattern ,_reference[i]))
if(Util.eq(pattern, _reference[i]))
nRef[i] = replace;
else
nRef[i] = _reference[i];
Expand Down Expand Up @@ -489,6 +492,34 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) {
throw new NotImplementedException();
}

@Override
protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
final SparseBlock sb = selection.getSparseBlock();
final SparseBlock retB = ret.getSparseBlock();
for(int r = rl; r < ru; r++) {
if(sb.isEmpty(r))
continue;

final int sPos = sb.pos(r);
final int rowCompressed = sb.indexes(r)[sPos];
decompressToSparseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0);
}
}

@Override
protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
final SparseBlock sb = selection.getSparseBlock();
final DenseBlock retB = ret.getDenseBlock();
for(int r = rl; r < ru; r++) {
if(sb.isEmpty(r))
continue;

final int sPos = sb.pos(r);
final int rowCompressed = sb.indexes(r)[sPos];
decompressToDenseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0);
}
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import java.io.IOException;
import java.util.Arrays;

import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
Expand Down Expand Up @@ -53,7 +55,7 @@
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;

public class ColGroupEmpty extends AColGroupCompressed
implements IContainADictionary, IContainDefaultTuple, AOffsetsGroup ,IMapToDataGroup{
implements IContainADictionary, IContainDefaultTuple, AOffsetsGroup, IMapToDataGroup {
private static final long serialVersionUID = -2307677253622099958L;

/**
Expand Down Expand Up @@ -403,9 +405,18 @@ public AMapToData getMapToData() {
return MapToFactory.create(0, 0);
}

@Override
public AColGroup reduceCols(){
@Override
public AColGroup reduceCols() {
return null;
}

@Override
protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
throw new NotImplementedException();
}

@Override
protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
throw new NotImplementedException();
}
}
Loading

0 comments on commit 2ce1910

Please sign in to comment.