Skip to content

Commit

Permalink
Add BucketedInput#getInputs to unlock SMB taps
Browse files Browse the repository at this point in the history
  • Loading branch information
clairemcginty committed Dec 19, 2023
1 parent ab1ec15 commit 25dbe46
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;

/** API for reading and writing Avro sorted-bucket files. */
public class AvroSortedBucketIO {
Expand Down Expand Up @@ -191,6 +192,9 @@ public static <K1, K2, T extends SpecificRecord> TransformOutput<K1, K2, T> tran
/** Reads from Avro sorted-bucket files, to be used with {@link SortedBucketIO.CoGbk}. */
@AutoValue
public abstract static class Read<T extends IndexedRecord> extends SortedBucketIO.Read<T> {
@Nullable
abstract ImmutableList<String> getInputDirectories();

abstract String getFilenameSuffix();

@Nullable
Expand Down Expand Up @@ -244,7 +248,7 @@ public Read<T> withPredicate(Predicate<T> predicate) {
}

@Override
protected SortedBucketSource.BucketedInput<T> toBucketedInput(
public SortedBucketSource.BucketedInput<T> toBucketedInput(
final SortedBucketSource.Keying keying) {
@SuppressWarnings("unchecked")
final AvroFileOperations<T> fileOperations =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;

/** API for reading and writing BigQuery {@link TableRow} JSON sorted-bucket files. */
public class JsonSortedBucketIO {
Expand Down Expand Up @@ -107,6 +108,9 @@ public static <K1, K2> TransformOutput<K1, K2> transformOutput(
*/
@AutoValue
public abstract static class Read extends SortedBucketIO.Read<TableRow> {
@Nullable
abstract ImmutableList<String> getInputDirectories();

abstract String getFilenameSuffix();

abstract Compression getCompression();
Expand Down Expand Up @@ -153,7 +157,7 @@ public Read withPredicate(Predicate<TableRow> predicate) {
}

@Override
protected BucketedInput<TableRow> toBucketedInput(final SortedBucketSource.Keying keying) {
public BucketedInput<TableRow> toBucketedInput(final SortedBucketSource.Keying keying) {
return BucketedInput.of(
keying,
getTupleTag(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.hadoop.conf.Configuration;
import org.apache.parquet.filter2.predicate.FilterPredicate;
import org.apache.parquet.hadoop.metadata.CompressionCodecName;
Expand Down Expand Up @@ -199,6 +200,8 @@ public static <K1, K2, T extends SpecificRecord> TransformOutput<K1, K2, T> tran
/** Reads from Avro sorted-bucket files, to be used with {@link SortedBucketIO.CoGbk}. */
@AutoValue
public abstract static class Read<T extends IndexedRecord> extends SortedBucketIO.Read<T> {
@Nullable
abstract ImmutableList<String> getInputDirectories();

abstract String getFilenameSuffix();

Expand Down Expand Up @@ -272,7 +275,7 @@ public Read<T> withConfiguration(Configuration configuration) {
}

@Override
protected BucketedInput<T> toBucketedInput(final SortedBucketSource.Keying keying) {
public BucketedInput<T> toBucketedInput(final SortedBucketSource.Keying keying) {
final Schema schema =
getRecordClass() == null
? getSchema()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,12 +493,9 @@ public abstract static class TransformOutput<K1, K2, V> implements Serializable

/** Represents a single sorted-bucket source written using {@link SortedBucketSink}. */
public abstract static class Read<V> implements Serializable {
@Nullable
public abstract ImmutableList<String> getInputDirectories();

public abstract TupleTag<V> getTupleTag();

protected abstract BucketedInput<V> toBucketedInput(SortedBucketSource.Keying keying);
public abstract BucketedInput<V> toBucketedInput(SortedBucketSource.Keying keying);
}

@FunctionalInterface
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ public PrimaryKeyedBucketedInput(

public SourceMetadata<V> getSourceMetadata() {
if (sourceMetadata == null)
sourceMetadata = BucketMetadataUtil.get().getPrimaryKeyedSourceMetadata(directories);
sourceMetadata = BucketMetadataUtil.get().getPrimaryKeyedSourceMetadata(inputs);
return sourceMetadata;
}
}
Expand Down Expand Up @@ -408,8 +408,7 @@ public PrimaryAndSecondaryKeyedBucktedInput(

public SourceMetadata<V> getSourceMetadata() {
if (sourceMetadata == null)
sourceMetadata =
BucketMetadataUtil.get().getPrimaryAndSecondaryKeyedSourceMetadata(directories);
sourceMetadata = BucketMetadataUtil.get().getPrimaryAndSecondaryKeyedSourceMetadata(inputs);
return sourceMetadata;
}
}
Expand All @@ -424,7 +423,7 @@ public abstract static class BucketedInput<V> implements Serializable {
private static final Pattern BUCKET_PATTERN = Pattern.compile("(\\d+)-of-(\\d+)");

protected TupleTag<V> tupleTag;
protected Map<ResourceId, KV<String, FileOperations<V>>> directories;
protected Map<ResourceId, KV<String, FileOperations<V>>> inputs;
protected Predicate<V> predicate;
protected Keying keying;
// lazy, internal checks depend on what kind of iteration is requested
Expand Down Expand Up @@ -478,7 +477,7 @@ public BucketedInput(
Predicate<V> predicate) {
this.keying = keying;
this.tupleTag = tupleTag;
this.directories =
this.inputs =
directories.entrySet().stream()
.collect(
Collectors.toMap(
Expand All @@ -496,9 +495,13 @@ public Predicate<V> getPredicate() {
return predicate;
}

public Map<ResourceId, KV<String, FileOperations<V>>> getInputs() {
return inputs;
}

public Coder<V> getCoder() {
final KV<String, FileOperations<V>> sampledSource =
directories.entrySet().iterator().next().getValue();
inputs.entrySet().iterator().next().getValue();
return sampledSource.getValue().getCoder();
}

Expand All @@ -520,7 +523,7 @@ private static List<Metadata> sampleDirectory(ResourceId directory, String filep
}

long getOrSampleByteSize() {
return directories
return inputs
.entrySet()
.parallelStream()
.mapToLong(
Expand Down Expand Up @@ -596,8 +599,7 @@ public KeyGroupIterator<V> createIterator(
try {
Iterator<KV<SortedBucketIO.ComparableKeyBytes, V>> iterator =
Iterators.transform(
directories.get(dir).getValue().iterator(file),
v -> KV.of(keyFn.apply(v), v));
inputs.get(dir).getValue().iterator(file), v -> KV.of(keyFn.apply(v), v));
Iterator<KV<SortedBucketIO.ComparableKeyBytes, V>> out =
(bufferSize > 0) ? new BufferedIterator<>(iterator, bufferSize) : iterator;
iterators.add(out);
Expand All @@ -612,7 +614,7 @@ public KeyGroupIterator<V> createIterator(

@Override
public String toString() {
List<ResourceId> inputDirectories = new ArrayList<>(directories.keySet());
List<ResourceId> inputDirectories = new ArrayList<>(inputs.keySet());
return String.format(
"BucketedInput[tupleTag=%s, inputDirectories=[%s]]",
tupleTag.getId(),
Expand All @@ -638,7 +640,7 @@ private void writeObject(ObjectOutputStream outStream) throws IOException {
final Map<ResourceId, Integer> directoriesEncoding = new HashMap<>();
int i = 0;

for (Map.Entry<ResourceId, KV<String, FileOperations<V>>> entry : directories.entrySet()) {
for (Map.Entry<ResourceId, KV<String, FileOperations<V>>> entry : inputs.entrySet()) {
final KV<String, FileOperations<V>> fileOps = entry.getValue();
final KV<String, String> metadataKey =
KV.of(fileOps.getKey(), fileOps.getValue().getClass().getName());
Expand Down Expand Up @@ -675,7 +677,7 @@ private void readObject(ObjectInputStream inStream) throws ClassNotFoundExceptio
final Map<ResourceId, Integer> directoriesEncoding =
MapCoder.of(ResourceIdCoder.of(), VarIntCoder.of()).decode(inStream);

this.directories =
this.inputs =
directoriesEncoding.entrySet().stream()
.collect(
Collectors.toMap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.tensorflow.proto.example.Example;

/**
Expand Down Expand Up @@ -124,6 +125,9 @@ public static <K1, K2> TransformOutput<K1, K2> transformOutput(
*/
@AutoValue
public abstract static class Read extends SortedBucketIO.Read<Example> {
@Nullable
abstract ImmutableList<String> getInputDirectories();

abstract String getFilenameSuffix();

abstract Compression getCompression();
Expand Down Expand Up @@ -166,7 +170,7 @@ public Read withPredicate(Predicate<Example> predicate) {
}

@Override
protected BucketedInput<Example> toBucketedInput(final SortedBucketSource.Keying keying) {
public BucketedInput<Example> toBucketedInput(final SortedBucketSource.Keying keying) {
return BucketedInput.of(
keying,
getTupleTag(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ object ParquetTypeSortedBucketIO {
def withConfiguration(configuration: Configuration): Read[T] =
this.copy(configuration = configuration)

override def getInputDirectories: ImmutableList[String] =
def getInputDirectories: ImmutableList[String] =
ImmutableList.copyOf(inputDirectories.asJava: java.lang.Iterable[String])
def getFilenameSuffix: String = filenameSuffix

override def getTupleTag: TupleTag[T] = tupleTag

override protected def toBucketedInput(
override def toBucketedInput(
keying: SortedBucketSource.Keying
): SortedBucketSource.BucketedInput[T] = {
val fileOperations = ParquetTypeFileOperations[T](filterPredicate, configuration)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,16 @@ import scala.jdk.CollectionConverters._

object SortedBucketIOUtil {
def testId(read: beam.SortedBucketIO.Read[_]): String =
scio.SortedBucketIO.testId(read.getInputDirectories.asScala.toSeq: _*)
scio.SortedBucketIO.testId(
read
.toBucketedInput(SortedBucketSource.Keying.PRIMARY)
.getInputs
.asScala
.toSeq
.map { case (rId, _) =>
s"${rId.getCurrentDirectory}${Option(rId.getFilename).getOrElse("")}"
}: _*
)

def testId(write: beam.SortedBucketIO.Write[_, _, _]): String =
scio.SortedBucketIO.testId(write.getOutputDirectory.toString)
Expand Down

0 comments on commit 25dbe46

Please sign in to comment.