Skip to content

Commit

Permalink
[SYSTEMDS-3808] Dictionary Compressed Combine
Browse files Browse the repository at this point in the history
This commit speedup the combining of dictionaries via custum hashmaps.

Closes  #2166

Signed-off-by: Sebastian Baunsgaard <baunsgaard@apache.org>
  • Loading branch information
Baunsgaard committed Dec 29, 2024
1 parent 809490f commit 96fd5da
Show file tree
Hide file tree
Showing 12 changed files with 888 additions and 351 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@

package org.apache.sysds.runtime.compress.estim.encoding;

import java.util.Map;

import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
import org.apache.sysds.runtime.compress.utils.HashMapLongInt;

/** Const encoding for cases where the entire group of columns is the same value */
public class ConstEncoding extends AEncode {
Expand All @@ -41,7 +40,7 @@ public IEncode combine(IEncode e) {
}

@Override
public Pair<IEncode, Map<Integer, Integer>> combineWithMap(IEncode e) {
public Pair<IEncode, HashMapLongInt> combineWithMap(IEncode e) {
return new ImmutablePair<>(e, null);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,40 @@

package org.apache.sysds.runtime.compress.estim.encoding;

import java.util.HashMap;
import java.util.Map;

import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToChar;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToCharPByte;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.offset.AIterator;
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
import org.apache.sysds.runtime.compress.utils.HashMapLongInt;

/**
* An Encoding that contains a value on each row of the input.
*/
public class DenseEncoding extends AEncode {

private static boolean zeroWarn = false;

private final AMapToData map;

public DenseEncoding(AMapToData map) {
this.map = map;

if(CompressedMatrixBlock.debug) {
// if(!zeroWarn) {
int[] freq = map.getCounts();
for(int i = 0; i < freq.length; i++) {
if(freq[i] == 0)
throw new DMLCompressionException("Invalid counts in fact contains 0");
for(int i = 0; i < freq.length && !zeroWarn; i++) {
if(freq[i] == 0) {
LOG.warn("Dense encoding contains zero encoding, indicating not all dictionary entries are in use");
zeroWarn = true;

}
}
}
}
Expand All @@ -62,7 +68,7 @@ else if(e instanceof SparseEncoding)
}

@Override
public Pair<IEncode, Map<Integer, Integer>> combineWithMap(IEncode e) {
public Pair<IEncode, HashMapLongInt> combineWithMap(IEncode e) {
if(e instanceof EmptyEncoding || e instanceof ConstEncoding)
return new ImmutablePair<>(this, null);
else if(e instanceof SparseEncoding)
Expand Down Expand Up @@ -106,14 +112,14 @@ private AMapToData assignSparse(SparseEncoding e) {
return ret;
}

private final Pair<IEncode, Map<Integer, Integer>> combineSparseHashMap(final AMapToData ret) {
private final Pair<IEncode, HashMapLongInt> combineSparseHashMap(final AMapToData ret) {
final int size = ret.size();
final Map<Integer, Integer> m = new HashMap<>(size);
final HashMapLongInt m = new HashMapLongInt(100);
for(int r = 0; r < size; r++) {
final int prev = ret.getIndex(r);
final int v = m.size();
final Integer mv = m.putIfAbsent(prev, v);
if(mv == null)
final int mv = m.putIfAbsent(prev, v);
if(mv == -1)
ret.set(r, v);
else
ret.set(r, mv);
Expand Down Expand Up @@ -146,28 +152,44 @@ protected DenseEncoding combineDense(final DenseEncoding other) {
final int nVL = lm.getUnique();
final int nVR = rm.getUnique();
final int size = map.size();
final int maxUnique = nVL * nVR;

int maxUnique = nVL * nVR;
final DenseEncoding retE;
final AMapToData ret = MapToFactory.create(size, maxUnique);

if(maxUnique > size && maxUnique > 2048) {
if(maxUnique < Math.max(nVL, nVR)) {// overflow
final HashMapLongInt m = new HashMapLongInt(Math.max(100, size / 100));
retE = combineDenseWithHashMapLong(lm, rm, size, nVL, ret, m);
}
else if(maxUnique > size && maxUnique > 2048) {
// aka there is more maxUnique than rows.
final Map<Integer, Integer> m = new HashMap<>(size);
return combineDenseWithHashMap(lm, rm, size, nVL, ret, m);
final HashMapLongInt m = new HashMapLongInt(Math.max(100, maxUnique / 100));
retE = combineDenseWithHashMap(lm, rm, size, nVL, ret, m);
}
else {
final AMapToData m = MapToFactory.create(maxUnique, maxUnique + 1);
return combineDenseWithMapToData(lm, rm, size, nVL, ret, maxUnique, m);
retE = combineDenseWithMapToData(lm, rm, size, nVL, ret, maxUnique, m);
}

if(retE.getUnique() < 0) {
String th = this.toString();
String ot = other.toString();
String cm = retE.toString();

if(th.length() > 1000)
th = th.substring(0, 1000);
if(ot.length() > 1000)
ot = ot.substring(0, 1000);
if(cm.length() > 1000)
cm = cm.substring(0, 1000);
throw new DMLCompressionException(
"Failed to combine dense encodings correctly: Number unique values is lower than max input: \n\n" + th
+ "\n\n" + ot + "\n\n" + cm);
}
return retE;
}

private Pair<IEncode, Map<Integer, Integer>> combineDenseNoResize(final DenseEncoding other) {
if(map == other.map) {
LOG.warn("Constructing perfect mapping, this could be optimized to skip hashmap");
final Map<Integer, Integer> m = new HashMap<>(map.size());
for(int i = 0; i < map.getUnique(); i++)
m.put(i * i, i);
return new ImmutablePair<>(this, m); // same object
private Pair<IEncode, HashMapLongInt> combineDenseNoResize(final DenseEncoding other) {
if(map.equals(other.map)) {
return combineSameMapping();
}

final AMapToData lm = map;
Expand All @@ -176,40 +198,115 @@ private Pair<IEncode, Map<Integer, Integer>> combineDenseNoResize(final DenseEnc
final int nVL = lm.getUnique();
final int nVR = rm.getUnique();
final int size = map.size();
final int maxUnique = nVL * nVR;
final int maxUnique = (int) Math.min((long) nVL * nVR, (long) size);

final AMapToData ret = MapToFactory.create(size, maxUnique);

final Map<Integer, Integer> m = new HashMap<>(Math.min(size, maxUnique));
final HashMapLongInt m = new HashMapLongInt(Math.max(100, maxUnique / 1000));
return new ImmutablePair<>(combineDenseWithHashMap(lm, rm, size, nVL, ret, m), m);
}

// there can be less unique.

// return new DenseEncoding(ret);
private Pair<IEncode, HashMapLongInt> combineSameMapping() {
LOG.warn("Constructing perfect mapping, this could be optimized to skip hashmap");
final HashMapLongInt m = new HashMapLongInt(Math.max(100, map.size() / 100));
for(int i = 0; i < map.getUnique(); i++)
m.putIfAbsent(i * (map.getUnique() + 1), i);
return new ImmutablePair<>(this, m); // same object
}

private Pair<IEncode, Map<Integer, Integer>> combineSparseNoResize(final SparseEncoding other) {
private Pair<IEncode, HashMapLongInt> combineSparseNoResize(final SparseEncoding other) {
final AMapToData a = assignSparse(other);
return combineSparseHashMap(a);
}

protected final DenseEncoding combineDenseWithHashMapLong(final AMapToData lm, final AMapToData rm, final int size,
final long nVL, final AMapToData ret, HashMapLongInt m) {
if(ret instanceof MapToChar)
for(int r = 0; r < size; r++)
addValHashMapChar((long) lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, (MapToChar) ret);
else
for(int r = 0; r < size; r++)
addValHashMap((long) lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, ret);
return new DenseEncoding(ret.resize(m.size()));
}

protected final DenseEncoding combineDenseWithHashMap(final AMapToData lm, final AMapToData rm, final int size,
final int nVL, final AMapToData ret, Map<Integer, Integer> m) {
final int nVL, final AMapToData ret, HashMapLongInt m) {
// JIT compile instance checks.
if(ret instanceof MapToChar)
combineDenseWIthHashMapCharOut(lm, rm, size, nVL, (MapToChar) ret, m);
else if(ret instanceof MapToCharPByte)
combineDenseWIthHashMapPByteOut(lm, rm, size, nVL, (MapToCharPByte) ret, m);
else
combineDenseWithHashMapGeneric(lm, rm, size, nVL, ret, m);
ret.setUnique(m.size());
return new DenseEncoding(ret);

}

private final void combineDenseWIthHashMapPByteOut(final AMapToData lm, final AMapToData rm, final int size,
final int nVL, final MapToCharPByte ret, HashMapLongInt m) {
for(int r = 0; r < size; r++)
addValHashMapCharByte(lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, ret);
}

private final void combineDenseWIthHashMapCharOut(final AMapToData lm, final AMapToData rm, final int size,
final int nVL, final MapToChar ret, HashMapLongInt m) {
if(lm instanceof MapToChar && rm instanceof MapToChar)
combineDenseWIthHashMapAllChar(lm, rm, size, nVL, ret, m);
else// some other combination
combineDenseWIthHashMapCharOutGeneric(lm, rm, size, nVL, ret, m);
}

private final void combineDenseWIthHashMapCharOutGeneric(final AMapToData lm, final AMapToData rm, final int size,
final int nVL, final MapToChar ret, HashMapLongInt m) {
for(int r = 0; r < size; r++)
addValHashMapChar(lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, ret);
}

private final void combineDenseWIthHashMapAllChar(final AMapToData lm, final AMapToData rm, final int size,
final int nVL, final MapToChar ret, HashMapLongInt m) {
final MapToChar lmC = (MapToChar) lm;
final MapToChar rmC = (MapToChar) rm;
for(int r = 0; r < size; r++)
addValHashMapChar(lmC.getIndex(r) + rmC.getIndex(r) * nVL, r, m, ret);

}

protected final void combineDenseWithHashMapGeneric(final AMapToData lm, final AMapToData rm, final int size,
final int nVL, final AMapToData ret, HashMapLongInt m) {
for(int r = 0; r < size; r++)
addValHashMap(lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, ret);
return new DenseEncoding(ret.resize(m.size()));
}

protected final DenseEncoding combineDenseWithMapToData(final AMapToData lm, final AMapToData rm, final int size,
final int nVL, final AMapToData ret, final int maxUnique, final AMapToData m) {
if(m instanceof MapToChar)
return combineDenseWithMapToDataToChar(lm, rm, size, nVL, ret, maxUnique, (MapToChar) m);
else
return combineDenseWithMapToDataGeneric(lm, rm, size, nVL, ret, maxUnique, m);

}

protected final DenseEncoding combineDenseWithMapToDataToChar(final AMapToData lm, final AMapToData rm,
final int size, final int nVL, final AMapToData ret, final int maxUnique, final MapToChar m) {
int newUID = 1;
for(int r = 0; r < size; r++)
newUID = addValMapToDataChar(lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, newUID, ret);
ret.setUnique(newUID - 1);
return new DenseEncoding(ret);
}

protected final DenseEncoding combineDenseWithMapToDataGeneric(final AMapToData lm, final AMapToData rm,
final int size, final int nVL, final AMapToData ret, final int maxUnique, final AMapToData m) {
int newUID = 1;
for(int r = 0; r < size; r++)
newUID = addValMapToData(lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, newUID, ret);
return new DenseEncoding(ret.resize(newUID - 1));
ret.setUnique(newUID - 1);
return new DenseEncoding(ret);
}

protected static int addValMapToData(final int nv, final int r, final AMapToData map, int newId,
protected static int addValMapToDataChar(final int nv, final int r, final MapToChar map, int newId,
final AMapToData d) {
int mv = map.getIndex(nv);
if(mv == 0)
Expand All @@ -218,11 +315,56 @@ protected static int addValMapToData(final int nv, final int r, final AMapToData
return newId;
}

protected static void addValHashMap(final int nv, final int r, final Map<Integer, Integer> map,
protected static int addValMapToData(final int nv, final int r, final AMapToData map, int newId,
final AMapToData d) {
int mv = map.getIndex(nv);
if(mv == 0)
mv = map.setAndGet(nv, newId++);
d.set(r, mv - 1);
return newId;
}

protected static void addValHashMap(final int nv, final int r, final HashMapLongInt map, final AMapToData d) {
final int v = map.size();
final Integer mv = map.putIfAbsent(nv, v);
if(mv == null)
final int mv = map.putIfAbsent(nv, v);
if(mv == -1)
d.set(r, v);
else
d.set(r, mv);
}

protected static void addValHashMapChar(final int nv, final int r, final HashMapLongInt map, final MapToChar d) {
final int v = map.size();
final int mv = map.putIfAbsent(nv, v);
if(mv == -1)
d.set(r, v);
else
d.set(r, mv);
}

protected static void addValHashMapCharByte(final int nv, final int r, final HashMapLongInt map,
final MapToCharPByte d) {
final int v = map.size();
final int mv = map.putIfAbsent(nv, v);
if(mv == -1)
d.set(r, v);
else
d.set(r, mv);
}

protected static void addValHashMapChar(final long nv, final int r, final HashMapLongInt map, final MapToChar d) {
final int v = map.size();
final int mv = map.putIfAbsent(nv, v);
if(mv == -1)
d.set(r, v);
else
d.set(r, mv);
}

protected static void addValHashMap(final long nv, final int r, final HashMapLongInt map, final AMapToData d) {
final int v = map.size();
final int mv = map.putIfAbsent(nv, v);
if(mv == -1)
d.set(r, v);
else
d.set(r, mv);
Expand All @@ -237,13 +379,18 @@ public int getUnique() {
public EstimationFactors extractFacts(int nRows, double tupleSparsity, double matrixSparsity,
CompressionSettings cs) {
int largestOffs = 0;

int[] counts = map.getCounts();
for(int i = 0; i < counts.length; i++)
if(counts[i] > largestOffs)
largestOffs = counts[i];
else if(counts[i] == 0)
throw new DMLCompressionException("Invalid count of 0 all values should have at least one instance");
else if(counts[i] == 0) {
if(!zeroWarn) {
LOG.warn("Invalid count of 0 all values should have at least one instance index: " + i + " of "
+ counts.length);
zeroWarn = true;
}
counts[i] = 1;
}

if(cs.isRLEAllowed())
return new EstimationFactors(map.getUnique(), nRows, largestOffs, counts, 0, nRows, map.countRuns(), false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@

package org.apache.sysds.runtime.compress.estim.encoding;

import java.util.Map;

import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
import org.apache.sysds.runtime.compress.utils.HashMapLongInt;

/**
* Empty encoding for cases where the entire group of columns is zero
Expand All @@ -41,7 +40,7 @@ public IEncode combine(IEncode e) {
}

@Override
public Pair<IEncode, Map<Integer, Integer>> combineWithMap(IEncode e) {
public Pair<IEncode, HashMapLongInt> combineWithMap(IEncode e) {
return new ImmutablePair<>(e, null);
}

Expand Down
Loading

0 comments on commit 96fd5da

Please sign in to comment.