Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flytekit][2][untyped dict] Binary IDL With MessagePack #2757

Closed
wants to merge 10 commits into from
52 changes: 35 additions & 17 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import lru_cache
from typing import Dict, List, NamedTuple, Optional, Type, cast
from typing import Any, Dict, List, NamedTuple, Optional, Type, cast

import msgpack
from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
from google.protobuf import json_format as _json_format
Expand All @@ -26,6 +27,7 @@
from google.protobuf.message import Message
from google.protobuf.struct_pb2 import Struct
from mashumaro.codecs.json import JSONDecoder, JSONEncoder
from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder
from mashumaro.mixins.json import DataClassJSONMixin
from typing_extensions import Annotated, get_args, get_origin

Expand All @@ -42,22 +44,21 @@
from flytekit.models import types as _type_models
from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel
from flytekit.models.core import types as _core_types
from flytekit.models.literals import (
Literal,
LiteralCollection,
LiteralMap,
Primitive,
Scalar,
Union,
Void,
)
from flytekit.models.literals import Binary, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Union, Void
from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType

T = typing.TypeVar("T")
DEFINITIONS = "definitions"
TITLE = "title"


# In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True.
# This is relevant for cases like Dict[int, str].
# If strict_map_key=False is not used, the decoder will raise an error when trying to decode keys that are not strictly typed.`
def _default_flytekit_decoder(data: bytes) -> Any:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we call this _default_msgpack_decoder

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we may have other decoder in the future, like _default_utf8_decoder

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, pretty good advice, thank you!

return msgpack.unpackb(data, raw=False, strict_map_key=False)


class BatchSize:
"""
This is used to annotate a FlyteDirectory when we want to download/upload the contents of the directory in batches. For example,
Expand Down Expand Up @@ -129,6 +130,8 @@ def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True):
self._t = t
self._name = name
self._type_assertions_enabled = enable_type_assertions
self._msgpack_encoder: Dict[Type, MessagePackEncoder] = {}
self._msgpack_decoder: Dict[Type, MessagePackDecoder] = {}

@property
def name(self):
Expand Down Expand Up @@ -221,6 +224,17 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
f"Conversion to python value expected type {expected_python_type} from literal not implemented"
)

def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]:
if binary_idl_object.tag == "msgpack":
try:
decoder = self._msgpack_decoder[expected_python_type]
except KeyError:
decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_flytekit_decoder)
self._msgpack_decoder[expected_python_type] = decoder
return decoder.decode(binary_idl_object.value)
else:
raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}")
raise TypeTransformerFailedError(f"Unsupported binary format `{binary_idl_object.tag}`")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!


def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str:
"""
Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div
Expand Down Expand Up @@ -271,6 +285,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
f"Cannot convert to type {expected_python_type}, only {self._type} is supported"
)

if lv.scalar and lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore
Future-Outlier marked this conversation as resolved.
Show resolved Hide resolved

try: # todo(maximsmol): this is quite ugly and each transformer should really check their Literal
res = self._from_literal_transformer(lv)
if type(res) != self._type:
Expand Down Expand Up @@ -1697,17 +1714,15 @@ def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple:
return None, None

@staticmethod
def dict_to_generic_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Literal:
def dict_to_binary_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Literal:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but we still use generic to save the pickle file, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I didn't change the logic for pickle

"""
Creates a flyte-specific ``Literal`` value from a native python dictionary.
"""
from flytekit.types.pickle import FlytePickle

try:
return Literal(
scalar=Scalar(generic=_json_format.Parse(json.dumps(v), _struct.Struct())),
metadata={"format": "json"},
)
msgpack_bytes = msgpack.dumps(v)
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))
except TypeError as e:
if allow_pickle:
remote_path = FlytePickle.to_pickle(ctx, v)
Expand All @@ -1717,7 +1732,7 @@ def dict_to_generic_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> L
),
metadata={"format": "pickle"},
)
raise e
raise TypeTransformerFailedError(f"Cannot convert {v} to Flyte Literal.\n" f"Error Message: {e}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise TypeTransformerFailedError(f"Cannot convert {v} to Flyte Literal.\n" f"Error Message: {e}")
raise TypeTransformerFailedError(f"Cannot convert `{v}` to Flyte Literal.\n" f"Error Message: {e}")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no problem, thank you


@staticmethod
def is_pickle(python_type: Type[dict]) -> typing.Tuple[bool, Type]:
Expand Down Expand Up @@ -1768,7 +1783,7 @@ def to_literal(
allow_pickle, base_type = DictTransformer.is_pickle(python_type)

if expected and expected.simple and expected.simple == SimpleType.STRUCT:
return self.dict_to_generic_literal(ctx, python_val, allow_pickle)
return self.dict_to_binary_literal(ctx, python_val, allow_pickle)

lit_map = {}
for k, v in python_val.items():
Expand All @@ -1785,6 +1800,9 @@ def to_literal(
return Literal(map=LiteralMap(literals=lit_map))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict:
if lv and lv.scalar and lv.scalar.binary is not None:
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore

if lv and lv.map and lv.map.literals is not None:
tp = self.dict_types(expected_python_type)

Expand Down
3 changes: 3 additions & 0 deletions plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from datetime import datetime, timedelta
from unittest import mock
import msgpack
import base64

import pytest
from flyteidl.core.execution_pb2 import TaskExecution
Expand Down Expand Up @@ -161,6 +163,7 @@ async def test_agent(mock_boto_call, mock_return_value):
if "pickle_check" in mock_return_value[0][0]:
assert "pickle_file" in outputs["result"]
else:
outputs["result"] = msgpack.loads(base64.b64decode(outputs["result"]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if scalar.binary:
return base64.b64encode(scalar.binary.value)

Because someone serialize the bytes in the binary IDL here before.

assert (
outputs["result"]["EndpointConfigArn"]
== "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config"
Expand Down
5 changes: 3 additions & 2 deletions plugins/flytekit-openai/tests/openai_batch/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from datetime import timedelta
from unittest import mock
from unittest.mock import AsyncMock

import msgpack
import base64
import pytest
from flyteidl.core.execution_pb2 import TaskExecution
from flytekitplugins.openai.batch.agent import BatchEndpointMetadata
Expand Down Expand Up @@ -159,7 +160,7 @@ async def test_openai_batch_agent(mock_retrieve, mock_create, mock_context):
outputs = literal_map_string_repr(resource.outputs)
result = outputs["result"]

assert result == batch_retrieve_result.to_dict()
assert msgpack.loads(base64.b64decode(result)) == batch_retrieve_result.to_dict()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


# Status: Failed
mock_retrieve.return_value = batch_retrieve_result_failure
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"marshmallow-enum",
"marshmallow-jsonschema>=0.12.0",
"mashumaro>=3.11",
"msgpack>=1.1.0",
"protobuf!=4.25.0",
"pygments",
"python-json-logger>=2.0.0",
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_local_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def test_stable_cache_key():
}
)
key = _calculate_cache_key("task_name_1", "31415", lm)
assert key == "task_name_1-31415-189e755a8f41c006889c291fcaedb4eb"
assert key == "task_name_1-31415-e3a85f91467d1e1f721ebe8129b2de31"


@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.")
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3218,10 +3218,10 @@ def test_union_file_directory():
(typing.List[Color], [Color.RED, Color.GREEN, Color.BLUE]),
(typing.List[Annotated[int, "tag"]], [1, 2, 3]),
(typing.List[Annotated[str, "tag"]], ["a", "b", "c"]),
(typing.Dict[int, str], {"1": "a", "2": "b", "3": "c"}),
(typing.Dict[int, str], {1: "a", 2: "b", 3: "c"}),
(typing.Dict[str, int], {"a": 1, "b": 2, "c": 3}),
(typing.Dict[str, typing.List[int]], {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}),
(typing.Dict[str, typing.Dict[int, str]], {"a": {"1": "a", "2": "b", "3": "c"}, "b": {"4": "d", "5": "e", "6": "f"}}),
(typing.Dict[str, typing.Dict[int, str]], {"a": {1: "a", 2: "b", 3: "c"}, "b": {4: "d", 5: "e", 6: "f"}}),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we add one more example for Union[dict, str]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, no problem

(typing.Union[int, str], 42),
(typing.Union[int, str], "hello"),
(typing.Union[typing.List[int], typing.List[str]], [1, 2, 3]),
Expand Down
134 changes: 134 additions & 0 deletions tests/flytekit/unit/core/test_type_engine_binary_idl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from datetime import datetime, date, timedelta

import msgpack
from mashumaro.codecs.msgpack import MessagePackEncoder

from flytekit.models.literals import Binary, Literal, Scalar
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.type_engine import TypeEngine

def test_simple_type_transformer():
ctx = FlyteContextManager.current_context()

int_inputs = [1, 2, 20240918, -1, -2, -20240918]
encoder = MessagePackEncoder(int)
for int_input in int_inputs:
int_msgpack_bytes = encoder.encode(int_input)
lv = Literal(scalar=Scalar(binary=Binary(value=int_msgpack_bytes, tag="msgpack")))
int_output = TypeEngine.to_python_value(ctx, lv, int)
assert int_input == int_output

float_inputs = [2024.0918, 5.0, -2024.0918, -5.0]
encoder = MessagePackEncoder(float)
for float_input in float_inputs:
float_msgpack_bytes = encoder.encode(float_input)
lv = Literal(scalar=Scalar(binary=Binary(value=float_msgpack_bytes, tag="msgpack")))
float_output = TypeEngine.to_python_value(ctx, lv, float)
assert float_input == float_output

bool_inputs = [True, False]
encoder = MessagePackEncoder(bool)
for bool_input in bool_inputs:
bool_msgpack_bytes = encoder.encode(bool_input)
lv = Literal(scalar=Scalar(binary=Binary(value=bool_msgpack_bytes, tag="msgpack")))
bool_output = TypeEngine.to_python_value(ctx, lv, bool)
assert bool_input == bool_output

str_inputs = ["hello", "world", "flyte", "kit", "is", "awesome"]
encoder = MessagePackEncoder(str)
for str_input in str_inputs:
str_msgpack_bytes = encoder.encode(str_input)
lv = Literal(scalar=Scalar(binary=Binary(value=str_msgpack_bytes, tag="msgpack")))
str_output = TypeEngine.to_python_value(ctx, lv, str)
assert str_input == str_output

datetime_inputs = [datetime.now(),
datetime(2024, 9, 18),
datetime(2024, 9, 18, 1),
datetime(2024, 9, 18, 1, 1),
datetime(2024, 9, 18, 1, 1, 1),
datetime(2024, 9, 18, 1, 1, 1, 1)]
encoder = MessagePackEncoder(datetime)
for datetime_input in datetime_inputs:
datetime_msgpack_bytes = encoder.encode(datetime_input)
lv = Literal(scalar=Scalar(binary=Binary(value=datetime_msgpack_bytes, tag="msgpack")))
datetime_output = TypeEngine.to_python_value(ctx, lv, datetime)
assert datetime_input == datetime_output

date_inputs = [date.today(),
date(2024, 9, 18)]
encoder = MessagePackEncoder(date)
for date_input in date_inputs:
date_msgpack_bytes = encoder.encode(date_input)
lv = Literal(scalar=Scalar(binary=Binary(value=date_msgpack_bytes, tag="msgpack")))
date_output = TypeEngine.to_python_value(ctx, lv, date)
assert date_input == date_output

timedelta_inputs = [timedelta(days=1),
timedelta(days=1, seconds=1),
timedelta(days=1, seconds=1, microseconds=1),
timedelta(days=1, seconds=1, microseconds=1, milliseconds=1),
timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1),
timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1),
timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1, weeks=1),
timedelta(days=-1, seconds=-1, microseconds=-1, milliseconds=-1, minutes=-1, hours=-1, weeks=-1)]
encoder = MessagePackEncoder(timedelta)
for timedelta_input in timedelta_inputs:
timedelta_msgpack_bytes = encoder.encode(timedelta_input)
lv = Literal(scalar=Scalar(binary=Binary(value=timedelta_msgpack_bytes, tag="msgpack")))
timedelta_output = TypeEngine.to_python_value(ctx, lv, timedelta)
assert timedelta_input == timedelta_output

def test_untyped_dict():
ctx = FlyteContextManager.current_context()

dict_inputs = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have tests that include datetime and other possibly more complex objects too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NO PROBLEM

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enum and datetime is not supported in untyped_dict, and I think this is ok, no one use it.

Copy link
Member Author

@Future-Outlier Future-Outlier Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other complex objects like data class will not be supported too, since we don't have type hints

# Basic key-value combinations with int, str, bool, float
{1: "a", "key": 2.5, True: False, 3.14: 100},
{"a": 1, 2: "b", 3.5: True, False: 3.1415},

{
1: [1, "a", 2.5, False],
"key_list": ["str", 3.14, True, 42],
True: [False, 2.718, "test"],
},

{
"nested_dict": {1: 2, "key": "value", True: 3.14, False: "string"},
3.14: {"pi": 3.14, "e": 2.718, 42: True},
},

{
"list_in_dict": [
{"inner_dict_1": [1, 2.5, "a"], "inner_dict_2": [True, False, 3.14]},
[1, 2, 3, {"nested_list_dict": [False, "test"]}],
]
},

{
"complex_nested": {
1: {"nested_dict": {True: [1, "a", 2.5]}},
"string_key": {False: {3.14: {"deep": [1, "deep_value"]}}},
}
},

{
"list_of_dicts": [{"a": 1, "b": 2}, {"key1": "value1", "key2": "value2"}],
10: [{"nested_list": [1, "value", 3.14]}, {"another_list": [True, False]}],
},

# More nested combinations of list and dict
{
"outer_list": [
[1, 2, 3],
{"inner_dict": {"key1": [True, "string", 3.14], "key2": [1, 2.5]}}, # Dict inside list
],
"another_dict": {"key1": {"subkey": [1, 2, "str"]}, "key2": [False, 3.14, "test"]},
},
]

for dict_input in dict_inputs:
dict_msgpack_bytes = msgpack.dumps(dict_input)
lv = Literal(scalar=Scalar(binary=Binary(value=dict_msgpack_bytes, tag="msgpack")))
dict_output = TypeEngine.to_python_value(ctx, lv, dict)
assert dict_input == dict_output
11 changes: 6 additions & 5 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from dataclasses import dataclass
from enum import Enum

import msgpack
import pytest
from dataclasses_json import DataClassJsonMixin
from google.protobuf.struct_pb2 import Struct
from mashumaro.codecs.json import JSONEncoder, JSONDecoder
from typing_extensions import Annotated, get_origin

import flytekit
Expand All @@ -37,6 +37,7 @@
from flytekit.models import literals as _literal_models
from flytekit.models.core import types as _core_types
from flytekit.models.interface import Parameter
from flytekit.models.literals import Binary
from flytekit.models.task import Resources as _resource_models
from flytekit.models.types import LiteralType, SimpleType
from flytekit.tools.translator import get_serializable
Expand Down Expand Up @@ -1488,7 +1489,7 @@ def t2(a: dict) -> str:
guessed_types = {"a": pt}
ctx = context_manager.FlyteContext.current_context()
lm = TypeEngine.dict_to_literal_map(ctx, d=input_map, type_hints=guessed_types)
assert isinstance(lm.literals["a"].scalar.generic, Struct)
assert isinstance(lm.literals["a"].scalar.binary, Binary)

output_lm = t2.dispatch_execute(ctx, lm)
str_value = output_lm.literals["o0"].scalar.primitive.string_value
Expand Down Expand Up @@ -1521,9 +1522,9 @@ def t2() -> dict:

ctx = context_manager.FlyteContextManager.current_context()
output_lm = t2.dispatch_execute(ctx, _literal_models.LiteralMap(literals={}))
expected_struct = Struct()
expected_struct.update({"k1": "v1", "k2": 3, "4": {"one": [1, "two", [3]]}})
assert output_lm.literals["o0"].scalar.generic == expected_struct
msgpack_bytes = msgpack.dumps({"k1": "v1", "k2": 3, 4: {"one": [1, "two", [3]]}})
binary_idl_obj = Binary(value=msgpack_bytes, tag="msgpack")
assert output_lm.literals["o0"].scalar.binary == binary_idl_obj


@pytest.mark.skipif(sys.version_info < (3, 9), reason="Use of dict hints is only supported in Python 3.9+")
Expand Down
Loading