Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace StateTag.StateBinder to top level StateBinder in SparkStateInternals #31798

Merged
merged 3 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test"
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test",
"https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test"
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test"
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test",
"https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test"
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test"
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test",
"https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test"
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ public interface StateTag<StateT extends State> extends Serializable {
/**
* Visitor for binding a {@link StateSpec} and to the associated {@link State}.
*
* @deprecated for migration only; runners should reference the top level {@link StateBinder} and
* move towards {@link StateSpec} rather than {@link StateTag}.
* @deprecated for migration only; runners should reference the top level {@link
* org.apache.beam.sdk.state.StateBinder} and move towards {@link StateSpec} rather than
* {@link StateTag}.
*/
@Deprecated
public interface StateBinder {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.apache.beam.runners.core.StateInternals;
import org.apache.beam.runners.core.StateNamespace;
import org.apache.beam.runners.core.StateTag;
import org.apache.beam.runners.core.StateTag.StateBinder;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.InstantCoder;
Expand All @@ -42,11 +41,13 @@
import org.apache.beam.sdk.state.ReadableStates;
import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.State;
import org.apache.beam.sdk.state.StateBinder;
import org.apache.beam.sdk.state.StateContext;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.state.WatermarkHoldState;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.util.CombineFnUtil;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashBasedTable;
Expand Down Expand Up @@ -96,45 +97,47 @@ public K getKey() {
@Override
public <T extends State> T state(
StateNamespace namespace, StateTag<T> address, StateContext<?> c) {
return address.bind(new SparkStateBinder(namespace, c));
return address.getSpec().bind(address.getId(), new SparkStateBinder(namespace, c));
}

private class SparkStateBinder implements StateBinder {
private final StateNamespace namespace;
private final StateContext<?> c;
private final StateContext<?> stateContext;

private SparkStateBinder(StateNamespace namespace, StateContext<?> c) {
private SparkStateBinder(StateNamespace namespace, StateContext<?> stateContext) {
this.namespace = namespace;
this.c = c;
this.stateContext = stateContext;
}

@Override
public <T> ValueState<T> bindValue(StateTag<ValueState<T>> address, Coder<T> coder) {
return new SparkValueState<>(namespace, address, coder);
public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>> spec, Coder<T> coder) {
return new SparkValueState<>(namespace, id, coder);
}

@Override
public <T> BagState<T> bindBag(StateTag<BagState<T>> address, Coder<T> elemCoder) {
return new SparkBagState<>(namespace, address, elemCoder);
public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec, Coder<T> elemCoder) {
return new SparkBagState<>(namespace, id, elemCoder);
}

@Override
public <T> SetState<T> bindSet(StateTag<SetState<T>> spec, Coder<T> elemCoder) {
public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, Coder<T> elemCoder) {
throw new UnsupportedOperationException(
String.format("%s is not supported", SetState.class.getSimpleName()));
}

@Override
public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
StateTag<MapState<KeyT, ValueT>> address,
String id,
StateSpec<MapState<KeyT, ValueT>> spec,
Coder<KeyT> mapKeyCoder,
Coder<ValueT> mapValueCoder) {
return new SparkMapState<>(namespace, address, MapCoder.of(mapKeyCoder, mapValueCoder));
return new SparkMapState<>(namespace, id, MapCoder.of(mapKeyCoder, mapValueCoder));
}

@Override
public <KeyT, ValueT> MultimapState<KeyT, ValueT> bindMultimap(
StateTag<MultimapState<KeyT, ValueT>> spec,
String id,
StateSpec<MultimapState<KeyT, ValueT>> spec,
Coder<KeyT> keyCoder,
Coder<ValueT> valueCoder) {
throw new UnsupportedOperationException(
Expand All @@ -143,63 +146,63 @@ public <KeyT, ValueT> MultimapState<KeyT, ValueT> bindMultimap(

@Override
public <T> OrderedListState<T> bindOrderedList(
StateTag<OrderedListState<T>> spec, Coder<T> elemCoder) {
String id, StateSpec<OrderedListState<T>> spec, Coder<T> elemCoder) {
throw new UnsupportedOperationException(
String.format("%s is not supported", OrderedListState.class.getSimpleName()));
}

@Override
public <InputT, AccumT, OutputT> CombiningState<InputT, AccumT, OutputT> bindCombiningValue(
StateTag<CombiningState<InputT, AccumT, OutputT>> address,
public <InputT, AccumT, OutputT> CombiningState<InputT, AccumT, OutputT> bindCombining(
String id,
StateSpec<CombiningState<InputT, AccumT, OutputT>> spec,
Coder<AccumT> accumCoder,
CombineFn<InputT, AccumT, OutputT> combineFn) {
return new SparkCombiningState<>(namespace, address, accumCoder, combineFn);
return new SparkCombiningState<>(namespace, id, accumCoder, combineFn);
}

@Override
public <InputT, AccumT, OutputT>
CombiningState<InputT, AccumT, OutputT> bindCombiningValueWithContext(
StateTag<CombiningState<InputT, AccumT, OutputT>> address,
CombiningState<InputT, AccumT, OutputT> bindCombiningWithContext(
String id,
StateSpec<CombiningState<InputT, AccumT, OutputT>> spec,
Coder<AccumT> accumCoder,
CombineFnWithContext<InputT, AccumT, OutputT> combineFn) {
CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn) {
return new SparkCombiningState<>(
namespace, address, accumCoder, CombineFnUtil.bindContext(combineFn, c));
namespace, id, accumCoder, CombineFnUtil.bindContext(combineFn, stateContext));
}

@Override
public WatermarkHoldState bindWatermark(
StateTag<WatermarkHoldState> address, TimestampCombiner timestampCombiner) {
return new SparkWatermarkHoldState(namespace, address, timestampCombiner);
String id, StateSpec<WatermarkHoldState> spec, TimestampCombiner timestampCombiner) {
return new SparkWatermarkHoldState(namespace, id, timestampCombiner);
}
}

private class AbstractState<T> {
final StateNamespace namespace;
final StateTag<? extends State> address;
final String id;
final Coder<T> coder;

private AbstractState(
StateNamespace namespace, StateTag<? extends State> address, Coder<T> coder) {
private AbstractState(StateNamespace namespace, String id, Coder<T> coder) {
this.namespace = namespace;
this.address = address;
this.id = id;
this.coder = coder;
}

T readValue() {
byte[] buf = stateTable.get(namespace.stringKey(), address.getId());
byte[] buf = stateTable.get(namespace.stringKey(), id);
if (buf != null) {
return CoderHelpers.fromByteArray(buf, coder);
}
return null;
}

void writeValue(T input) {
stateTable.put(
namespace.stringKey(), address.getId(), CoderHelpers.toByteArray(input, coder));
stateTable.put(namespace.stringKey(), id, CoderHelpers.toByteArray(input, coder));
}

public void clear() {
stateTable.remove(namespace.stringKey(), address.getId());
stateTable.remove(namespace.stringKey(), id);
}

@Override
Expand All @@ -212,22 +215,21 @@ public boolean equals(@Nullable Object o) {
}
@SuppressWarnings("unchecked")
AbstractState<?> that = (AbstractState<?>) o;
return namespace.equals(that.namespace) && address.equals(that.address);
return namespace.equals(that.namespace) && id.equals(that.id);
}

@Override
public int hashCode() {
int result = namespace.hashCode();
result = 31 * result + address.hashCode();
result = 31 * result + id.hashCode();
return result;
}
}

private class SparkValueState<T> extends AbstractState<T> implements ValueState<T> {

private SparkValueState(
StateNamespace namespace, StateTag<ValueState<T>> address, Coder<T> coder) {
super(namespace, address, coder);
private SparkValueState(StateNamespace namespace, String id, Coder<T> coder) {
super(namespace, id, coder);
}

@Override
Expand All @@ -252,10 +254,8 @@ private class SparkWatermarkHoldState extends AbstractState<Instant>
private final TimestampCombiner timestampCombiner;

SparkWatermarkHoldState(
StateNamespace namespace,
StateTag<WatermarkHoldState> address,
TimestampCombiner timestampCombiner) {
super(namespace, address, InstantCoder.of());
StateNamespace namespace, String id, TimestampCombiner timestampCombiner) {
super(namespace, id, InstantCoder.of());
this.timestampCombiner = timestampCombiner;
}

Expand Down Expand Up @@ -287,7 +287,7 @@ public ReadableState<Boolean> readLater() {

@Override
public Boolean read() {
return stateTable.get(namespace.stringKey(), address.getId()) == null;
return stateTable.get(namespace.stringKey(), id) == null;
}
};
}
Expand All @@ -299,22 +299,22 @@ public TimestampCombiner getTimestampCombiner() {
}

@SuppressWarnings("TypeParameterShadowing")
private class SparkCombiningState<K, InputT, AccumT, OutputT> extends AbstractState<AccumT>
private class SparkCombiningState<KeyT, InputT, AccumT, OutputT> extends AbstractState<AccumT>
implements CombiningState<InputT, AccumT, OutputT> {

private final CombineFn<InputT, AccumT, OutputT> combineFn;

private SparkCombiningState(
StateNamespace namespace,
StateTag<CombiningState<InputT, AccumT, OutputT>> address,
String id,
Coder<AccumT> coder,
CombineFn<InputT, AccumT, OutputT> combineFn) {
super(namespace, address, coder);
super(namespace, id, coder);
this.combineFn = combineFn;
}

@Override
public SparkCombiningState<K, InputT, AccumT, OutputT> readLater() {
public SparkCombiningState<KeyT, InputT, AccumT, OutputT> readLater() {
return this;
}

Expand Down Expand Up @@ -348,7 +348,7 @@ public ReadableState<Boolean> readLater() {

@Override
public Boolean read() {
return stateTable.get(namespace.stringKey(), address.getId()) == null;
return stateTable.get(namespace.stringKey(), id) == null;
}
};
}
Expand All @@ -369,10 +369,8 @@ private final class SparkMapState<MapKeyT, MapValueT>
extends AbstractState<Map<MapKeyT, MapValueT>> implements MapState<MapKeyT, MapValueT> {

private SparkMapState(
StateNamespace namespace,
StateTag<? extends State> address,
Coder<Map<MapKeyT, MapValueT>> coder) {
super(namespace, address, coder);
StateNamespace namespace, String id, Coder<Map<MapKeyT, MapValueT>> coder) {
super(namespace, id, coder);
}

@Override
Expand Down Expand Up @@ -490,7 +488,7 @@ public ReadableState<Boolean> isEmpty() {
return new ReadableState<Boolean>() {
@Override
public Boolean read() {
return stateTable.get(namespace.stringKey(), address.getId()) == null;
return stateTable.get(namespace.stringKey(), id) == null;
}

@Override
Expand All @@ -502,8 +500,8 @@ public ReadableState<Boolean> readLater() {
}

private final class SparkBagState<T> extends AbstractState<List<T>> implements BagState<T> {
private SparkBagState(StateNamespace namespace, StateTag<BagState<T>> address, Coder<T> coder) {
super(namespace, address, ListCoder.of(coder));
private SparkBagState(StateNamespace namespace, String id, Coder<T> coder) {
super(namespace, id, ListCoder.of(coder));
}

@Override
Expand Down Expand Up @@ -537,7 +535,7 @@ public ReadableState<Boolean> readLater() {

@Override
public Boolean read() {
return stateTable.get(namespace.stringKey(), address.getId()) == null;
return stateTable.get(namespace.stringKey(), id) == null;
}
};
}
Expand Down
Loading