Skip to content

Commit

Permalink
Simple Schema type support -> pandas dataframe (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
kumare3 authored Oct 26, 2020
1 parent 6984964 commit 6e60e72
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 27 deletions.
10 changes: 6 additions & 4 deletions flytekit/annotated/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ class Interface(object):
A Python native interface object, like inspect.signature but simpler.
"""

def __init__(self, inputs: Dict[str, Union[Type, Tuple[Type, Any]]] = None, outputs: Dict[str, Type] = None):
def __init__(
self, inputs: typing.Dict[str, Union[Type, Tuple[Type, Any]]] = None, outputs: typing.Dict[str, Type] = None,
):
"""
:param outputs: Output variables and their types as a dictionary
:param inputs: the variable and its type only
Expand All @@ -29,18 +31,18 @@ def __init__(self, inputs: Dict[str, Union[Type, Tuple[Type, Any]]] = None, outp
self._outputs = outputs

@property
def inputs(self) -> Dict[str, Type]:
def inputs(self) -> typing.Dict[str, Type]:
r = {}
for k, v in self._inputs.items():
r[k] = v[0]
return r

@property
def inputs_with_defaults(self) -> Dict[str, Tuple[Type, Any]]:
def inputs_with_defaults(self) -> typing.Dict[str, Tuple[Type, Any]]:
return self._inputs

@property
def outputs(self):
def outputs(self) -> typing.Dict[str, type]:
return self._outputs

def remove_inputs(self, vars: List[str]) -> "Interface":
Expand Down
5 changes: 5 additions & 0 deletions flytekit/annotated/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,11 @@ def __init__(self, var: str, val: Union[_NodeOutput, _literal_models.Literal]):
self._promise_ready = False
self._val = None

def with_var(self, new_var: str) -> "Promise":
if self.is_ready:
return Promise(var=new_var, val=self.val)
return Promise(var=new_var, val=self.ref)

@property
def is_ready(self) -> bool:
"""
Expand Down
133 changes: 131 additions & 2 deletions flytekit/annotated/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,22 @@
from flytekit import typing as flyte_typing
from flytekit.annotated.context_manager import FlyteContext
from flytekit.common.types import primitives as _primitives
from flytekit.configuration import sdk
from flytekit.models import interface as _interface_models
from flytekit.models import types as _type_models
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, LiteralCollection, LiteralMap, Primitive, Scalar
from flytekit.models.types import LiteralType, SimpleType
from flytekit.models.literals import (
Blob,
BlobMetadata,
Literal,
LiteralCollection,
LiteralMap,
Primitive,
Scalar,
Schema,
)
from flytekit.models.types import LiteralType, SchemaType, SimpleType
from flytekit.plugins import pandas

T = typing.TypeVar("T")

Expand Down Expand Up @@ -390,6 +401,121 @@ def _downloader():
return flyte_typing.FlyteFilePath(local_path, _downloader(), lv.scalar.blob.uri)


class ParquetIO(object):
PARQUET_ENGINE = "pyarrow"

def _read(self, chunk: os.PathLike, columns: typing.List[str], **kwargs):
return pandas.read_parquet(chunk, columns=columns, engine=self.PARQUET_ENGINE, **kwargs)

def read(self, files: typing.List[os.PathLike], columns=None, **kwargs) -> pandas.DataFrame:
frames = [self._read(chunk=f, columns=columns, **kwargs) for f in files if os.path.getsize(f) > 0]
if len(frames) == 1:
return frames[0]
elif len(frames) > 1:
return pandas.concat(frames, copy=True)
return pandas.Dataframe()

def write(
self,
df: pandas.DataFrame,
to_file: os.PathLike,
coerce_timestamps: str = "us",
allow_truncated_timestamps: bool = False,
**kwargs,
):
"""
Writes data frame as a chunk to the local directory owned by the Schema object. Will later be uploaded to s3.
:param df: data frame to write as parquet
:param to_file: Sink file to write the dataframe to
:param coerce_timestamps: format to store timestamp in parquet. 'us', 'ms', 's' are allowed values.
Note: if your timestamps will lose data due to the coercion, your write will fail! Nanoseconds are
problematic in the Parquet format and will not work. See allow_truncated_timestamps.
:param allow_truncated_timestamps: default False. Allow truncation when coercing timestamps to a coarser
resolution.
"""
# TODO @ketan validate and remove this comment, as python 3 all strings are unicode
# Convert all columns to unicode as pyarrow's parquet reader can not handle mixed strings and unicode.
# Since columns from Hive are returned as unicode, if a user wants to add a column to a dataframe returned from
# Hive, then output the new data, the user would have to provide a unicode column name which is unnatural.
df.to_parquet(
to_file,
coerce_timestamps=coerce_timestamps,
allow_truncated_timestamps=allow_truncated_timestamps,
**kwargs,
)


class FastParquetIO(ParquetIO):
PARQUET_ENGINE = "fastparquet"

def _read(self, chunk: os.PathLike, columns: typing.List[str], **kwargs):
from fastparquet import ParquetFile as _ParquetFile
from fastparquet import thrift_structures as _ts

# TODO Follow up to figure out if this is not needed anymore
# https://github.com/dask/fastparquet/issues/414#issuecomment-478983811
df = pandas.read_parquet(chunk, columns=columns, engine=self.PARQUET_ENGINE, index=False)
df_column_types = df.dtypes
pf = _ParquetFile(chunk)
schema_column_dtypes = {l.name: l.type for l in list(pf.schema.schema_elements)}

for idx in df_column_types[df_column_types == "float16"].index.tolist():
# A hacky way to get the string representations of the column types of a parquet schema
# Reference:
# https://github.com/dask/fastparquet/blob/f4ecc67f50e7bf98b2d0099c9589c615ea4b06aa/fastparquet/schema.py
if _ts.parquet_thrift.Type._VALUES_TO_NAMES[schema_column_dtypes[idx]] == "BOOLEAN":
df[idx] = df[idx].astype("object")
df[idx].replace({0: False, 1: True, pandas.np.nan: None}, inplace=True)
return df


_PARQUETIO_ENGINES: typing.Dict[str, ParquetIO] = {
ParquetIO.PARQUET_ENGINE: ParquetIO(),
FastParquetIO.PARQUET_ENGINE: FastParquetIO,
}


def generate_ordered_files(directory: os.PathLike, n: int) -> typing.Generator[os.PathLike, None, None]:
for i in range(n):
yield os.path.join(directory, f"{i:05}")


class PandasDataFrameTransformer(TypeTransformer[pandas.DataFrame]):
"""
Transforms a pd.DataFrame to Schema without column types.
"""

def __init__(self, parquet_engine: ParquetIO):
super().__init__("PandasDataFrame<->GenericSchema", pandas.DataFrame)
self._parquet_engine = parquet_engine

@staticmethod
def _get_schema_type() -> SchemaType:
return SchemaType(columns=[])

def get_literal_type(self, t: type) -> LiteralType:
return LiteralType(schema=self._get_schema_type())

def to_literal(
self, ctx: FlyteContext, python_val: pandas.DataFrame, python_type: type, expected: LiteralType
) -> Literal:
remote_path = ctx.file_access.get_random_remote_directory()
local_dir = ctx.file_access.get_random_local_directory()
f = list(generate_ordered_files(local_dir, 1))[0]
self._parquet_engine.write(python_val, f)
ctx.file_access.put_data(local_dir, remote_path, is_multipart=True)
return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type())))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: type) -> pandas.DataFrame:
if not (lv and lv.scalar and lv.scalar.schema):
return pandas.DataFrame()
local_dir = ctx.file_access.get_random_local_directory()
ctx.file_access.download_directory(lv.scalar.schema.uri, local_dir)
files = os.listdir(local_dir)
files = [os.path.join(local_dir, f) for f in files]
return self._parquet_engine.read(files)


def _register_default_type_transformers():
TypeEngine.register(
SimpleTransformer(
Expand Down Expand Up @@ -463,6 +589,9 @@ def _register_default_type_transformers():
TypeEngine.register(PathLikeTransformer())
TypeEngine.register(BinaryIOTransformer())

parquet_io_engine = _PARQUETIO_ENGINES[sdk.PARQUET_ENGINE.get()]
TypeEngine.register(PandasDataFrameTransformer(parquet_io_engine))

# inner type is. Also unsupported are typing's Tuples. Even though you can look inside them, Flyte's type system
# doesn't support these currently.
# Confusing note: typing.NamedTuple is in here even though task functions themselves can return them. We just mean
Expand Down
75 changes: 54 additions & 21 deletions flytekit/annotated/workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
import typing
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import flytekit.annotated.promise
import flytekit.annotated.type_engine
Expand All @@ -27,6 +28,46 @@
from flytekit.models.core import workflow as _workflow_model


def _workflow_fn_outputs_to_promise(
ctx: FlyteContext,
native_outputs: typing.Dict[str, type], # Actually an orderedDict
typed_outputs: Dict[str, _interface_models.Variable],
outputs: Union[Any, Tuple[Any]],
) -> List[Promise]:
if len(native_outputs) == 0:
if outputs is not None:
raise AssertionError("something returned from wf but shouldn't have outputs")
return None

if len(native_outputs) == 1:
if isinstance(outputs, tuple):
if len(outputs) != 1:
raise AssertionError(
f"The Workflow specification indicates only one return value, received {len(outputs)}"
)
else:
outputs = (outputs,)

if len(native_outputs) > 1:
if not isinstance(outputs, tuple) or len(native_outputs) != len(outputs):
# Length check, clean up exception
raise AssertionError(
f"The workflow specification indicates {len(native_outputs)} return vals, but received {len(outputs)}"
)

# This recasts the Promises provided by the outputs of the workflow's tasks into the correct output names
# of the workflow itself
return_vals = []
for (k, t), v in zip(native_outputs.items(), outputs):
if isinstance(v, Promise):
return_vals.append(v.with_var(k))
else:
# Found a return type that is not a promise, so we need to transform it
var = typed_outputs[k]
return_vals.append(Promise(var=k, val=TypeEngine.to_literal(ctx, v, t, var.type)))
return return_vals


class Workflow(object):
"""
When you assign a name to a node.
Expand Down Expand Up @@ -98,12 +139,18 @@ def compile(self, **kwargs):
# iterate through the list here, instead we should let the binding creation unwrap it and make a binding
# collection/map out of it.
if len(output_names) == 1:
if isinstance(workflow_outputs, tuple) and len(workflow_outputs) != 1:
raise AssertionError(
f"The Workflow specification indicates only one return value, received {len(workflow_outputs)}"
)
t = self._native_interface.outputs[output_names[0]]
b = flytekit.annotated.promise.binding_from_python_std(
ctx, output_names[0], self.interface.outputs[output_names[0]].type, workflow_outputs, t,
)
bindings.append(b)
elif len(output_names) > 1:
if not isinstance(workflow_outputs, tuple):
raise AssertionError("The Workflow specification indicates multiple return values, received only one")
if len(output_names) != len(workflow_outputs):
raise Exception(f"Length mismatch {len(output_names)} vs {len(workflow_outputs)}")
for i, out in enumerate(output_names):
Expand Down Expand Up @@ -145,28 +192,14 @@ def _local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], P
# other things as well? What if someone just returns 5? Should we disallow this?
function_outputs = self._workflow_function(**kwargs)

output_names = list(self._interface.outputs.keys())
if len(output_names) == 0:
if function_outputs is None:
return None
else:
raise Exception("something returned from wf but shouldn't have outputs")

if len(output_names) != len(function_outputs):
# Length check, clean up exception
raise Exception(f"Length difference {len(output_names)} {len(function_outputs)}")

# This recasts the Promises provided by the outputs of the workflow's tasks into the correct output names
# of the workflow itself
vals = [
Promise(var=output_names[idx], val=function_outputs[idx].val)
for idx, promise in enumerate(function_outputs)
]
return create_task_output(vals)
promises = _workflow_fn_outputs_to_promise(
ctx, self._native_interface.outputs, self.interface.outputs, function_outputs
)
return create_task_output(promises)

def __call__(self, *args, **kwargs):
if len(args) > 0:
raise Exception("not allowed")
raise AssertionError("Only Keyword Arguments are supported for Workflow executions")

ctx = FlyteContext.current_context()

Expand Down Expand Up @@ -206,7 +239,7 @@ def __call__(self, *args, **kwargs):
if result is None:
return None
elif isinstance(result, Promise):
k, v = self._native_interface.outputs.items()[0]
v = [v for k, v in self._native_interface.outputs.items()][0]
return TypeEngine.to_python_value(ctx, result.val, v)
else:
for prom in result:
Expand Down
Loading

0 comments on commit 6e60e72

Please sign in to comment.