Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MSGPACK IDL] Gate feature by setting ENV #2894

Merged
merged 16 commits into from
Nov 6, 2024
99 changes: 92 additions & 7 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import inspect
import json
import mimetypes
import os
import sys
import textwrap
import threading
Expand All @@ -29,7 +30,7 @@
from google.protobuf.json_format import ParseDict as _ParseDict
from google.protobuf.message import Message
from google.protobuf.struct_pb2 import Struct
from mashumaro.codecs.json import JSONDecoder
from mashumaro.codecs.json import JSONDecoder, JSONEncoder
from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder
from mashumaro.mixins.json import DataClassJSONMixin
from typing_extensions import Annotated, get_args, get_origin
Expand Down Expand Up @@ -498,7 +499,8 @@ class Test(DataClassJsonMixin):

def __init__(self) -> None:
super().__init__("Object-Dataclass-Transformer", object)
self._decoder: Dict[Type, JSONDecoder] = dict()
self._json_encoder: Dict[Type, JSONEncoder] = dict()
self._json_decoder: Dict[Type, JSONDecoder] = dict()

def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):
# Skip iterating all attributes in the dataclass if the type of v already matches the expected_type
Expand Down Expand Up @@ -655,14 +657,55 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
)
)

# This is for attribute access in FlytePropeller.
ts = TypeStructure(tag="", dataclass_type=literal_type)

return _type_models.LiteralType(simple=_type_models.SimpleType.STRUCT, metadata=schema, structure=ts)

def to_old_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
"""
Serializes a dataclass or dictionary to a Flyte literal, handling both JSON and MessagePack formats.
Set `FLYTE_USE_OLD_DC_FORMAT=true` to use the old JSON-based format.
"""
if isinstance(python_val, dict):
json_str = json.dumps(python_val)
return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())))

if not dataclasses.is_dataclass(python_val):
raise TypeTransformerFailedError(
f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for "
f"user defined datatypes in Flytekit"
)

self._make_dataclass_serializable(python_val, python_type)

# JSON serialization using mashumaro's DataClassJSONMixin
if isinstance(python_val, DataClassJSONMixin):
json_str = python_val.to_json()
else:
try:
encoder = self._json_encoder[python_type]
except KeyError:
encoder = JSONEncoder(python_type)
self._json_encoder[python_type] = encoder

try:
json_str = encoder.encode(python_val)
except NotImplementedError:
raise NotImplementedError(
f"{python_type} should inherit from mashumaro.types.SerializableType"
f" and implement _serialize and _deserialize methods."
)

return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
if os.getenv("FLYTE_USE_OLD_DC_FORMAT", "false").lower() == "true":
return self.to_old_literal(ctx, python_val, python_type, expected)

if isinstance(python_val, dict):
msgpack_bytes = msgpack.dumps(python_val)
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK)))

if not dataclasses.is_dataclass(python_val):
raise TypeTransformerFailedError(
Expand Down Expand Up @@ -697,7 +740,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
f" and implement _serialize and _deserialize methods."
)

return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK)))

def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]:
# dataclass will try to hash python type when calling dataclass.schema(), but some types in the annotation is
Expand Down Expand Up @@ -887,10 +930,10 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
# The function looks up or creates a JSONDecoder specifically designed for the object's type.
# This decoder is then used to convert a JSON string into a data class.
try:
decoder = self._decoder[expected_python_type]
decoder = self._json_decoder[expected_python_type]
except KeyError:
decoder = JSONDecoder(expected_python_type)
self._decoder[expected_python_type] = decoder
self._json_decoder[expected_python_type] = decoder

dc = decoder.decode(json_str)

Expand Down Expand Up @@ -1954,6 +1997,42 @@ def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple:
return _args # type: ignore
return None, None

@staticmethod
async def dict_to_old_generic_literal(
ctx: FlyteContext, v: dict, python_type: Type[dict], allow_pickle: bool
) -> Literal:
"""
This is deprecated from flytekit 1.14.0.
Creates a flyte-specific ``Literal`` value from a native python dictionary.
"""
from flytekit.types.pickle import FlytePickle

try:
try:
# JSONEncoder is mashumaro's codec and this can triggered Flyte Types customized serialization and deserialization.
encoder = JSONEncoder(python_type)
json_str = encoder.encode(v)
except NotImplementedError:
raise NotImplementedError(
f"{python_type} should inherit from mashumaro.types.SerializableType"
f" and implement _serialize and _deserialize methods."
)

return Literal(
scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())),
metadata={"format": "json"},
)
except TypeError as e:
if allow_pickle:
remote_path = await FlytePickle.to_pickle(ctx, v)
return Literal(
scalar=Scalar(
generic=_json_format.Parse(json.dumps({"pickle_file": remote_path}), _struct.Struct())
),
metadata={"format": "pickle"},
)
raise TypeTransformerFailedError(f"Cannot convert `{v}` to Flyte Literal.\n" f"Error Message: {e}")

@staticmethod
async def dict_to_binary_literal(
ctx: FlyteContext, v: dict, python_type: Type[dict], allow_pickle: bool
Expand All @@ -1968,7 +2047,7 @@ async def dict_to_binary_literal(
# Handle dictionaries with non-string keys (e.g., Dict[int, Type])
encoder = MessagePackEncoder(python_type)
msgpack_bytes = encoder.encode(v)
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK)))
except TypeError as e:
if allow_pickle:
remote_path = await FlytePickle.to_pickle(ctx, v)
Expand Down Expand Up @@ -2029,6 +2108,12 @@ async def async_to_literal(
allow_pickle, base_type = DictTransformer.is_pickle(python_type)

if expected and expected.simple and expected.simple == SimpleType.STRUCT:
"""
TODO: add more comments
if FLYTE_USE_OLD_DC_FORMAT = true, then the old format is used, which is a binary literal with a struct
"""
if os.getenv("FLYTE_USE_OLD_DC_FORMAT", "false").lower() == "true":
return await self.dict_to_old_generic_literal(ctx, python_val, python_type, allow_pickle)
return await self.dict_to_binary_literal(ctx, python_val, python_type, allow_pickle)

lit_map = {}
Expand Down
19 changes: 19 additions & 0 deletions flytekit/extras/pydantic_transformer/transformer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import os
from typing import Type

import msgpack
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
from pydantic import BaseModel

from flytekit import FlyteContext
Expand Down Expand Up @@ -31,10 +33,24 @@ def get_literal_type(self, t: Type[BaseModel]) -> LiteralType:
"Field {} of type {} cannot be converted to a literal type. Error: {}".format(name, python_type, e)
)

# This is for attribute access in FlytePropeller.
ts = TypeStructure(tag="", dataclass_type=literal_type)

return types.LiteralType(simple=types.SimpleType.STRUCT, metadata=schema, structure=ts)

def to_old_generic_literal(
self,
ctx: FlyteContext,
python_val: BaseModel,
python_type: Type[BaseModel],
expected: types.LiteralType,
) -> Literal:
"""
This is for users who don't want to upgrade the flytepropeller.
"""
json_str = python_val.model_dump_json()
return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())))

def to_literal(
self,
ctx: FlyteContext,
Expand All @@ -47,6 +63,9 @@ def to_literal(
This is for handling enum in basemodel.
More details: https://github.com/flyteorg/flytekit/pull/2792
"""
if os.getenv("FLYTE_USE_OLD_DC_FORMAT", "false").lower() == "true":
return self.to_old_generic_literal(ctx, python_val, python_type, expected)

json_str = python_val.model_dump_json()
dict_obj = json.loads(json_str)
msgpack_bytes = msgpack.dumps(dict_obj)
Expand Down
1 change: 1 addition & 0 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ async def async_to_literal(
# that we will need to invoke an encoder for. Figure out which encoder to call and invoke it.
df_type = type(python_val.dataframe)
protocol = self._protocol_from_type_or_prefix(ctx, df_type, python_val.uri)

return self.encode(
ctx,
python_val,
Expand Down
Loading
Loading