Skip to content

Commit

Permalink
Add RollingAvg feature to UpdateBy (#3503)
Browse files Browse the repository at this point in the history
  • Loading branch information
lbooker42 authored Mar 28, 2023
1 parent fc9de8f commit 17fb852
Show file tree
Hide file tree
Showing 26 changed files with 2,374 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import io.deephaven.engine.table.impl.updateby.fill.*;
import io.deephaven.engine.table.impl.updateby.minmax.*;
import io.deephaven.engine.table.impl.updateby.prod.*;
import io.deephaven.engine.table.impl.updateby.rollinggroup.RollingGroupOperator;
import io.deephaven.engine.table.impl.updateby.rollingavg.*;
import io.deephaven.engine.table.impl.updateby.rollinggroup.*;
import io.deephaven.engine.table.impl.updateby.rollingsum.*;
import io.deephaven.engine.table.impl.updateby.sum.*;
import io.deephaven.engine.table.impl.util.WritableRowRedirection;
Expand Down Expand Up @@ -371,6 +372,20 @@ public Void visit(@NotNull final RollingGroupSpec rg) {
return null;
}

@Override
public Void visit(@NotNull final RollingAvgSpec rs) {
final boolean isTimeBased = rs.revWindowScale().isTimeBased();
final String timestampCol = rs.revWindowScale().timestampCol();

Arrays.stream(pairs)
.filter(p -> !isTimeBased || !p.rightColumn().equals(timestampCol))
.map(fc -> makeRollingAvgOperator(fc,
source,
rs))
.forEach(ops::add);
return null;
}

private UpdateByOperator makeEmaOperator(@NotNull final MatchPair pair,
@NotNull final Table source,
@NotNull final EmaSpec ema) {
Expand All @@ -391,7 +406,7 @@ private UpdateByOperator makeEmaOperator(@NotNull final MatchPair pair,

if (csType == byte.class || csType == Byte.class) {
return new ByteEMAOperator(pair, affectingColumns, rowRedirection, control,
ema.timeScale().timestampCol(), timeScaleUnits, columnSource);
ema.timeScale().timestampCol(), timeScaleUnits, columnSource, NULL_BYTE);
} else if (csType == short.class || csType == Short.class) {
return new ShortEMAOperator(pair, affectingColumns, rowRedirection, control,
ema.timeScale().timestampCol(), timeScaleUnits, columnSource);
Expand Down Expand Up @@ -421,7 +436,7 @@ private UpdateByOperator makeEmaOperator(@NotNull final MatchPair pair,
private UpdateByOperator makeCumProdOperator(MatchPair fc, Table source) {
final Class<?> csType = source.getColumnSource(fc.rightColumn).getType();
if (csType == byte.class || csType == Byte.class) {
return new ByteCumProdOperator(fc, rowRedirection);
return new ByteCumProdOperator(fc, rowRedirection, NULL_BYTE);
} else if (csType == short.class || csType == Short.class) {
return new ShortCumProdOperator(fc, rowRedirection);
} else if (csType == int.class || csType == Integer.class) {
Expand All @@ -445,7 +460,7 @@ private UpdateByOperator makeCumMinMaxOperator(MatchPair fc, Table source, boole
final ColumnSource<?> columnSource = source.getColumnSource(fc.rightColumn);
final Class<?> csType = columnSource.getType();
if (csType == byte.class || csType == Byte.class) {
return new ByteCumMinMaxOperator(fc, isMax, rowRedirection);
return new ByteCumMinMaxOperator(fc, isMax, rowRedirection, NULL_BYTE);
} else if (csType == short.class || csType == Short.class) {
return new ShortCumMinMaxOperator(fc, isMax, rowRedirection);
} else if (csType == int.class || csType == Integer.class) {
Expand Down Expand Up @@ -601,5 +616,67 @@ private UpdateByOperator makeRollingGroupOperator(@NotNull final MatchPair[] pai
rg.revWindowScale().timestampCol(),
prevWindowScaleUnits, fwdWindowScaleUnits, columnSources);
}

private UpdateByOperator makeRollingAvgOperator(@NotNull final MatchPair pair,
@NotNull final Table source,
@NotNull final RollingAvgSpec rs) {
// noinspection rawtypes
final ColumnSource columnSource = source.getColumnSource(pair.rightColumn);
final Class<?> csType = columnSource.getType();

final String[] affectingColumns;
if (rs.revWindowScale().timestampCol() == null) {
affectingColumns = new String[] {pair.rightColumn};
} else {
affectingColumns = new String[] {rs.revWindowScale().timestampCol(), pair.rightColumn};
}

final long prevWindowScaleUnits = rs.revWindowScale().timescaleUnits();
final long fwdWindowScaleUnits = rs.fwdWindowScale().timescaleUnits();

if (csType == Boolean.class || csType == boolean.class) {
return new ByteRollingAvgOperator(pair, affectingColumns, rowRedirection,
rs.revWindowScale().timestampCol(),
prevWindowScaleUnits, fwdWindowScaleUnits, NULL_BOOLEAN_AS_BYTE);
} else if (csType == byte.class || csType == Byte.class) {
return new ByteRollingAvgOperator(pair, affectingColumns, rowRedirection,
rs.revWindowScale().timestampCol(),
prevWindowScaleUnits, fwdWindowScaleUnits, NULL_BYTE);
} else if (csType == char.class || csType == Character.class) {
return new CharRollingAvgOperator(pair, affectingColumns, rowRedirection,
rs.revWindowScale().timestampCol(),
prevWindowScaleUnits, fwdWindowScaleUnits);
} else if (csType == short.class || csType == Short.class) {
return new ShortRollingAvgOperator(pair, affectingColumns, rowRedirection,
rs.revWindowScale().timestampCol(),
prevWindowScaleUnits, fwdWindowScaleUnits);
} else if (csType == int.class || csType == Integer.class) {
return new IntRollingAvgOperator(pair, affectingColumns, rowRedirection,
rs.revWindowScale().timestampCol(),
prevWindowScaleUnits, fwdWindowScaleUnits);
} else if (csType == long.class || csType == Long.class) {
return new LongRollingAvgOperator(pair, affectingColumns, rowRedirection,
rs.revWindowScale().timestampCol(),
prevWindowScaleUnits, fwdWindowScaleUnits);
} else if (csType == float.class || csType == Float.class) {
return new FloatRollingAvgOperator(pair, affectingColumns, rowRedirection,
rs.revWindowScale().timestampCol(),
prevWindowScaleUnits, fwdWindowScaleUnits);
} else if (csType == double.class || csType == Double.class) {
return new DoubleRollingAvgOperator(pair, affectingColumns, rowRedirection,
rs.revWindowScale().timestampCol(),
prevWindowScaleUnits, fwdWindowScaleUnits);
} else if (csType == BigDecimal.class) {
return new BigDecimalRollingAvgOperator(pair, affectingColumns, rowRedirection,
rs.revWindowScale().timestampCol(),
prevWindowScaleUnits, fwdWindowScaleUnits, control.mathContextOrDefault());
} else if (csType == BigInteger.class) {
return new BigIntegerRollingAvgOperator(pair, affectingColumns, rowRedirection,
rs.revWindowScale().timestampCol(),
prevWindowScaleUnits, fwdWindowScaleUnits, control.mathContextOrDefault());
}

throw new IllegalArgumentException("Can not perform RollingSum on type " + csType);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

public class ByteEMAOperator extends BasePrimitiveEMAOperator {
public final ColumnSource<?> valueSource;
// region extra-fields
final byte nullValue;
// endregion extra-fields

protected class Context extends BasePrimitiveEMAOperator.Context {

Expand All @@ -45,7 +48,7 @@ public void accumulateCumulative(RowSequence inputKeys,
// read the value from the values chunk
final byte input = byteValueChunk.get(ii);

if (input == NULL_BYTE) {
if (input == nullValue) {
handleBadData(this, true, false);
} else {
if (curVal == NULL_DOUBLE) {
Expand All @@ -63,7 +66,7 @@ public void accumulateCumulative(RowSequence inputKeys,
final byte input = byteValueChunk.get(ii);
final long timestamp = tsChunk.get(ii);
//noinspection ConstantConditions
final boolean isNull = input == NULL_BYTE;
final boolean isNull = input == nullValue;
final boolean isNullTime = timestamp == NULL_LONG;
if (isNull) {
handleBadData(this, true, false);
Expand Down Expand Up @@ -96,7 +99,7 @@ public void setValuesChunk(@NotNull final Chunk<? extends Values> valuesChunk) {

@Override
public boolean isValueValid(long atKey) {
return valueSource.getByte(atKey) != NULL_BYTE;
return valueSource.getByte(atKey) != nullValue;
}

@Override
Expand Down Expand Up @@ -124,11 +127,13 @@ public ByteEMAOperator(@NotNull final MatchPair pair,
final long windowScaleUnits,
final ColumnSource<?> valueSource
// region extra-constructor-args
,final byte nullValue
// endregion extra-constructor-args
) {
super(pair, affectingColumns, rowRedirection, control, timestampColumnName, windowScaleUnits);
this.valueSource = valueSource;
// region constructor
this.nullValue = nullValue;
// endregion constructor
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

public class IntEMAOperator extends BasePrimitiveEMAOperator {
public final ColumnSource<?> valueSource;
// region extra-fields
// endregion extra-fields

protected class Context extends BasePrimitiveEMAOperator.Context {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

public class LongEMAOperator extends BasePrimitiveEMAOperator {
public final ColumnSource<?> valueSource;
// region extra-fields
// endregion extra-fields

protected class Context extends BasePrimitiveEMAOperator.Context {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

public class ShortEMAOperator extends BasePrimitiveEMAOperator {
public final ColumnSource<?> valueSource;
// region extra-fields
// endregion extra-fields

protected class Context extends BasePrimitiveEMAOperator.Context {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class ByteCumMinMaxOperator extends BaseByteUpdateByOperator {
private final boolean isMax;

// region extra-fields
final byte nullValue;
// endregion extra-fields

protected class Context extends BaseByteUpdateByOperator.Context {
Expand All @@ -42,9 +43,9 @@ public void push(int pos, int count) {

final byte val = byteValueChunk.get(pos);

if (curVal == NULL_BYTE) {
if (curVal == nullValue) {
curVal = val;
} else if (val != NULL_BYTE) {
} else if (val != nullValue) {
if ((isMax && val > curVal) ||
(!isMax && val < curVal)) {
curVal = val;
Expand All @@ -57,11 +58,13 @@ public ByteCumMinMaxOperator(@NotNull final MatchPair pair,
final boolean isMax,
@Nullable final RowRedirection rowRedirection
// region extra-constructor-args
,final byte nullValue
// endregion extra-constructor-args
) {
super(pair, new String[] { pair.rightColumn }, rowRedirection);
this.isMax = isMax;
// region constructor
this.nullValue = nullValue;
// endregion constructor
}
// region extra-methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

public class ByteCumProdOperator extends BaseLongUpdateByOperator {
// region extra-fields
final byte nullValue;
// endregion extra-fields

protected class Context extends BaseLongUpdateByOperator.Context {
Expand All @@ -41,7 +42,7 @@ public void push(int pos, int count) {

final byte val = byteValueChunk.get(pos);

if (val != NULL_BYTE) {
if (val != nullValue) {
curVal = curVal == NULL_LONG ? val : curVal * val;
}
}
Expand All @@ -50,10 +51,12 @@ public void push(int pos, int count) {
public ByteCumProdOperator(@NotNull final MatchPair pair,
@Nullable final RowRedirection rowRedirection
// region extra-constructor-args
,final byte nullValue
// endregion extra-constructor-args
) {
super(pair, new String[] { pair.rightColumn }, rowRedirection);
// region constructor
this.nullValue = nullValue;
// endregion constructor
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package io.deephaven.engine.table.impl.updateby.rollingavg;

import io.deephaven.base.RingBuffer;
import io.deephaven.chunk.Chunk;
import io.deephaven.chunk.ObjectChunk;
import io.deephaven.chunk.attributes.Values;
import io.deephaven.engine.table.MatchPair;
import io.deephaven.engine.table.impl.updateby.UpdateByOperator;
import io.deephaven.engine.table.impl.updateby.internal.BaseObjectUpdateByOperator;
import io.deephaven.engine.table.impl.util.RowRedirection;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.math.BigDecimal;
import java.math.MathContext;

public final class BigDecimalRollingAvgOperator extends BaseObjectUpdateByOperator<BigDecimal> {
private static final int RING_BUFFER_INITIAL_CAPACITY = 128;
@NotNull
private final MathContext mathContext;

protected class Context extends BaseObjectUpdateByOperator<BigDecimal>.Context {
protected ObjectChunk<BigDecimal, ? extends Values> objectInfluencerValuesChunk;
protected RingBuffer<BigDecimal> objectWindowValues;

protected Context(final int chunkSize) {
super(chunkSize);
objectWindowValues = new RingBuffer<>(RING_BUFFER_INITIAL_CAPACITY);
}

@Override
public void close() {
super.close();
objectWindowValues = null;
}


@Override
public void setValuesChunk(@NotNull final Chunk<? extends Values> valuesChunk) {
objectInfluencerValuesChunk = valuesChunk.asObjectChunk();
}

@Override
public void push(int pos, int count) {
for (int ii = 0; ii < count; ii++) {
BigDecimal val = objectInfluencerValuesChunk.get(pos + ii);
objectWindowValues.add(val);

// increase the running sum
if (val != null) {
if (curVal == null) {
curVal = val;
} else {
curVal = curVal.add(val, mathContext);
}
} else {
nullCount++;
}
}
}

@Override
public void pop(int count) {
for (int ii = 0; ii < count; ii++) {
BigDecimal val = objectWindowValues.remove();

// reduce the running sum
if (val != null) {
curVal = curVal.subtract(val, mathContext);
} else {
nullCount--;

}
}
}

@Override
public void writeToOutputChunk(int outIdx) {
if (objectWindowValues.size() == nullCount) {
outputValues.set(outIdx, null);
curVal = null;
} else {
final BigDecimal count = new BigDecimal(objectWindowValues.size() - nullCount);
outputValues.set(outIdx, curVal.divide(count, mathContext));
}
}


@Override
public void reset() {
super.reset();
objectWindowValues.clear();
}
}

@NotNull
@Override
public UpdateByOperator.Context makeUpdateContext(final int chunkSize) {
return new Context(chunkSize);
}

public BigDecimalRollingAvgOperator(@NotNull final MatchPair pair,
@NotNull final String[] affectingColumns,
@Nullable final RowRedirection rowRedirection,
@Nullable final String timestampColumnName,
final long reverseWindowScaleUnits,
final long forwardWindowScaleUnits,
@NotNull final MathContext mathContext) {
super(pair, affectingColumns, rowRedirection, timestampColumnName, reverseWindowScaleUnits,
forwardWindowScaleUnits, true, BigDecimal.class);
this.mathContext = mathContext;
}
}
Loading

0 comments on commit 17fb852

Please sign in to comment.