Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scio execution graph #5508

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 40 additions & 23 deletions scio-core/src/main/scala/com/spotify/scio/ScioContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.io.File
import java.net.URI
import java.nio.file.Files
import com.spotify.scio.coders.{Coder, CoderMaterializer, KVCoder}
import com.spotify.scio.graph.{NodeIO, NodeType, ScioGraphNode}
import com.spotify.scio.io._
import com.spotify.scio.metrics.Metrics
import com.spotify.scio.options.ScioOptions
Expand Down Expand Up @@ -524,8 +525,8 @@ class ScioContext private[scio] (
private var _onClose: Unit => Unit = identity

/** Wrap a [[org.apache.beam.sdk.values.PCollection PCollection]]. */
def wrap[T](p: PCollection[T]): SCollection[T] =
new SCollectionImpl[T](p, this)
def wrap[T](p: PCollection[T], step: ScioGraphNode): SCollection[T] =
new SCollectionImpl[T](p, this, step)

/** Add callbacks calls when the context is closed. */
private[scio] def onClose(f: Unit => Unit): Unit =
Expand Down Expand Up @@ -687,25 +688,34 @@ class ScioContext private[scio] (

private[scio] def applyTransform[U](
name: Option[String],
root: PTransform[_ >: PBegin, PCollection[U]]
root: PTransform[_ >: PBegin, PCollection[U]],
step: ScioGraphNode
): SCollection[U] =
wrap(applyInternal(name, root))
wrap(applyInternal(name, root), step)

private[scio] def applyTransform[U](
root: PTransform[_ >: PBegin, PCollection[U]]
root: PTransform[_ >: PBegin, PCollection[U]],
step: ScioGraphNode
): SCollection[U] =
applyTransform(None, root)
applyTransform(None, root, step)

private[scio] def applyTransform[U](
name: String,
root: PTransform[_ >: PBegin, PCollection[U]]
root: PTransform[_ >: PBegin, PCollection[U]],
step: ScioGraphNode
): SCollection[U] =
applyTransform(Option(name), root)
applyTransform(Option(name), root, step)

def transform[U](f: ScioContext => SCollection[U]): SCollection[U] = transform(this.tfName)(f)
def transform[U: ClassTag](f: ScioContext => SCollection[U]): SCollection[U] =
transform(this.tfName)(f)

def transform[U](name: String)(f: ScioContext => SCollection[U]): SCollection[U] =
wrap(transform_(name)(f(_).internal))
def transform[U: ClassTag](name: String)(f: ScioContext => SCollection[U]): SCollection[U] = {
val transformed = transform_(name)(sc => SCollectionOutput(f(sc)))
wrap(
transformed.scioCollection.internal,
ScioGraphNode.node[U](name, NodeType.Transform, List(transformed.scioCollection))
)
}

private[scio] def transform_[U <: POutput](f: ScioContext => U): U =
transform_(tfName)(f)
Expand Down Expand Up @@ -760,7 +770,7 @@ class ScioContext private[scio] (
if (this.isTest) {
TestDataManager.getInput(testId.get)(CustomIO[T](name)).toSCollection(this)
} else {
applyTransform(name, transform)
applyTransform(name, transform, ScioGraphNode.read(name, NodeIO.CustomInput))
}
}

Expand All @@ -786,30 +796,31 @@ class ScioContext private[scio] (

/** Create a union of multiple SCollections. Supports empty lists. */
// `T: Coder` context bound is required since `scs` might be empty.
def unionAll[T: Coder](scs: => Iterable[SCollection[T]]): SCollection[T] = {
def unionAll[T: Coder: ClassTag](scs: => Iterable[SCollection[T]]): SCollection[T] = {
val tfName = this.tfName // evaluate eagerly to avoid overriding `scs` names
scs match {
scs.toList match {
case Nil => empty()
case contents =>
wrap(
PCollectionList
.of(contents.map(_.internal).asJava)
.apply(tfName, Flatten.pCollections())
.apply(tfName, Flatten.pCollections()),
ScioGraphNode.node[T](tfName, NodeType.UnionAll, contents)
)
}
}

/** Form an empty SCollection. */
def empty[T: Coder](): SCollection[T] = parallelize(Nil)
def empty[T: Coder: ClassTag](): SCollection[T] = parallelize(Nil)

/**
* Distribute a local Scala `Iterable` to form an SCollection.
* @group in_memory
*/
def parallelize[T: Coder](elems: Iterable[T]): SCollection[T] =
def parallelize[T: Coder: ClassTag](elems: Iterable[T]): SCollection[T] =
requireNotClosed {
val coder = CoderMaterializer.beam(this, Coder[T])
this.applyTransform(Create.of(elems.asJava).withCoder(coder))
this.applyTransform(Create.of(elems.asJava).withCoder(coder), ScioGraphNode.parallelize[T])
}

/**
Expand All @@ -822,33 +833,39 @@ class ScioContext private[scio] (
requireNotClosed {
val coder = CoderMaterializer.beam(this, KVCoder(koder, voder))
this
.applyTransform(Create.of(elems.asJava).withCoder(coder))
.applyTransform(Create.of(elems.asJava).withCoder(coder), ScioGraphNode.parallelize[(K, V)])
.map(kv => (kv.getKey, kv.getValue))
}

/**
* Distribute a local Scala `Iterable` with timestamps to form an SCollection.
* @group in_memory
*/
def parallelizeTimestamped[T: Coder](elems: Iterable[(T, Instant)]): SCollection[T] =
def parallelizeTimestamped[T: Coder: ClassTag](elems: Iterable[(T, Instant)]): SCollection[T] =
requireNotClosed {
val coder = CoderMaterializer.beam(this, Coder[T])
val v = elems.map(t => TimestampedValue.of(t._1, t._2))
this.applyTransform(Create.timestamped(v.asJava).withCoder(coder))
this.applyTransform(
Create.timestamped(v.asJava).withCoder(coder),
ScioGraphNode.parallelize[T]
)
}

/**
* Distribute a local Scala `Iterable` with timestamps to form an SCollection.
* @group in_memory
*/
def parallelizeTimestamped[T: Coder](
def parallelizeTimestamped[T: Coder: ClassTag](
elems: Iterable[T],
timestamps: Iterable[Instant]
): SCollection[T] =
requireNotClosed {
val coder = CoderMaterializer.beam(this, Coder[T])
val v = elems.zip(timestamps).map(t => TimestampedValue.of(t._1, t._2))
this.applyTransform(Create.timestamped(v.asJava).withCoder(coder))
this.applyTransform(
Create.timestamped(v.asJava).withCoder(coder),
ScioGraphNode.parallelize[T]
)
}

// =======================================================================
Expand Down
127 changes: 127 additions & 0 deletions scio-core/src/main/scala/com/spotify/scio/graph/ScioGraphNode.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package com.spotify.scio.graph

import com.spotify.scio.values.SCollection

import scala.reflect.ClassTag

private[scio] case class ScioGraphNode(
name: String,
`type`: String,
io: Option[String],
dataClass: Class[_],
schemaPath: Option[String],
sources: List[ScioGraphNode],
properties: Map[String, Any]
)

object NodeType {
val Parallelize: String = "Parallelize"
val Transform: String = "Transform"
val UnionAll: String = "UnionAll"
val FlatMap: String = "FlatMap"
}

object NodeIO {
val Text: String = "Text"
val CustomInput: String = "CustomInput"
}

object ScioGraphNode {
def parallelize[T](implicit ct: ClassTag[T]) =
node[T](null, NodeType.Parallelize, List())

def node[T](
name: String,
`type`: String,
sources: List[SCollection[_]],
properties: Map[String, Any] = Map.empty
)(implicit
ct: ClassTag[T]
): ScioGraphNode = {
ScioGraphNode(
name,
`type`,
None,
ct.runtimeClass,
None,
sources.map(_.step),
properties
)
}

def read[T](name: String, io: String, properties: Map[String, Any] = Map.empty)(implicit
ct: ClassTag[T]
): ScioGraphNode = {
ScioGraphNode(
name,
"read",
Some(io),
ct.runtimeClass,
None,
List(),
properties
)
}

def write[T](
name: String,
io: String,
source: SCollection[_],
properties: Map[String, Any] = Map.empty
)(implicit
ct: ClassTag[T]
): ScioGraphNode = {
ScioGraphNode(
name,
"write",
Some(io),
ct.runtimeClass,
None,
List(source.step),
properties
)
}
}

//private[scio] case class SingleSourceNode(
// name: String,
// source: ScioGraphNode,
// dataClass: Class[_],
// schemaPath: Option[String] = None
//) extends ScioGraphNode {
// override val sources: List[ScioGraphNode] = List(source)
//}
//
//private[scio] case class TransformStep(name: String, sources: List[ScioGraphNode])
// extends ScioGraphNode
//
//private[scio] object TransformStep {
// def apply(name: String, source: ScioGraphNode): TransformStep = TransformStep(name, List(source))
//}

// preserve link to an original PTransform?
//private[scio] case class CustomInput(name: String) extends ScioGraphNode {
// val sources: List[ScioGraphNode] = List.empty
//}

//private[scio] case class UnionAll(name: String, sources: List[ScioGraphNode]) extends ScioGraphNode

//private[scio] object Parallelize extends ScioGraphNode {
// override val name: String = "parallelize"
// override val sources: List[ScioGraphNode] = List.empty
//}

//private[scio] case class ReadTextIO(filePattern: String) extends ScioGraphNode {
// override val name: String = null
// override val sources: List[ScioGraphNode] = List()
//}
//
//private[scio] case class TestInput(kind: String) extends ScioGraphNode {
// override val name: String = kind
// override val sources: List[ScioGraphNode] = List()
//}
//
//private[scio] case class FlatMap(source: ScioGraphNode) extends ScioGraphNode {
// val name: String = null
// override val sources: List[ScioGraphNode] = List(source)
//}
3 changes: 2 additions & 1 deletion scio-core/src/main/scala/com/spotify/scio/io/Tap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.beam.sdk.coders.{ByteArrayCoder, Coder => BCoder}
import org.apache.beam.sdk.util.CoderUtils

import java.io.{EOFException, InputStream}
import scala.reflect.ClassTag

/**
* Placeholder to an external data set that can either be load into memory as an iterator or opened
Expand Down Expand Up @@ -88,7 +89,7 @@ final case class TextTap(path: String, params: TextIO.ReadParam) extends Tap[Str
sc.read(TextIO(path))(params)
}

final private[scio] class InMemoryTap[T: Coder] extends Tap[T] {
final private[scio] class InMemoryTap[T: Coder: ClassTag] extends Tap[T] {
private[scio] val id: String = UUID.randomUUID().toString
override def value: Iterator[T] = InMemorySink.get(id).iterator
override def open(sc: ScioContext): SCollection[T] =
Expand Down
6 changes: 4 additions & 2 deletions scio-core/src/main/scala/com/spotify/scio/io/TextIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.nio.channels.Channels
import java.util.Collections
import com.spotify.scio.ScioContext
import com.spotify.scio.coders.{Coder, CoderMaterializer}
import com.spotify.scio.graph.{NodeIO, ScioGraphNode}
import com.spotify.scio.util.ScioUtil
import com.spotify.scio.util.FilenamePolicySupplier
import com.spotify.scio.values.SCollection
Expand Down Expand Up @@ -49,9 +50,10 @@ final case class TextIO(path: String) extends ScioIO[String] {
.read()
.from(filePattern)
.withCompression(params.compression)
.withEmptyMatchTreatment(params.emptyMatchTreatment)
.withEmptyMatchTreatment(params.emptyMatchTreatment).dis

sc.applyTransform(t)
sc.applyTransform(t, ScioGraphNode.read(null, NodeIO.Text, Map("filePattern" -> filePattern))
)
.setCoder(coder)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package com.spotify.scio.testing

import com.spotify.scio.coders.Coder
import com.spotify.scio.graph.ScioGraphNode
import com.spotify.scio.io.ScioIO
import com.spotify.scio.values.SCollection
import com.spotify.scio.{ScioContext, ScioResult}
Expand All @@ -28,6 +29,7 @@ import org.apache.beam.sdk.testing.TestStream
import scala.collection.concurrent.TrieMap
import scala.collection.mutable.{Set => MSet}
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

/* Inputs are Scala Iterables to be parallelized for TestPipeline, or PTransforms to be applied */
Expand All @@ -36,7 +38,7 @@ sealed private[scio] trait JobInputSource[T] {
val asIterable: Try[Iterable[T]]
}

final private[scio] case class TestStreamInputSource[T](
final private[scio] case class TestStreamInputSource[T: ClassTag](
stream: TestStream[T]
) extends JobInputSource[T] {
override val asIterable: Try[Iterable[T]] = Failure(
Expand All @@ -46,12 +48,12 @@ final private[scio] case class TestStreamInputSource[T](
)

override def toSCollection(sc: ScioContext): SCollection[T] =
sc.applyTransform(stream)
sc.applyTransform(stream, ScioGraphNode.read[T]("TestStream", "TestInput"))

override def toString: String = s"TestStream(${stream.getEvents})"
}

final private[scio] case class IterableInputSource[T: Coder](
final private[scio] case class IterableInputSource[T: Coder: ClassTag](
iterable: Iterable[T]
) extends JobInputSource[T] {
override val asIterable: Success[Iterable[T]] = Success(iterable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ package com.spotify.scio.transforms.syntax

import com.spotify.scio.values.SCollection
import com.spotify.scio.coders.{Coder, CoderMaterializer}
import com.spotify.scio.graph.{NodeType, ScioGraphNode}
import com.spotify.scio.util.NamedDoFn
import com.twitter.chill.ClosureCleaner
import org.apache.beam.sdk.transforms.DoFn.{Element, MultiOutputReceiver, ProcessElement}
import org.apache.beam.sdk.transforms.ParDo
import org.apache.beam.sdk.values.{TupleTag, TupleTagList}

import scala.collection.compat._
import scala.reflect.ClassTag

trait SCollectionSafeSyntax {

Expand All @@ -43,7 +45,7 @@ trait SCollectionSafeSyntax {
*
* @group transform
*/
def safeFlatMap[U: Coder](
def safeFlatMap[U: Coder: ClassTag](
f: T => TraversableOnce[U]
): (SCollection[U], SCollection[(T, Throwable)]) = {
val (mainTag, errorTag) = (new TupleTag[U], new TupleTag[(T, Throwable)])
Expand Down Expand Up @@ -75,7 +77,11 @@ trait SCollectionSafeSyntax {
tuple
.get(errorTag)
.setCoder(CoderMaterializer.beam(self.context, Coder[(T, Throwable)]))
(self.context.wrap(main), self.context.wrap(errorPipe))
(
self.context.wrap(main, ScioGraphNode.node(null, NodeType.FlatMap, List(self))),
self.context
.wrap(errorPipe, ScioGraphNode.node(null, NodeType.FlatMap, List(self)))
)
}
}
}
Loading
Loading