diff --git a/dev-requirements.in b/dev-requirements.in index 313ce1d82b..bf6ded9a7b 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -13,3 +13,5 @@ google-cloud-bigquery google-cloud-bigquery-storage IPython torch +tensorflow +grpcio-status<1.49.0 diff --git a/dev-requirements.txt b/dev-requirements.txt index 2eb5c51706..102fa71455 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,17 +1,25 @@ # -# This file is autogenerated by pip-compile with python 3.7 +# This file is autogenerated by pip-compile with python 3.9 # To update, run: # -# make dev-requirements.txt +# pip-compile --output-file=dev-requirements.txt dev-requirements.in # -e file:.#egg=flytekit # via # -c requirements.txt # pytest-flyte +absl-py==1.3.0 + # via + # tensorboard + # tensorflow +appnope==0.1.3 + # via ipython arrow==1.2.3 # via # -c requirements.txt # jinja2-time +astunparse==1.6.3 + # via tensorflow attrs==20.3.0 # via # -c requirements.txt @@ -26,8 +34,6 @@ binaryornot==0.4.4 # via # -c requirements.txt # cookiecutter -cached-property==1.5.2 - # via docker-compose cachetools==5.2.0 # via google-auth certifi==2022.9.24 @@ -77,7 +83,6 @@ cryptography==38.0.1 # -c requirements.txt # paramiko # pyopenssl - # secretstorage dataclasses-json==0.5.7 # via # -c requirements.txt @@ -120,10 +125,14 @@ docstring-parser==0.15 # flytekit filelock==3.8.0 # via virtualenv +flatbuffers==22.9.24 + # via tensorflow flyteidl==1.1.22 # via # -c requirements.txt # flytekit +gast==0.5.3 + # via tensorflow google-api-core[grpc]==2.10.2 # via # google-cloud-bigquery @@ -132,7 +141,11 @@ google-api-core[grpc]==2.10.2 google-auth==2.13.0 # via # google-api-core + # google-auth-oauthlib # google-cloud-core + # tensorboard +google-auth-oauthlib==0.4.6 + # via tensorboard google-cloud-bigquery==3.3.5 # via -r dev-requirements.in google-cloud-bigquery-storage==2.16.2 @@ -143,6 +156,8 @@ google-cloud-core==2.3.2 # via google-cloud-bigquery google-crc32c==1.5.0 # via google-resumable-media +google-pasta==0.2.0 + # via tensorflow google-resumable-media==2.4.0 # via google-cloud-bigquery googleapis-common-protos==1.56.4 @@ -158,11 +173,16 @@ grpcio==1.47.0 # google-api-core # google-cloud-bigquery # grpcio-status + # tensorboard + # tensorflow grpcio-status==1.47.0 # via # -c requirements.txt + # -r dev-requirements.in # flytekit # google-api-core +h5py==3.7.0 + # via tensorflow identify==2.5.6 # via pre-commit idna==3.4 @@ -172,14 +192,9 @@ idna==3.4 importlib-metadata==5.0.0 # via # -c requirements.txt - # click # flytekit - # jsonschema # keyring - # pluggy - # pre-commit - # pytest - # virtualenv + # markdown iniconfig==1.1.1 # via pytest ipython==7.34.0 @@ -190,11 +205,6 @@ jaraco-classes==3.2.3 # keyring jedi==0.18.1 # via ipython -jeepney==0.8.0 - # via - # -c requirements.txt - # keyring - # secretstorage jinja2==3.1.2 # via # -c requirements.txt @@ -214,14 +224,23 @@ jsonschema==3.2.0 # via # -c requirements.txt # docker-compose +keras==2.8.0 + # via tensorflow +keras-preprocessing==1.1.2 + # via tensorflow keyring==23.9.3 # via # -c requirements.txt # flytekit +libclang==14.0.6 + # via tensorflow +markdown==3.4.1 + # via tensorboard markupsafe==2.1.1 # via # -c requirements.txt # jinja2 + # werkzeug marshmallow==3.18.0 # via # -c requirements.txt @@ -260,9 +279,17 @@ nodeenv==1.7.0 numpy==1.21.6 # via # -c requirements.txt - # flytekit + # h5py + # keras-preprocessing + # opt-einsum # pandas # pyarrow + # tensorboard + # tensorflow +oauthlib==3.2.2 + # via requests-oauthlib +opt-einsum==3.3.0 + # via tensorflow packaging==21.3 # via # -c requirements.txt @@ -306,6 +333,8 @@ protobuf==3.20.3 # grpcio-status # proto-plus # protoc-gen-swagger + # tensorboard + # tensorflow protoc-gen-swagger==0.1.0 # via # -c requirements.txt @@ -407,7 +436,11 @@ requests==2.28.1 # flytekit # google-api-core # google-cloud-bigquery + # requests-oauthlib # responses + # tensorboard +requests-oauthlib==1.3.1 + # via google-auth-oauthlib responses==0.22.0 # via # -c requirements.txt @@ -418,23 +451,19 @@ retry==0.9.2 # flytekit rsa==4.9 # via google-auth -secretstorage==3.3.3 - # via - # -c requirements.txt - # keyring -singledispatchmethod==1.0 - # via - # -c requirements.txt - # flytekit six==1.16.0 # via # -c requirements.txt + # astunparse # dockerpty # google-auth + # google-pasta # grpcio # jsonschema + # keras-preprocessing # paramiko # python-dateutil + # tensorflow # websocket-client sortedcontainers==2.4.0 # via @@ -444,6 +473,20 @@ statsd==3.3.0 # via # -c requirements.txt # flytekit +tensorboard==2.8.0 + # via tensorflow +tensorboard-data-server==0.6.1 + # via tensorboard +tensorboard-plugin-wit==1.8.1 + # via tensorboard +tensorflow==2.8.1 + # via -r dev-requirements.in +tensorflow-estimator==2.8.0 + # via tensorflow +tensorflow-io-gcs-filesystem==0.27.0 + # via tensorflow +termcolor==2.0.1 + # via tensorflow text-unidecode==1.3 # via # -c requirements.txt @@ -466,8 +509,6 @@ traitlets==5.5.0 # via # ipython # matplotlib-inline -typed-ast==1.5.4 - # via mypy types-toml==0.10.8 # via # -c requirements.txt @@ -475,11 +516,9 @@ types-toml==0.10.8 typing-extensions==4.4.0 # via # -c requirements.txt - # arrow # flytekit - # importlib-metadata # mypy - # responses + # tensorflow # torch # typing-inspect typing-inspect==0.8.0 @@ -502,15 +541,20 @@ websocket-client==0.59.0 # -c requirements.txt # docker # docker-compose +werkzeug==2.2.2 + # via tensorboard wheel==0.37.1 # via # -c requirements.txt + # astunparse # flytekit + # tensorboard wrapt==1.14.1 # via # -c requirements.txt # deprecated # flytekit + # tensorflow zipp==3.9.0 # via # -c requirements.txt diff --git a/docs/source/extras.keras.rst b/docs/source/extras.keras.rst new file mode 100644 index 0000000000..4660d81711 --- /dev/null +++ b/docs/source/extras.keras.rst @@ -0,0 +1,7 @@ +############ +Keras Type +############ +.. automodule:: flytekit.extras.keras + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/types.extend.rst b/docs/source/types.extend.rst index d72c3da0cf..1c7a65b89a 100644 --- a/docs/source/types.extend.rst +++ b/docs/source/types.extend.rst @@ -13,3 +13,4 @@ Refer to :doc:`cookbook ` if you'd like to contribute a F types.builtins.file types.builtins.directory extras.pytorch + extras.keras diff --git a/flytekit/__init__.py b/flytekit/__init__.py index fa7ca17fe2..a1549d331c 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -184,7 +184,7 @@ from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck -from flytekit.extras import pytorch +from flytekit.extras import keras, pytorch from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence from flytekit.loggers import logger from flytekit.models.common import Annotations, AuthRole, Labels diff --git a/flytekit/extras/keras/__init__.py b/flytekit/extras/keras/__init__.py new file mode 100644 index 0000000000..ea1adf8a5d --- /dev/null +++ b/flytekit/extras/keras/__init__.py @@ -0,0 +1,28 @@ +""" +Flytekit Keras +========================================= +.. currentmodule:: flytekit.extras.keras + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + +""" +from flytekit.loggers import logger + +# that have soft dependencies +try: + # isolate the exception to the keras import + from tensorflow import keras + + _keras_installed = True +except (ImportError, OSError): + _keras_installed = False + + +if _keras_installed: + from .native import KerasModelTransformer, KerasSequentialTransformer +else: + logger.info( + "We won't register KerasSequentialTransformer and KerasModelTransformer because keras is not installed." + ) diff --git a/flytekit/extras/keras/native.py b/flytekit/extras/keras/native.py new file mode 100644 index 0000000000..c86d1710d2 --- /dev/null +++ b/flytekit/extras/keras/native.py @@ -0,0 +1,85 @@ +import pathlib +from typing import Generic, Type, TypeVar + +import keras + +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core import types as _core_types +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType + +T = TypeVar("T") + + +class KerasTypeTransformer(TypeTransformer, Generic[T]): + def get_literal_type(self, t: Type[T]) -> LiteralType: + return LiteralType( + blob=_core_types.BlobType( + format=self.KERAS_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) + ) + + def to_literal( + self, + ctx: FlyteContext, + python_val: T, + python_type: Type[T], + expected: LiteralType, + ) -> Literal: + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.KERAS_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) + ) + + local_path = ctx.file_access.get_random_local_path() + ".h5" + pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) + + # save keras model to a folder in SavedModel format + keras.models.save_model(python_val, local_path) + + remote_path = ctx.file_access.get_random_remote_path(local_path) + ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + try: + uri = lv.scalar.blob.uri + except AttributeError: + TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + + local_path = ctx.file_access.get_random_local_path() + ctx.file_access.get_data(uri, local_path, is_multipart=False) + + return keras.models.load_model(local_path) + + def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.KERAS_FORMAT + ): + return T + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + + +class KerasSequentialTransformer(KerasTypeTransformer[keras.Sequential]): + KERAS_FORMAT = "KerasSequential" + + def __init__(self): + super().__init__(name="Keras Sequential", t=keras.Sequential) + + +class KerasModelTransformer(KerasTypeTransformer[keras.Model]): + KERAS_FORMAT = "KerasModel" + + def __init__(self): + super().__init__(name="Keras Model", t=keras.Model) + + +TypeEngine.register(KerasSequentialTransformer()) +TypeEngine.register(KerasModelTransformer()) diff --git a/tests/flytekit/unit/extras/keras/__init__.py b/tests/flytekit/unit/extras/keras/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/extras/keras/test_native.py b/tests/flytekit/unit/extras/keras/test_native.py new file mode 100644 index 0000000000..ad467f6bde --- /dev/null +++ b/tests/flytekit/unit/extras/keras/test_native.py @@ -0,0 +1,51 @@ +from typing import List + +import keras +import numpy as np + +from flytekit import task, workflow + + +@task +def get_model_with_sequential_class() -> keras.Sequential: + model = keras.Sequential() + model.add(keras.layers.Dense(8, input_shape=(16,))) + model.add(keras.layers.Dense(4)) + return model + + +@task +def get_model_with_model_class() -> keras.Model: + inputs = keras.Input(shape=(3,)) + x = keras.layers.Dense(4)(inputs) + outputs = keras.layers.Dense(5)(x) + model = keras.Model(inputs=inputs, outputs=outputs) + return model + + +@task +def get_model_weights(model: keras.Sequential) -> List[np.array]: + assert len(model.weights) == 4 + return model.weights + + +@task +def get_model_layer(model: keras.Sequential) -> List[keras.layers.core.dense.Dense]: + if isinstance(model, keras.Sequential): + assert len(model.layers) == 2 + elif isinstance(model, keras.Model): + assert len(model.layers) == 3 + return model.layers + + +@workflow +def wf(): + models = (get_model_with_sequential_class(), get_model_with_model_class()) + for m in models: + get_model_weights(model=m) + get_model_layer(model=m) + + +@workflow +def test_wf(): + wf() diff --git a/tests/flytekit/unit/extras/keras/test_transformations.py b/tests/flytekit/unit/extras/keras/test_transformations.py new file mode 100644 index 0000000000..ee3bd66fe8 --- /dev/null +++ b/tests/flytekit/unit/extras/keras/test_transformations.py @@ -0,0 +1,109 @@ +from collections import OrderedDict + +import keras +import numpy as np +import pytest + +import flytekit +from flytekit import task +from flytekit.configuration import Image, ImageConfig +from flytekit.core import context_manager +from flytekit.extras.keras import KerasModelTransformer, KerasSequentialTransformer +from flytekit.models.core.types import BlobType +from flytekit.models.literals import BlobMetadata +from flytekit.models.types import LiteralType +from flytekit.tools.translator import get_serializable + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +def build_keras_sequential_model(): + model = keras.Sequential() + model.add(keras.Input(shape=(16,))) + model.add(keras.layers.Dense(8)) + model.add(keras.layers.Dense(1)) + model.compile(optimizer="sgd", loss="mse") + return model + + +def build_keras_model_class() -> keras.Model: + inputs = keras.Input(shape=(16,)) + x = keras.layers.Dense(8)(inputs) + outputs = keras.layers.Dense(1)(x) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile(optimizer="sgd", loss="mse") + return model + + +@pytest.mark.parametrize( + "transformer,python_type,format", + [ + (KerasSequentialTransformer(), keras.Sequential, KerasSequentialTransformer.KERAS_FORMAT), + (KerasModelTransformer(), keras.Model, KerasModelTransformer.KERAS_FORMAT), + ], +) +def test_get_literal_type(transformer, python_type, format): + tf = transformer + lt = tf.get_literal_type(python_type) + assert lt == LiteralType(blob=BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE)) + + +@pytest.mark.parametrize( + "transformer,python_type,format,python_val", + [ + ( + KerasSequentialTransformer(), + keras.Sequential, + KerasSequentialTransformer.KERAS_FORMAT, + build_keras_sequential_model(), + ), + ( + KerasModelTransformer(), + keras.Model, + KerasModelTransformer.KERAS_FORMAT, + build_keras_model_class(), + ), + ], +) +def test_to_python_value_and_literal(transformer, python_type, format, python_val): + ctx = context_manager.FlyteContext.current_context() + tf = transformer + lt = tf.get_literal_type(python_type) + lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore + assert lv.scalar.blob.metadata == BlobMetadata( + type=BlobType( + format=format, + dimensionality=BlobType.BlobDimensionality.SINGLE, + ) + ) + assert lv.scalar.blob.uri is not None + + output = tf.to_python_value(ctx, lv, python_type) + if isinstance(python_val, (keras.Sequential, keras.Model)): + for p1, p2 in zip(output.weights, python_val.weights): + np.testing.assert_array_equal(p1.numpy(), p2.numpy()) + assert True + else: + assert isinstance(output, dict) + + +def test_example_model(): + @task + def t1() -> keras.Sequential: + return build_keras_sequential_model() + + @task + def t2() -> keras.Model: + return build_keras_model_class() + + task_spec1 = get_serializable(OrderedDict(), serialization_settings, t1) + task_spec2 = get_serializable(OrderedDict(), serialization_settings, t2) + assert task_spec1.template.interface.outputs["o0"].type.blob.format is KerasSequentialTransformer.KERAS_FORMAT + assert task_spec2.template.interface.outputs["o0"].type.blob.format is KerasModelTransformer.KERAS_FORMAT