Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hello.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
print("Hello SystemDS")
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ protected byte toID() {

//transform methods
public enum TfMethod {
IMPUTE, RECODE, HASH, BIN, DUMMYCODE, UDF, OMIT, WORD_EMBEDDING, BAG_OF_WORDS;
IMPUTE, RECODE, HASH, BIN, DUMMYCODE, UDF, OMIT, WORD_EMBEDDING, BAG_OF_WORDS, RAGGED;
@Override
public String toString() {
return name().toLowerCase();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public void initEmbeddings(MatrixBlock embeddings){
}

protected enum TransformType{
BIN, RECODE, DUMMYCODE, FEATURE_HASH, PASS_THROUGH, UDF, WORD_EMBEDDING, BAG_OF_WORDS, N_A
BIN, RECODE, DUMMYCODE, FEATURE_HASH, PASS_THROUGH, UDF, WORD_EMBEDDING, BAG_OF_WORDS, RAGGED, N_A
}

protected ColumnEncoder(int colID) {
Expand Down Expand Up @@ -447,7 +447,7 @@ protected void setApplyRowBlocksPerColumn(int nPart) {
}

public enum EncoderType {
Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite, WordEmbedding, BagOfWords
Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite, WordEmbedding, BagOfWords, Ragged
}

/*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
package org.apache.sysds.runtime.transform.encode;

import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.TfUtils;
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.common.Types.ValueType;

import java.util.HashMap;
import java.util.Map;

/**
* Encodes a column using ragged array/dictionary representation to optimize memory usage.
* Stores unique values in a dictionary and replaces occurrences with indices.
*/
public class ColumnEncoderRagged extends ColumnEncoder {
private static final long serialVersionUID = 2291732648968734088L;

// Dictionary storage
private Object[] _dict;
private int _dictSize;
private int _nullIndex = -1;

// Reverse mapping for fast lookups
private transient Map<Object, Integer> _valueToIndex;

private static final String[] DEFAULT_NA_STRINGS = new String[]{"NA", "NaN", ""};

public ColumnEncoderRagged() {
super(-1); // ID will be set during construction
}

public ColumnEncoderRagged(int colID) {
super(colID);
}

@Override
protected TransformType getTransformType() {
return TransformType.RAGGED;
}

// Helper method to check NA values
private boolean isNAValue(String val) {
if(val == null) return true;
for(String na : DEFAULT_NA_STRINGS) {
if(val.equals(na)) return true;
}
return false;
}

@Override
public void build(CacheBlock<?> in) {
if (!(in instanceof FrameBlock))
throw new IllegalArgumentException("Ragged encoding only supports FrameBlock input");

FrameBlock fin = (FrameBlock) in;
if (_colID < 1 || _colID > fin.getNumColumns())
throw new IllegalArgumentException("Invalid column ID: " + _colID);

_valueToIndex = new HashMap<>();
_dict = new String[Math.min(1024, fin.getNumRows())];
_dictSize = 0;

for (int i = 0; i < fin.getNumRows(); i++) {
Object valObj = fin.get(i, _colID - 1);
// Convert all values to strings safely
String val = (valObj != null) ? valObj.toString() : null;

if (isNAValue(val)) {
if (_nullIndex == -1) {
_nullIndex = _dictSize;
_dict[_dictSize++] = null;
}
continue;
}

if (!_valueToIndex.containsKey(val)) {
if (_dictSize == _dict.length) {
String[] newDict = new String[_dict.length * 2];
System.arraycopy(_dict, 0, newDict, 0, _dictSize);
_dict = newDict;
}
_dict[_dictSize] = val;
_valueToIndex.put(val, _dictSize);
_dictSize++;
}
}
}

@Override
public MatrixBlock apply(CacheBlock<?> in, MatrixBlock out, int outputCol) {
// Validate input type
if (!(in instanceof FrameBlock)) {
throw new IllegalArgumentException("Ragged encoding only supports FrameBlock input");
}

FrameBlock fin = (FrameBlock) in;
final int numRows = fin.getNumRows();

// Create new matrix if needed
if (out == null) {
out = new MatrixBlock(numRows, outputCol + 1, false);
}

// Encode each value
for (int i = 0; i < numRows; i++) {
String val = fin.get(i, _colID - 1).toString();
int index = isNAValue(val) ? _nullIndex : _valueToIndex.getOrDefault(val, _nullIndex);

// Use the standard set method
out.set(i, outputCol, (double) index);
}

return out;
}

@Override
public double[] getCodeCol(CacheBlock<?> in, int outputCol, int rowStart, double[] tmp) {
if (!(in instanceof FrameBlock))
throw new IllegalArgumentException("Ragged encoding only supports FrameBlock input");
FrameBlock fin = (FrameBlock) in;

if (tmp == null)
tmp = new double[fin.getNumRows() - rowStart];

for (int i = rowStart; i < fin.getNumRows(); i++) {
String val = fin.get(i, _colID - 1).toString();
tmp[i - rowStart] = isNAValue(val) ? _nullIndex : _valueToIndex.getOrDefault(val, _nullIndex);
}
return tmp;
}

@Override
public double getCode(CacheBlock<?> in, int row) {
if (!(in instanceof FrameBlock))
throw new IllegalArgumentException("Ragged encoding only supports FrameBlock input");
FrameBlock fin = (FrameBlock) in;

String val = fin.get(row, _colID - 1).toString();
return isNAValue(val) ? _nullIndex : _valueToIndex.getOrDefault(val, _nullIndex);
}

@Override
public FrameBlock getMetaData(FrameBlock out) {
if (out == null)
out = new FrameBlock(1, ValueType.STRING);

// Store dictionary in meta frame
out.ensureAllocatedColumns(_dictSize);
for (int i = 0; i < _dictSize; i++) {
out.set(i, 0, _dict[i]);
}

return out;
}

@Override
public void initMetaData(FrameBlock meta) {
if (meta == null || meta.getNumRows() == 0)
return;

// Reconstruct dictionary from meta data
_dictSize = meta.getNumRows();
_dict = new Object[_dictSize];
_valueToIndex = new HashMap<>();

for (int i = 0; i < _dictSize; i++) {
_dict[i] = meta.get(i, 0);
if (_dict[i] == null) {
_nullIndex = i;
} else {
_valueToIndex.put(_dict[i], i);
}
}
}

// Other required methods with default implementations
@Override public void allocateMetaData(FrameBlock meta) {}
@Override public void prepareBuildPartial() {}
@Override public void buildPartial(FrameBlock in) { build(in); }
@Override public void updateIndexRanges(long[] beginDims, long[] endDims, int offset) {}

// Additional helper methods
public Object[] getDictionary() {
return _dict;
}

public int getDictionarySize() {
return _dictSize;
}

public int getNullIndex() {
return _nullIndex;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,15 @@ public static MultiColumnEncoder createEncoder(String spec, String[] colnames, i
.toObject(TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.WORD_EMBEDDING.toString(), minCol, maxCol)));
List<Integer> bowIDs = Arrays.asList(ArrayUtils
.toObject(TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.BAG_OF_WORDS.toString(), minCol, maxCol)));
List<Integer> ragIDs = Arrays.asList(ArrayUtils
.toObject(TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.RAGGED.toString(), minCol, maxCol)));

// NOTE: any dummycode column requires recode as preparation, unless the dummycode
// column follows binning or feature hashing
rcIDs = unionDistinct(rcIDs, except(except(dcIDs, binIDs), haIDs));
// Error out if the first level encoders have overlaps
if (intersect(rcIDs, binIDs, haIDs, weIDs, bowIDs))
throw new DMLRuntimeException("More than one encoders (recode, binning, hashing, word_embedding, bag_of_words) on one column is not allowed:\n" + spec);
if (intersect(rcIDs, binIDs, haIDs, weIDs, bowIDs, ragIDs))
throw new DMLRuntimeException("More than one encoders (recode, binning, hashing, word_embedding, bag_of_words, ragIDs) on one column is not allowed:\n" + spec);

List<Integer> ptIDs = except(UtilFunctions.getSeqList(1, clen, 1), naryUnionDistinct(rcIDs, haIDs, binIDs, weIDs, bowIDs));
List<Integer> oIDs = new ArrayList<>(Arrays.asList(ArrayUtils
Expand Down Expand Up @@ -158,6 +160,9 @@ public static MultiColumnEncoder createEncoder(String spec, String[] colnames, i
if(!weIDs.isEmpty())
for(Integer id : weIDs)
addEncoderToMap(new ColumnEncoderWordEmbedding(id), colEncoders);
if(!ragIDs.isEmpty())
for(Integer id : ragIDs)
addEncoderToMap(new ColumnEncoderRagged(id), colEncoders);
if(!bowIDs.isEmpty())
for(Integer id : bowIDs)
addEncoderToMap(new ColumnEncoderBagOfWords(id), colEncoders);
Expand Down Expand Up @@ -287,6 +292,8 @@ public static ColumnEncoder createInstance(int type) {
return new ColumnEncoderWordEmbedding();
case BagOfWords:
return new ColumnEncoderBagOfWords();
case Ragged:
return new ColumnEncoderRagged();
default:
throw new DMLRuntimeException("Unsupported encoder type: " + etype);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,10 @@ public void test(String spec) {
try {

FrameBlock meta = null;
MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(),
meta);
MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), meta);
MatrixBlock out = encoder.encode(data);
meta = encoder.getMetaData(meta); //I added this just to have the frame stored somewhere
System.out.println(meta);
MatrixBlock out2 = encoder.apply(data);

TestUtils.compareMatrices(out, out2, 0, "Not Equal after apply");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package org.apache.sysds.test.component.frame.transform;

import static org.junit.Assert.fail;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderDummycode;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderPassThrough;
import org.apache.sysds.runtime.transform.encode.CompressedEncode;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestUtils;
import org.junit.Test;

public class TransformDummySeparatedTest extends AutomatedTestBase {
protected static final Log LOG = LogFactory.getLog(TransformDummySeparatedTest.class.getName());

final FrameBlock data;

public TransformDummySeparatedTest() {
data = TestUtils.generateRandomFrameBlock(100, new org.apache.sysds.common.Types.ValueType[] {
org.apache.sysds.common.Types.ValueType.UINT8 }, 231);
data.setSchema(new org.apache.sysds.common.Types.ValueType[] {
org.apache.sysds.common.Types.ValueType.INT32 });
}

@Test
public void testDummySeparatedBasic() {

test("{ids:true, dummycode:[1]}");

}

public void test(String spec) {
try {
FrameBlock meta = null;
MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), meta);

MatrixBlock out = encoder.encode(data);
meta = encoder.getMetaData(new FrameBlock(data.getNumColumns(), org.apache.sysds.common.Types.ValueType.STRING));
MatrixBlock out2 = encoder.apply(data);

// Compare consistency
TestUtils.compareMatrices(out, out2, 0, "Not Equal after apply");

// Print output
System.out.println("== Encoded MatrixBlock ==");
System.out.println(out.toString());

System.out.println("== Metadata FrameBlock ==");
System.out.println(meta.toString());

} catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}

@Override
public void setUp() {
// TODO Auto-generated method stub
//throw new UnsupportedOperationException("Unimplemented method 'setUp'");
}
}
Loading
Loading