Skip to content

Commit

Permalink
add annotation option for serialization (#1615)
Browse files Browse the repository at this point in the history
* add annotation option for serialization

* add support for guessing type

* formatting

* test both directions

* remove guessing for annotated types

* clarify usage in test

* Update tests/flytekit/unit/core/test_type_engine.py

---------

Co-authored-by: Eli Bixby <eli@cradle.bio>
Co-authored-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
  • Loading branch information
3 people authored May 11, 2023
1 parent 993201f commit dd44bba
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 4 deletions.
16 changes: 12 additions & 4 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,24 +718,32 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
d = dictionary of registered transformers, where is a python `type`
v = lookup type
Step 1:
find a transformer that matches v exactly
If the type is annotated with a TypeTransformer instance, use that.
Step 2:
find a transformer that matches the generic type of v. e.g List[int], Dict[str, int] etc
find a transformer that matches v exactly
Step 3:
find a transformer that matches the generic type of v. e.g List[int], Dict[str, int] etc
Step 4:
Walk the inheritance hierarchy of v and find a transformer that matches the first base class.
This is potentially non-deterministic - will depend on the registration pattern.
TODO lets make this deterministic by using an ordered dict
Step 4:
Step 5:
if v is of type data class, use the dataclass transformer
"""
cls.lazy_import_transformers()
# Step 1
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]
args = get_args(python_type)
for annotation in args:
if isinstance(annotation, TypeTransformer):
return annotation

python_type = args[0]

if python_type in cls._REGISTRY:
return cls._REGISTRY[python_type]
Expand Down
47 changes: 47 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import datetime
import json
import os
import tempfile
import typing
from dataclasses import asdict, dataclass, field
from datetime import timedelta
from enum import Enum
from typing import Optional, Type

import mock
import pandas as pd
Expand Down Expand Up @@ -170,6 +172,51 @@ class Foo(object):
assert pv[0].b == Bar(v=[1, 2, 99], w=[3.1415, 2.7182])


def test_annotated_type():
class JsonTypeTransformer(TypeTransformer[T]):
LiteralType = LiteralType(
simple=SimpleType.STRING, annotation=TypeAnnotation(annotations=dict(protocol="json"))
)

def get_literal_type(self, t: Type[T]) -> LiteralType:
return self.LiteralType

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[T]:
return json.loads(lv.scalar.primitive.string_value)

def to_literal(
self, ctx: FlyteContext, python_val: T, python_type: typing.Type[T], expected: LiteralType
) -> Literal:
return Literal(scalar=Scalar(primitive=Primitive(string_value=json.dumps(python_val))))

class JSONSerialized:
def __class_getitem__(cls, item: Type[T]):
return Annotated[item, JsonTypeTransformer(name=f"json[{item}]", t=item)]

MyJsonDict = JSONSerialized[typing.Dict[str, int]]
_, test_transformer = get_args(MyJsonDict)

assert TypeEngine.get_transformer(MyJsonDict) is test_transformer
assert TypeEngine.to_literal_type(MyJsonDict) == JsonTypeTransformer.LiteralType

test_dict = {"foo": 1}
test_literal = Literal(scalar=Scalar(primitive=Primitive(string_value=json.dumps(test_dict))))

assert (
TypeEngine.to_python_value(
FlyteContext.current_context(),
test_literal,
MyJsonDict,
)
== test_dict
)

assert (
TypeEngine.to_literal(FlyteContext.current_context(), test_dict, MyJsonDict, JsonTypeTransformer.LiteralType)
== test_literal
)


def test_list_of_dataclass_getting_python_value():
@dataclass_json
@dataclass()
Expand Down

0 comments on commit dd44bba

Please sign in to comment.