-
Notifications
You must be signed in to change notification settings - Fork 293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flytekit][2][untyped dict] Binary IDL With MessagePack #2757
Changes from all commits
e3a258a
3562f0c
f93b441
c05a905
b1bf20c
5539c1e
be6c024
6b59d89
bcaf573
dd5e1c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -15,8 +15,9 @@ | |||||
from abc import ABC, abstractmethod | ||||||
from collections import OrderedDict | ||||||
from functools import lru_cache | ||||||
from typing import Dict, List, NamedTuple, Optional, Type, cast | ||||||
from typing import Any, Dict, List, NamedTuple, Optional, Type, cast | ||||||
|
||||||
import msgpack | ||||||
from dataclasses_json import DataClassJsonMixin, dataclass_json | ||||||
from flyteidl.core import literals_pb2 | ||||||
from google.protobuf import json_format as _json_format | ||||||
|
@@ -26,6 +27,7 @@ | |||||
from google.protobuf.message import Message | ||||||
from google.protobuf.struct_pb2 import Struct | ||||||
from mashumaro.codecs.json import JSONDecoder, JSONEncoder | ||||||
from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder | ||||||
from mashumaro.mixins.json import DataClassJSONMixin | ||||||
from typing_extensions import Annotated, get_args, get_origin | ||||||
|
||||||
|
@@ -42,22 +44,21 @@ | |||||
from flytekit.models import types as _type_models | ||||||
from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel | ||||||
from flytekit.models.core import types as _core_types | ||||||
from flytekit.models.literals import ( | ||||||
Literal, | ||||||
LiteralCollection, | ||||||
LiteralMap, | ||||||
Primitive, | ||||||
Scalar, | ||||||
Union, | ||||||
Void, | ||||||
) | ||||||
from flytekit.models.literals import Binary, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Union, Void | ||||||
from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType | ||||||
|
||||||
T = typing.TypeVar("T") | ||||||
DEFINITIONS = "definitions" | ||||||
TITLE = "title" | ||||||
|
||||||
|
||||||
# In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True. | ||||||
# This is relevant for cases like Dict[int, str]. | ||||||
# If strict_map_key=False is not used, the decoder will raise an error when trying to decode keys that are not strictly typed.` | ||||||
def _default_flytekit_decoder(data: bytes) -> Any: | ||||||
return msgpack.unpackb(data, raw=False, strict_map_key=False) | ||||||
|
||||||
|
||||||
class BatchSize: | ||||||
""" | ||||||
This is used to annotate a FlyteDirectory when we want to download/upload the contents of the directory in batches. For example, | ||||||
|
@@ -129,6 +130,8 @@ def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True): | |||||
self._t = t | ||||||
self._name = name | ||||||
self._type_assertions_enabled = enable_type_assertions | ||||||
self._msgpack_encoder: Dict[Type, MessagePackEncoder] = {} | ||||||
self._msgpack_decoder: Dict[Type, MessagePackDecoder] = {} | ||||||
|
||||||
@property | ||||||
def name(self): | ||||||
|
@@ -221,6 +224,17 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: | |||||
f"Conversion to python value expected type {expected_python_type} from literal not implemented" | ||||||
) | ||||||
|
||||||
def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]: | ||||||
if binary_idl_object.tag == "msgpack": | ||||||
try: | ||||||
decoder = self._msgpack_decoder[expected_python_type] | ||||||
except KeyError: | ||||||
decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_flytekit_decoder) | ||||||
self._msgpack_decoder[expected_python_type] = decoder | ||||||
return decoder.decode(binary_idl_object.value) | ||||||
else: | ||||||
raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thank you! |
||||||
|
||||||
def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str: | ||||||
""" | ||||||
Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div | ||||||
|
@@ -271,6 +285,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: | |||||
f"Cannot convert to type {expected_python_type}, only {self._type} is supported" | ||||||
) | ||||||
|
||||||
if lv.scalar and lv.scalar.binary: | ||||||
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore | ||||||
Future-Outlier marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
try: # todo(maximsmol): this is quite ugly and each transformer should really check their Literal | ||||||
res = self._from_literal_transformer(lv) | ||||||
if type(res) != self._type: | ||||||
|
@@ -1697,17 +1714,15 @@ def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple: | |||||
return None, None | ||||||
|
||||||
@staticmethod | ||||||
def dict_to_generic_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Literal: | ||||||
def dict_to_binary_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Literal: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but we still use generic to save the pickle file, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, I didn't change the logic for pickle |
||||||
""" | ||||||
Creates a flyte-specific ``Literal`` value from a native python dictionary. | ||||||
""" | ||||||
from flytekit.types.pickle import FlytePickle | ||||||
|
||||||
try: | ||||||
return Literal( | ||||||
scalar=Scalar(generic=_json_format.Parse(json.dumps(v), _struct.Struct())), | ||||||
metadata={"format": "json"}, | ||||||
) | ||||||
msgpack_bytes = msgpack.dumps(v) | ||||||
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack"))) | ||||||
except TypeError as e: | ||||||
if allow_pickle: | ||||||
remote_path = FlytePickle.to_pickle(ctx, v) | ||||||
|
@@ -1717,7 +1732,7 @@ def dict_to_generic_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> L | |||||
), | ||||||
metadata={"format": "pickle"}, | ||||||
) | ||||||
raise e | ||||||
raise TypeTransformerFailedError(f"Cannot convert {v} to Flyte Literal.\n" f"Error Message: {e}") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no problem, thank you |
||||||
|
||||||
@staticmethod | ||||||
def is_pickle(python_type: Type[dict]) -> typing.Tuple[bool, Type]: | ||||||
|
@@ -1768,7 +1783,7 @@ def to_literal( | |||||
allow_pickle, base_type = DictTransformer.is_pickle(python_type) | ||||||
|
||||||
if expected and expected.simple and expected.simple == SimpleType.STRUCT: | ||||||
return self.dict_to_generic_literal(ctx, python_val, allow_pickle) | ||||||
return self.dict_to_binary_literal(ctx, python_val, allow_pickle) | ||||||
|
||||||
lit_map = {} | ||||||
for k, v in python_val.items(): | ||||||
|
@@ -1785,6 +1800,9 @@ def to_literal( | |||||
return Literal(map=LiteralMap(literals=lit_map)) | ||||||
|
||||||
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict: | ||||||
if lv and lv.scalar and lv.scalar.binary is not None: | ||||||
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore | ||||||
|
||||||
if lv and lv.map and lv.map.literals is not None: | ||||||
tp = self.dict_types(expected_python_type) | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,5 +1,7 @@ | ||||||
from datetime import datetime, timedelta | ||||||
from unittest import mock | ||||||
import msgpack | ||||||
import base64 | ||||||
|
||||||
import pytest | ||||||
from flyteidl.core.execution_pb2 import TaskExecution | ||||||
|
@@ -161,6 +163,7 @@ async def test_agent(mock_boto_call, mock_return_value): | |||||
if "pickle_check" in mock_return_value[0][0]: | ||||||
assert "pickle_file" in outputs["result"] | ||||||
else: | ||||||
outputs["result"] = msgpack.loads(base64.b64decode(outputs["result"])) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. flytekit/flytekit/interaction/string_literals.py Lines 44 to 45 in 9a9f5b2
Because someone serialize the |
||||||
assert ( | ||||||
outputs["result"]["EndpointConfigArn"] | ||||||
== "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config" | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
from datetime import timedelta | ||
from unittest import mock | ||
from unittest.mock import AsyncMock | ||
|
||
import msgpack | ||
import base64 | ||
import pytest | ||
from flyteidl.core.execution_pb2 import TaskExecution | ||
from flytekitplugins.openai.batch.agent import BatchEndpointMetadata | ||
|
@@ -159,7 +160,7 @@ async def test_openai_batch_agent(mock_retrieve, mock_create, mock_context): | |
outputs = literal_map_string_repr(resource.outputs) | ||
result = outputs["result"] | ||
|
||
assert result == batch_retrieve_result.to_dict() | ||
assert msgpack.loads(base64.b64decode(result)) == batch_retrieve_result.to_dict() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
|
||
# Status: Failed | ||
mock_retrieve.return_value = batch_retrieve_result_failure | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3218,10 +3218,10 @@ def test_union_file_directory(): | |
(typing.List[Color], [Color.RED, Color.GREEN, Color.BLUE]), | ||
(typing.List[Annotated[int, "tag"]], [1, 2, 3]), | ||
(typing.List[Annotated[str, "tag"]], ["a", "b", "c"]), | ||
(typing.Dict[int, str], {"1": "a", "2": "b", "3": "c"}), | ||
(typing.Dict[int, str], {1: "a", 2: "b", 3: "c"}), | ||
(typing.Dict[str, int], {"a": 1, "b": 2, "c": 3}), | ||
(typing.Dict[str, typing.List[int]], {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}), | ||
(typing.Dict[str, typing.Dict[int, str]], {"a": {"1": "a", "2": "b", "3": "c"}, "b": {"4": "d", "5": "e", "6": "f"}}), | ||
(typing.Dict[str, typing.Dict[int, str]], {"a": {1: "a", 2: "b", 3: "c"}, "b": {4: "d", 5: "e", 6: "f"}}), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we add one more example for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, no problem |
||
(typing.Union[int, str], 42), | ||
(typing.Union[int, str], "hello"), | ||
(typing.Union[typing.List[int], typing.List[str]], [1, 2, 3]), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from datetime import datetime, date, timedelta | ||
|
||
import msgpack | ||
from mashumaro.codecs.msgpack import MessagePackEncoder | ||
|
||
from flytekit.models.literals import Binary, Literal, Scalar | ||
from flytekit.core.context_manager import FlyteContextManager | ||
from flytekit.core.type_engine import TypeEngine | ||
|
||
def test_simple_type_transformer(): | ||
ctx = FlyteContextManager.current_context() | ||
|
||
int_inputs = [1, 2, 20240918, -1, -2, -20240918] | ||
encoder = MessagePackEncoder(int) | ||
for int_input in int_inputs: | ||
int_msgpack_bytes = encoder.encode(int_input) | ||
lv = Literal(scalar=Scalar(binary=Binary(value=int_msgpack_bytes, tag="msgpack"))) | ||
int_output = TypeEngine.to_python_value(ctx, lv, int) | ||
assert int_input == int_output | ||
|
||
float_inputs = [2024.0918, 5.0, -2024.0918, -5.0] | ||
encoder = MessagePackEncoder(float) | ||
for float_input in float_inputs: | ||
float_msgpack_bytes = encoder.encode(float_input) | ||
lv = Literal(scalar=Scalar(binary=Binary(value=float_msgpack_bytes, tag="msgpack"))) | ||
float_output = TypeEngine.to_python_value(ctx, lv, float) | ||
assert float_input == float_output | ||
|
||
bool_inputs = [True, False] | ||
encoder = MessagePackEncoder(bool) | ||
for bool_input in bool_inputs: | ||
bool_msgpack_bytes = encoder.encode(bool_input) | ||
lv = Literal(scalar=Scalar(binary=Binary(value=bool_msgpack_bytes, tag="msgpack"))) | ||
bool_output = TypeEngine.to_python_value(ctx, lv, bool) | ||
assert bool_input == bool_output | ||
|
||
str_inputs = ["hello", "world", "flyte", "kit", "is", "awesome"] | ||
encoder = MessagePackEncoder(str) | ||
for str_input in str_inputs: | ||
str_msgpack_bytes = encoder.encode(str_input) | ||
lv = Literal(scalar=Scalar(binary=Binary(value=str_msgpack_bytes, tag="msgpack"))) | ||
str_output = TypeEngine.to_python_value(ctx, lv, str) | ||
assert str_input == str_output | ||
|
||
datetime_inputs = [datetime.now(), | ||
datetime(2024, 9, 18), | ||
datetime(2024, 9, 18, 1), | ||
datetime(2024, 9, 18, 1, 1), | ||
datetime(2024, 9, 18, 1, 1, 1), | ||
datetime(2024, 9, 18, 1, 1, 1, 1)] | ||
encoder = MessagePackEncoder(datetime) | ||
for datetime_input in datetime_inputs: | ||
datetime_msgpack_bytes = encoder.encode(datetime_input) | ||
lv = Literal(scalar=Scalar(binary=Binary(value=datetime_msgpack_bytes, tag="msgpack"))) | ||
datetime_output = TypeEngine.to_python_value(ctx, lv, datetime) | ||
assert datetime_input == datetime_output | ||
|
||
date_inputs = [date.today(), | ||
date(2024, 9, 18)] | ||
encoder = MessagePackEncoder(date) | ||
for date_input in date_inputs: | ||
date_msgpack_bytes = encoder.encode(date_input) | ||
lv = Literal(scalar=Scalar(binary=Binary(value=date_msgpack_bytes, tag="msgpack"))) | ||
date_output = TypeEngine.to_python_value(ctx, lv, date) | ||
assert date_input == date_output | ||
|
||
timedelta_inputs = [timedelta(days=1), | ||
timedelta(days=1, seconds=1), | ||
timedelta(days=1, seconds=1, microseconds=1), | ||
timedelta(days=1, seconds=1, microseconds=1, milliseconds=1), | ||
timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1), | ||
timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1), | ||
timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1, weeks=1), | ||
timedelta(days=-1, seconds=-1, microseconds=-1, milliseconds=-1, minutes=-1, hours=-1, weeks=-1)] | ||
encoder = MessagePackEncoder(timedelta) | ||
for timedelta_input in timedelta_inputs: | ||
timedelta_msgpack_bytes = encoder.encode(timedelta_input) | ||
lv = Literal(scalar=Scalar(binary=Binary(value=timedelta_msgpack_bytes, tag="msgpack"))) | ||
timedelta_output = TypeEngine.to_python_value(ctx, lv, timedelta) | ||
assert timedelta_input == timedelta_output | ||
|
||
def test_untyped_dict(): | ||
ctx = FlyteContextManager.current_context() | ||
|
||
dict_inputs = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we have tests that include There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NO PROBLEM There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. enum and datetime is not supported in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Other complex objects like data class will not be supported too, since we don't have type hints |
||
# Basic key-value combinations with int, str, bool, float | ||
{1: "a", "key": 2.5, True: False, 3.14: 100}, | ||
{"a": 1, 2: "b", 3.5: True, False: 3.1415}, | ||
|
||
{ | ||
1: [1, "a", 2.5, False], | ||
"key_list": ["str", 3.14, True, 42], | ||
True: [False, 2.718, "test"], | ||
}, | ||
|
||
{ | ||
"nested_dict": {1: 2, "key": "value", True: 3.14, False: "string"}, | ||
3.14: {"pi": 3.14, "e": 2.718, 42: True}, | ||
}, | ||
|
||
{ | ||
"list_in_dict": [ | ||
{"inner_dict_1": [1, 2.5, "a"], "inner_dict_2": [True, False, 3.14]}, | ||
[1, 2, 3, {"nested_list_dict": [False, "test"]}], | ||
] | ||
}, | ||
|
||
{ | ||
"complex_nested": { | ||
1: {"nested_dict": {True: [1, "a", 2.5]}}, | ||
"string_key": {False: {3.14: {"deep": [1, "deep_value"]}}}, | ||
} | ||
}, | ||
|
||
{ | ||
"list_of_dicts": [{"a": 1, "b": 2}, {"key1": "value1", "key2": "value2"}], | ||
10: [{"nested_list": [1, "value", 3.14]}, {"another_list": [True, False]}], | ||
}, | ||
|
||
# More nested combinations of list and dict | ||
{ | ||
"outer_list": [ | ||
[1, 2, 3], | ||
{"inner_dict": {"key1": [True, "string", 3.14], "key2": [1, 2.5]}}, # Dict inside list | ||
], | ||
"another_dict": {"key1": {"subkey": [1, 2, "str"]}, "key2": [False, 3.14, "test"]}, | ||
}, | ||
] | ||
|
||
for dict_input in dict_inputs: | ||
dict_msgpack_bytes = msgpack.dumps(dict_input) | ||
lv = Literal(scalar=Scalar(binary=Binary(value=dict_msgpack_bytes, tag="msgpack"))) | ||
dict_output = TypeEngine.to_python_value(ctx, lv, dict) | ||
assert dict_input == dict_output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could we call this _default_msgpack_decoder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we may have other decoder in the future, like
_default_utf8_decoder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, pretty good advice, thank you!