Skip to content

Commit

Permalink
More efficient map side inputs for small maps. (#29587)
Browse files Browse the repository at this point in the history
The ability to perform point lookups for multi-map side inputs is great
for maps that are too large to fit into memory, but can be very inefficient
in requiring an entire state request per key for small maps.

This change adds an optional protocol to request an entire map as a stream
of key-values in one (possibly paginated) API call, and uses this to bulk
pre-fetch an initial set of values from the map before falling back to
point lookups.
  • Loading branch information
robertwb authored Dec 8, 2023
1 parent ed3a582 commit 85bf388
Show file tree
Hide file tree
Showing 8 changed files with 274 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ import "google/protobuf/duration.proto";
message FnApiTransforms {
enum Runner {
// DataSource is a Root Transform, and a source of data for downstream
// transforms in the same ProcessBundleDescriptor.
// transforms in the same ProcessBundleDescriptor.
// It represents a stream of values coming in from an external source/over
// a data channel, typically from the runner. It's not the PCollection itself
// but a description of how to get the portion of the PCollection for a given
Expand Down Expand Up @@ -82,7 +82,7 @@ message FnApiTransforms {
// request will be sent with the matching instruction ID and transform ID.
// Each PCollection that exits the ProcessBundleDescriptor subgraph will have
// it's own DataSink, keyed by a transform ID determined by the runner.
//
//
// The DataSink will take in a stream of elements for a given instruction ID
// and encode them for transmission to the remote sink. The coder ID must be
// for a windowed value coder.
Expand Down Expand Up @@ -924,6 +924,35 @@ message StateKey {
bytes window = 3;
}

// Represents a request for the keys and values associated with a specified window in a PCollection. See
// https://s.apache.org/beam-fn-state-api-and-bundle-processing for further
// details.
//
// This is expected to be more efficient than iterating over they keys and
// looking up the values one at a time. If a runner chooses not to implement
// this protocol, or a key has too many values to fit into a single response,
// the runner is free to fail the request and a fallback to point lookups
// will be performed by the SDK.
//
// Can only be used to perform StateGetRequests on side inputs of the URN
// beam:side_input:multimap:v1.
//
// For a PCollection<KV<K, V>>, the response data stream will be a
// concatenation of all KVs associated with the specified window,
// encoded with the the KV<K, Iterable<V>> coder.
// See
// https://s.apache.org/beam-fn-api-send-and-receive-data for further
// details.
message MultimapKeysValuesSideInput {
// (Required) The id of the PTransform containing a side input.
string transform_id = 1;
// (Required) The id of the side input.
string side_input_id = 2;
// (Required) The window (after mapping the currently executing elements
// window into the side input windows domain) encoded in a nested context.
bytes window = 3;
}

// Represents a request for an unordered set of values associated with a
// specified user key and window for a PTransform. See
// https://s.apache.org/beam-fn-state-api-and-bundle-processing for further
Expand Down Expand Up @@ -999,6 +1028,7 @@ message StateKey {
BagUserState bag_user_state = 3;
IterableSideInput iterable_side_input = 4;
MultimapKeysSideInput multimap_keys_side_input = 5;
MultimapKeysValuesSideInput multimap_keys_values_side_input = 8;
MultimapKeysUserState multimap_keys_user_state = 6;
MultimapUserState multimap_user_state = 7;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -915,9 +915,10 @@ public Coder<V> valueCoder() {
// Expect the following requests for the first bundle:
// * one to read iterable side input
// * one to read keys from multimap side input
// * one to attempt multimap side input bulk read
// * one to read key1 iterable from multimap side input
// * one to read key2 iterable from multimap side input
assertEquals(4, stateRequestHandler.receivedRequests.size());
assertEquals(5, stateRequestHandler.receivedRequests.size());
assertEquals(
stateRequestHandler.receivedRequests.get(0).getStateKey().getIterableSideInput(),
BeamFnApi.StateKey.IterableSideInput.newBuilder()
Expand All @@ -931,14 +932,20 @@ public Coder<V> valueCoder() {
.setTransformId(transformId)
.build());
assertEquals(
stateRequestHandler.receivedRequests.get(2).getStateKey().getMultimapSideInput(),
stateRequestHandler.receivedRequests.get(2).getStateKey().getMultimapKeysValuesSideInput(),
BeamFnApi.StateKey.MultimapKeysValuesSideInput.newBuilder()
.setSideInputId(multimapView.getTagInternal().getId())
.setTransformId(transformId)
.build());
assertEquals(
stateRequestHandler.receivedRequests.get(3).getStateKey().getMultimapSideInput(),
BeamFnApi.StateKey.MultimapSideInput.newBuilder()
.setSideInputId(multimapView.getTagInternal().getId())
.setTransformId(transformId)
.setKey(encode("key1"))
.build());
assertEquals(
stateRequestHandler.receivedRequests.get(3).getStateKey().getMultimapSideInput(),
stateRequestHandler.receivedRequests.get(4).getStateKey().getMultimapSideInput(),
BeamFnApi.StateKey.MultimapSideInput.newBuilder()
.setSideInputId(multimapView.getTagInternal().getId())
.setTransformId(transformId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,21 @@
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.function.Function;
import org.apache.beam.fn.harness.Cache;
import org.apache.beam.fn.harness.Caches;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.transforms.Materializations.MultimapView;
import org.apache.beam.sdk.util.ByteStringOutputStream;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;

/**
Expand All @@ -38,11 +46,14 @@
})
public class MultimapSideInput<K, V> implements MultimapView<K, V> {

private static final int BULK_READ_SIZE = 100;

private final Cache<?, ?> cache;
private final BeamFnStateClient beamFnStateClient;
private final StateRequest keysRequest;
private final Coder<K> keyCoder;
private final Coder<V> valueCoder;
private volatile Function<ByteString, Iterable<V>> bulkReadResult;

public MultimapSideInput(
Cache<?, ?> cache,
Expand Down Expand Up @@ -71,17 +82,66 @@ public Iterable<K> get() {

@Override
public Iterable<V> get(K k) {
ByteStringOutputStream output = new ByteStringOutputStream();
try {
keyCoder.encode(k, output);
} catch (IOException e) {
throw new IllegalStateException(
String.format(
"Failed to encode key %s for side input id %s.",
k, keysRequest.getStateKey().getMultimapKeysSideInput().getSideInputId()),
e);
ByteString encodedKey = encodeKey(k);

if (bulkReadResult == null) {
synchronized (this) {
if (bulkReadResult == null) {
Map<ByteString, Iterable<V>> bulkRead = new HashMap<>();
StateKey bulkReadStateKey =
StateKey.newBuilder()
.setMultimapKeysValuesSideInput(
StateKey.MultimapKeysValuesSideInput.newBuilder()
.setTransformId(
keysRequest.getStateKey().getMultimapKeysSideInput().getTransformId())
.setSideInputId(
keysRequest.getStateKey().getMultimapKeysSideInput().getSideInputId())
.setWindow(
keysRequest.getStateKey().getMultimapKeysSideInput().getWindow()))
.build();

StateRequest bulkReadRequest =
keysRequest.toBuilder().setStateKey(bulkReadStateKey).build();
try {
Iterator<KV<K, Iterable<V>>> entries =
StateFetchingIterators.readAllAndDecodeStartingFrom(
Caches.subCache(cache, "ValuesForKey", encodedKey),
beamFnStateClient,
bulkReadRequest,
KvCoder.of(keyCoder, IterableCoder.of(valueCoder)))
.iterator();
while (bulkRead.size() < BULK_READ_SIZE && entries.hasNext()) {
KV<K, Iterable<V>> entry = entries.next();
bulkRead.put(encodeKey(entry.getKey()), entry.getValue());
}
if (entries.hasNext()) {
bulkReadResult = bulkRead::get;
} else {
bulkReadResult =
key -> {
Iterable<V> result = bulkRead.get(key);
if (result == null) {
// As we read the entire set of values, we don't have to do a lookup to know
// this key doesn't exist.
// Missing keys are treated as empty iterables in this multimap.
return Collections.emptyList();
} else {
return result;
}
};
}
} catch (Exception exn) {
bulkReadResult = bulkRead::get;
}
}
}
}

Iterable<V> bulkReadValues = bulkReadResult.apply(encodedKey);
if (bulkReadValues != null) {
return bulkReadValues;
}
ByteString encodedKey = output.toByteString();

StateKey stateKey =
StateKey.newBuilder()
.setMultimapSideInput(
Expand All @@ -98,4 +158,18 @@ public Iterable<V> get(K k) {
return StateFetchingIterators.readAllAndDecodeStartingFrom(
Caches.subCache(cache, "ValuesForKey", encodedKey), beamFnStateClient, request, valueCoder);
}

private ByteString encodeKey(K k) {
ByteStringOutputStream output = new ByteStringOutputStream();
try {
keyCoder.encode(k, output);
} catch (IOException e) {
throw new IllegalStateException(
String.format(
"Failed to encode key %s for side input id %s.",
k, keysRequest.getStateKey().getMultimapKeysSideInput().getSideInputId()),
e);
}
return output.toByteString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ public CompletableFuture<StateResponse> handle(StateRequest.Builder requestBuild
if (key.getTypeCase() == TypeCase.MULTIMAP_SIDE_INPUT || key.getTypeCase() == TypeCase.RUNNER) {
assertEquals(RequestCase.GET, request.getRequestCase());
}
if (key.getTypeCase() == TypeCase.MULTIMAP_KEYS_VALUES_SIDE_INPUT && !data.containsKey(key)) {
// Allow testing this not being supported rather than blindly returning the empty list.
throw new UnsupportedOperationException("No multimap keys values states provided.");
}

switch (request.getRequestCase()) {
case GET:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.util.ByteStringOutputStream;
import org.apache.beam.sdk.values.KV;
Expand All @@ -50,12 +52,38 @@ public class MultimapSideInputTest {
private static final byte[] B = "B".getBytes(StandardCharsets.UTF_8);
private static final byte[] UNKNOWN = "UNKNOWN".getBytes(StandardCharsets.UTF_8);

@Test
public void testGetWithBulkRead() throws Exception {
FakeBeamFnStateClient fakeBeamFnStateClient =
new FakeBeamFnStateClient(
ImmutableMap.of(
keysValuesStateKey(),
KV.of(
KvCoder.of(ByteArrayCoder.of(), IterableCoder.of(StringUtf8Coder.of())),
asList(KV.of(A, asList("A1", "A2", "A3")), KV.of(B, asList("B1", "B2"))))));

MultimapSideInput<byte[], String> multimapSideInput =
new MultimapSideInput<>(
Caches.noop(),
fakeBeamFnStateClient,
"instructionId",
keysStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
assertArrayEquals(
new String[] {"A1", "A2", "A3"}, Iterables.toArray(multimapSideInput.get(A), String.class));
assertArrayEquals(
new String[] {"B1", "B2"}, Iterables.toArray(multimapSideInput.get(B), String.class));
assertArrayEquals(
new String[] {}, Iterables.toArray(multimapSideInput.get(UNKNOWN), String.class));
}

@Test
public void testGet() throws Exception {
FakeBeamFnStateClient fakeBeamFnStateClient =
new FakeBeamFnStateClient(
ImmutableMap.of(
stateKey(), KV.of(ByteArrayCoder.of(), asList(A, B)),
keysStateKey(), KV.of(ByteArrayCoder.of(), asList(A, B)),
key(A), KV.of(StringUtf8Coder.of(), asList("A1", "A2", "A3")),
key(B), KV.of(StringUtf8Coder.of(), asList("B1", "B2"))));

Expand All @@ -64,7 +92,7 @@ public void testGet() throws Exception {
Caches.noop(),
fakeBeamFnStateClient,
"instructionId",
stateKey(),
keysStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
assertArrayEquals(
Expand All @@ -82,7 +110,7 @@ public void testGetCached() throws Exception {
FakeBeamFnStateClient fakeBeamFnStateClient =
new FakeBeamFnStateClient(
ImmutableMap.of(
stateKey(), KV.of(ByteArrayCoder.of(), asList(A, B)),
keysStateKey(), KV.of(ByteArrayCoder.of(), asList(A, B)),
key(A), KV.of(StringUtf8Coder.of(), asList("A1", "A2", "A3")),
key(B), KV.of(StringUtf8Coder.of(), asList("B1", "B2"))));

Expand All @@ -94,7 +122,7 @@ public void testGetCached() throws Exception {
cache,
fakeBeamFnStateClient,
"instructionId",
stateKey(),
keysStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
assertArrayEquals(
Expand All @@ -117,7 +145,7 @@ public void testGetCached() throws Exception {
throw new IllegalStateException("Unexpected call for test.");
},
"instructionId",
stateKey(),
keysStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
assertArrayEquals(
Expand All @@ -132,7 +160,7 @@ public void testGetCached() throws Exception {
}
}

private StateKey stateKey() throws IOException {
private StateKey keysStateKey() throws IOException {
return StateKey.newBuilder()
.setMultimapKeysSideInput(
StateKey.MultimapKeysSideInput.newBuilder()
Expand All @@ -142,6 +170,16 @@ private StateKey stateKey() throws IOException {
.build();
}

private StateKey keysValuesStateKey() throws IOException {
return StateKey.newBuilder()
.setMultimapKeysValuesSideInput(
StateKey.MultimapKeysValuesSideInput.newBuilder()
.setTransformId("ptransformId")
.setSideInputId("sideInputId")
.setWindow(ByteString.copyFromUtf8("encodedWindow")))
.build();
}

private StateKey key(byte[] key) throws IOException {
ByteStringOutputStream out = new ByteStringOutputStream();
ByteArrayCoder.of().encode(key, out);
Expand Down
Loading

0 comments on commit 85bf388

Please sign in to comment.