Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spark DataFrames handled as a type if using spark #267

Merged
merged 7 commits into from
Dec 3, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 61 additions & 43 deletions flytekit/annotated/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from flytekit.annotated.promise import Promise, VoidPromise, create_task_output, translate_inputs_to_literals
from flytekit.annotated.type_engine import TypeEngine
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.tasks.sdk_runnable import ExecutionParameters
from flytekit.common.tasks.task import SdkTask
from flytekit.loggers import logger
from flytekit.models import dynamic_job as _dynamic_job
Expand Down Expand Up @@ -246,50 +247,67 @@ def dispatch_execute(
`DynamicJobSpec` is returned when a dynamic workflow is executed
"""

# TODO We could support default values here too - but not part of the plan right now
# Translate the input literals to Python native
native_inputs = TypeEngine.literal_map_to_kwargs(ctx, input_literal_map, self.python_interface.inputs)

# TODO: Logger should auto inject the current context information to indicate if the task is running within
# a workflow or a subworkflow etc
logger.info(f"Invoking {self.name} with inputs: {native_inputs}")
try:
native_outputs = self.execute(**native_inputs)
except Exception as e:
logger.exception(f"Exception when executing {e}")
raise e
logger.info(f"Task executed successfully in user level, outputs: {native_outputs}")

# Short circuit the translation to literal map because what's returned may be a dj spec (or an
# already-constructed LiteralMap if the dynamic task was a no-op), not python native values
if isinstance(native_outputs, _literal_models.LiteralMap) or isinstance(
native_outputs, _dynamic_job.DynamicJobSpec
):
return native_outputs
# Invoked before the task is executed
new_user_params = self.pre_execute(ctx.user_space_params)
with ctx.new_execution_context(mode=ctx.execution_state.mode, execution_params=new_user_params) as exec_ctx:
# TODO We could support default values here too - but not part of the plan right now
# Translate the input literals to Python native
native_inputs = TypeEngine.literal_map_to_kwargs(exec_ctx, input_literal_map, self.python_interface.inputs)

# TODO: Logger should auto inject the current context information to indicate if the task is running within
# a workflow or a subworkflow etc
logger.info(f"Invoking {self.name} with inputs: {native_inputs}")
try:
native_outputs = self.execute(**native_inputs)
except Exception as e:
logger.exception(f"Exception when executing {e}")
raise e
logger.info(f"Task executed successfully in user level, outputs: {native_outputs}")

# Short circuit the translation to literal map because what's returned may be a dj spec (or an
# already-constructed LiteralMap if the dynamic task was a no-op), not python native values
if isinstance(native_outputs, _literal_models.LiteralMap) or isinstance(
native_outputs, _dynamic_job.DynamicJobSpec
):
return native_outputs

expected_output_names = list(self.interface.outputs.keys())
if len(expected_output_names) == 1:
native_outputs_as_map = {expected_output_names[0]: native_outputs}
elif len(expected_output_names) == 0:
return VoidPromise(self.name)
else:
# Question: How do you know you're going to enumerate them in the correct order? Even if autonamed, will
# output2 come before output100 if there's a hundred outputs? We don't! We'll have to circle back to
# the Python task instance and inspect annotations again. Or we change the Python model representation
# of the interface to be an ordered dict and we fill it in correctly to begin with.
native_outputs_as_map = {
expected_output_names[i]: native_outputs[i] for i, _ in enumerate(native_outputs)
}

# We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
# built into the IDL that all the values of a literal map are of the same type.
literals = {}
for k, v in native_outputs_as_map.items():
literal_type = self.interface.outputs[k].type
py_type = self.get_type_for_output_var(k, v)
if isinstance(v, tuple):
raise AssertionError(f"Output({k}) in task{self.name} received a tuple {v}, instead of {py_type}")
literals[k] = TypeEngine.to_literal(exec_ctx, v, py_type, literal_type)
outputs_literal_map = _literal_models.LiteralMap(literals=literals)
# After the execute has been successfully completed
return outputs_literal_map

expected_output_names = list(self.interface.outputs.keys())
if len(expected_output_names) == 1:
native_outputs_as_map = {expected_output_names[0]: native_outputs}
elif len(expected_output_names) == 0:
return VoidPromise(self.name)
else:
# Question: How do you know you're going to enumerate them in the correct order? Even if autonamed, will
# output2 come before output100 if there's a hundred outputs? We don't! We'll have to circle back to
# the Python task instance and inspect annotations again. Or we change the Python model representation
# of the interface to be an ordered dict and we fill it in correctly to begin with.
native_outputs_as_map = {expected_output_names[i]: native_outputs[i] for i, _ in enumerate(native_outputs)}

# We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
# built into the IDL that all the values of a literal map are of the same type.
literals = {}
for k, v in native_outputs_as_map.items():
literal_type = self.interface.outputs[k].type
py_type = self.get_type_for_output_var(k, v)
if isinstance(v, tuple):
raise AssertionError(f"Output({k}) in task{self.name} received a tuple {v}, instead of {py_type}")
literals[k] = TypeEngine.to_literal(ctx, v, py_type, literal_type)
outputs_literal_map = _literal_models.LiteralMap(literals=literals)
return outputs_literal_map
def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a huge fan of this function signature... can we think of a way around this? i'd rather pass in the parent FlyteContext and access the user space params from there. This function name makes it seem like a generic setup call, but it always takes in and returns just the user params? that seems limiting

"""
This is the method that will be invoked directly before executing the task method and before all the inputs
are converted. One particular case where this is useful is if the context is to be modified for the user process
to get some user space parameters. This also ensures that things like SparkSession are already correctly
setup before the type transformers are called

This should return either the same context of the mutated context
"""
return user_params

@abstractmethod
def execute(self, **kwargs) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion flytekit/common/tasks/sdk_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __getattr__(self, attr_name: str) -> typing.Any:
attr_name = attr_name.upper()
if self._attrs and attr_name in self._attrs:
return self._attrs[attr_name]
raise AssertionError(f"{attr_name} not available as a parameter in Flyte context")
raise AssertionError(f"{attr_name} not available as a parameter in Flyte context - are you in right task-type?")


class SdkRunnableContainer(_task_models.Container, metaclass=_sdk_bases.ExtendedSdkType):
Expand Down
8 changes: 3 additions & 5 deletions flytekit/interfaces/data/data_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,9 @@ def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_mul
self.remote.upload(local_path, remote_path)
except Exception as ex:
raise _user_exception.FlyteAssertion(
"Failed to put data from {local_path} to {remote_path} (recursive={is_multipart}).\n\n"
"Original exception: {error_string}".format(
remote_path=remote_path, local_path=local_path, is_multipart=is_multipart, error_string=str(ex),
)
)
f"Failed to put data from {local_path} to {remote_path} (recursive={is_multipart}).\n\n"
f"Original exception: {str(ex)}"
) from ex


timestamped_default_sandbox_location = os.path.join(
Expand Down
3 changes: 2 additions & 1 deletion flytekit/taskplugins/spark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .schema import SparkDataFrameSchemaReader, SparkDataFrameSchemaWriter, SparkDataFrameTransformer
from .task import Spark

__all__ = [Spark]
__all__ = [Spark, SparkDataFrameTransformer, SparkDataFrameSchemaReader, SparkDataFrameSchemaWriter]
84 changes: 84 additions & 0 deletions flytekit/taskplugins/spark/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import typing
from typing import Type

from flytekit import FlyteContext
from flytekit.annotated.type_engine import T, TypeEngine, TypeTransformer
from flytekit.models.literals import Literal, Scalar, Schema
from flytekit.models.types import LiteralType, SchemaType
from flytekit.plugins import pyspark
from flytekit.types.schema import SchemaEngine, SchemaFormat, SchemaHandler, SchemaReader, SchemaWriter


class SparkDataFrameSchemaReader(SchemaReader[pyspark.sql.DataFrame]):
def __init__(self, from_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat):
super().__init__(from_path, cols, fmt)

def iter(self, **kwargs) -> typing.Generator[T, None, None]:
raise NotImplementedError("Spark DataFrame reader cannot iterate over individual chunks in spark dataframe")

def all(self, **kwargs) -> pyspark.sql.DataFrame:
if self._fmt == SchemaFormat.PARQUET:
ctx = FlyteContext.current_context().user_space_params
return ctx.spark_session.read.parquet(self.from_path)
raise AssertionError("Only Parquet type files are supported for spark dataframe currently")


class SparkDataFrameSchemaWriter(SchemaWriter[pyspark.sql.DataFrame]):
def __init__(self, to_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat):
super().__init__(to_path, cols, fmt)

def write(self, *dfs: pyspark.sql.DataFrame, **kwargs):
if dfs is None or len(dfs) == 0:
return
if len(dfs) > 1:
raise AssertionError("Only one Spark.DataFrame can be returned per return variable currently")
if self._fmt == SchemaFormat.PARQUET:
dfs[0].write.mode("overwrite").parquet(self.to_path)
return
raise AssertionError("Only Parquet type files are supported for spark dataframe currently")


class SparkDataFrameTransformer(TypeTransformer[pyspark.sql.DataFrame]):
"""
Transforms Spark DataFrame's to and from a Schema (typed/untyped)
"""

def __init__(self):
super(SparkDataFrameTransformer, self).__init__("spark-df-transformer", t=pyspark.sql.DataFrame)

@staticmethod
def _get_schema_type() -> SchemaType:
return SchemaType(columns=[])

def get_literal_type(self, t: Type[pyspark.sql.DataFrame]) -> LiteralType:
return LiteralType(schema=self._get_schema_type())

def to_literal(
self,
ctx: FlyteContext,
python_val: pyspark.sql.DataFrame,
python_type: Type[pyspark.sql.DataFrame],
expected: LiteralType,
) -> Literal:
remote_path = ctx.file_access.get_random_remote_directory()
w = SparkDataFrameSchemaWriter(to_path=remote_path, cols=None, fmt=SchemaFormat.PARQUET)
w.write(python_val)
return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type())))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[pyspark.sql.DataFrame]) -> T:
if not (lv and lv.scalar and lv.scalar.schema):
return pyspark.sql.DataFrame()
r = SparkDataFrameSchemaReader(from_path=lv.scalar.schema.uri, cols=None, fmt=SchemaFormat.PARQUET)
return r.all()


SchemaEngine.register_handler(
SchemaHandler(
"pyspark.sql.DataFrame-Schema",
pyspark.sql.DataFrame,
SparkDataFrameSchemaReader,
SparkDataFrameSchemaWriter,
handles_remote_io=True,
)
)
TypeEngine.register(SparkDataFrameTransformer())
20 changes: 11 additions & 9 deletions flytekit/taskplugins/spark/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from google.protobuf.json_format import MessageToDict

from flytekit.annotated.context_manager import FlyteContext, RegistrationSettings
from flytekit.annotated.context_manager import RegistrationSettings
from flytekit.annotated.python_function_task import PythonFunctionTask
from flytekit.annotated.task import TaskPlugins
from flytekit.common.tasks.sdk_runnable import ExecutionParameters
from flytekit.models import task as _task_model
from flytekit.sdk.spark_types import SparkType

Expand All @@ -32,7 +33,8 @@ def new_spark_session(name: str):

sess = _pyspark.sql.SparkSession.builder.appName(f"FlyteSpark: {name}").getOrCreate()
yield sess
sess.stop()
# SparkSession.Stop does not work correctly, as it stops the session before all the data is written
# sess.stop()


class PysparkFunctionTask(PythonFunctionTask[Spark]):
Expand Down Expand Up @@ -60,13 +62,13 @@ def get_custom(self, settings: RegistrationSettings) -> Dict[str, Any]:
)
return MessageToDict(job.to_flyte_idl())

def execute(self, **kwargs) -> Any:
ctx = FlyteContext.current_context()
with new_spark_session(ctx.user_space_params.execution_id) as sess:
b = ctx.user_space_params.builder(ctx.user_space_params)
b.add_attr("SPARK_SESSION", sess)
with ctx.new_execution_context(mode=ctx.execution_state.mode, execution_params=b.build()):
return self._task_function(**kwargs)
def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
import pyspark as _pyspark

sess = _pyspark.sql.SparkSession.builder.appName(f"FlyteSpark: {user_params.execution_id}").getOrCreate()
b = user_params.builder(user_params)
b.add_attr("SPARK_SESSION", sess)
return b.build()


TaskPlugins.register_pythontask_plugin(Spark, PysparkFunctionTask)
11 changes: 2 additions & 9 deletions flytekit/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
from flytekit.types.flyte_file import FlyteFile
from flytekit.types.schema import (
FlyteSchema,
PandasSchemaReader,
PandasSchemaWriter,
Schema,
SchemaFormat,
SchemaOpenMode,
SchemaType,
)
from flytekit.types.pandas_schema import PandasSchemaReader, PandasSchemaWriter
from flytekit.types.schema import FlyteSchema
Loading