Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeTransformer for Keras #1242

Closed
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ google-cloud-bigquery
google-cloud-bigquery-storage
IPython
torch
tensorflow
grpcio-status<1.49.0
104 changes: 74 additions & 30 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
#
# This file is autogenerated by pip-compile with python 3.7
# This file is autogenerated by pip-compile with python 3.9
# To update, run:
#
# make dev-requirements.txt
# pip-compile --output-file=dev-requirements.txt dev-requirements.in
#
-e file:.#egg=flytekit
# via
# -c requirements.txt
# pytest-flyte
absl-py==1.3.0
# via
# tensorboard
# tensorflow
appnope==0.1.3
# via ipython
arrow==1.2.3
# via
# -c requirements.txt
# jinja2-time
astunparse==1.6.3
# via tensorflow
attrs==20.3.0
# via
# -c requirements.txt
Expand All @@ -26,8 +34,6 @@ binaryornot==0.4.4
# via
# -c requirements.txt
# cookiecutter
cached-property==1.5.2
# via docker-compose
cachetools==5.2.0
# via google-auth
certifi==2022.9.24
Expand Down Expand Up @@ -77,7 +83,6 @@ cryptography==38.0.1
# -c requirements.txt
# paramiko
# pyopenssl
# secretstorage
dataclasses-json==0.5.7
# via
# -c requirements.txt
Expand Down Expand Up @@ -120,10 +125,14 @@ docstring-parser==0.15
# flytekit
filelock==3.8.0
# via virtualenv
flatbuffers==22.9.24
# via tensorflow
flyteidl==1.1.22
# via
# -c requirements.txt
# flytekit
gast==0.5.3
# via tensorflow
google-api-core[grpc]==2.10.2
# via
# google-cloud-bigquery
Expand All @@ -132,7 +141,11 @@ google-api-core[grpc]==2.10.2
google-auth==2.13.0
# via
# google-api-core
# google-auth-oauthlib
# google-cloud-core
# tensorboard
google-auth-oauthlib==0.4.6
# via tensorboard
google-cloud-bigquery==3.3.5
# via -r dev-requirements.in
google-cloud-bigquery-storage==2.16.2
Expand All @@ -143,6 +156,8 @@ google-cloud-core==2.3.2
# via google-cloud-bigquery
google-crc32c==1.5.0
# via google-resumable-media
google-pasta==0.2.0
# via tensorflow
google-resumable-media==2.4.0
# via google-cloud-bigquery
googleapis-common-protos==1.56.4
Expand All @@ -158,11 +173,16 @@ grpcio==1.47.0
# google-api-core
# google-cloud-bigquery
# grpcio-status
# tensorboard
# tensorflow
grpcio-status==1.47.0
# via
# -c requirements.txt
# -r dev-requirements.in
# flytekit
# google-api-core
h5py==3.7.0
# via tensorflow
identify==2.5.6
# via pre-commit
idna==3.4
Expand All @@ -172,14 +192,9 @@ idna==3.4
importlib-metadata==5.0.0
# via
# -c requirements.txt
# click
# flytekit
# jsonschema
# keyring
# pluggy
# pre-commit
# pytest
# virtualenv
# markdown
iniconfig==1.1.1
# via pytest
ipython==7.34.0
Expand All @@ -190,11 +205,6 @@ jaraco-classes==3.2.3
# keyring
jedi==0.18.1
# via ipython
jeepney==0.8.0
# via
# -c requirements.txt
# keyring
# secretstorage
jinja2==3.1.2
# via
# -c requirements.txt
Expand All @@ -214,14 +224,23 @@ jsonschema==3.2.0
# via
# -c requirements.txt
# docker-compose
keras==2.8.0
# via tensorflow
keras-preprocessing==1.1.2
# via tensorflow
keyring==23.9.3
# via
# -c requirements.txt
# flytekit
libclang==14.0.6
# via tensorflow
markdown==3.4.1
# via tensorboard
markupsafe==2.1.1
# via
# -c requirements.txt
# jinja2
# werkzeug
marshmallow==3.18.0
# via
# -c requirements.txt
Expand Down Expand Up @@ -260,9 +279,17 @@ nodeenv==1.7.0
numpy==1.21.6
# via
# -c requirements.txt
# flytekit
# h5py
# keras-preprocessing
# opt-einsum
# pandas
# pyarrow
# tensorboard
# tensorflow
oauthlib==3.2.2
# via requests-oauthlib
opt-einsum==3.3.0
# via tensorflow
packaging==21.3
# via
# -c requirements.txt
Expand Down Expand Up @@ -306,6 +333,8 @@ protobuf==3.20.3
# grpcio-status
# proto-plus
# protoc-gen-swagger
# tensorboard
# tensorflow
protoc-gen-swagger==0.1.0
# via
# -c requirements.txt
Expand Down Expand Up @@ -407,7 +436,11 @@ requests==2.28.1
# flytekit
# google-api-core
# google-cloud-bigquery
# requests-oauthlib
# responses
# tensorboard
requests-oauthlib==1.3.1
# via google-auth-oauthlib
responses==0.22.0
# via
# -c requirements.txt
Expand All @@ -418,23 +451,19 @@ retry==0.9.2
# flytekit
rsa==4.9
# via google-auth
secretstorage==3.3.3
# via
# -c requirements.txt
# keyring
singledispatchmethod==1.0
# via
# -c requirements.txt
# flytekit
six==1.16.0
# via
# -c requirements.txt
# astunparse
# dockerpty
# google-auth
# google-pasta
# grpcio
# jsonschema
# keras-preprocessing
# paramiko
# python-dateutil
# tensorflow
# websocket-client
sortedcontainers==2.4.0
# via
Expand All @@ -444,6 +473,20 @@ statsd==3.3.0
# via
# -c requirements.txt
# flytekit
tensorboard==2.8.0
# via tensorflow
tensorboard-data-server==0.6.1
# via tensorboard
tensorboard-plugin-wit==1.8.1
# via tensorboard
tensorflow==2.8.1
# via -r dev-requirements.in
tensorflow-estimator==2.8.0
# via tensorflow
tensorflow-io-gcs-filesystem==0.27.0
# via tensorflow
termcolor==2.0.1
# via tensorflow
text-unidecode==1.3
# via
# -c requirements.txt
Expand All @@ -466,20 +509,16 @@ traitlets==5.5.0
# via
# ipython
# matplotlib-inline
typed-ast==1.5.4
# via mypy
types-toml==0.10.8
# via
# -c requirements.txt
# responses
typing-extensions==4.4.0
# via
# -c requirements.txt
# arrow
# flytekit
# importlib-metadata
# mypy
# responses
# tensorflow
# torch
# typing-inspect
typing-inspect==0.8.0
Expand All @@ -502,15 +541,20 @@ websocket-client==0.59.0
# -c requirements.txt
# docker
# docker-compose
werkzeug==2.2.2
# via tensorboard
wheel==0.37.1
# via
# -c requirements.txt
# astunparse
# flytekit
# tensorboard
wrapt==1.14.1
# via
# -c requirements.txt
# deprecated
# flytekit
# tensorflow
zipp==3.9.0
# via
# -c requirements.txt
Expand Down
7 changes: 7 additions & 0 deletions docs/source/extras.keras.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
############
Keras Type
############
.. automodule:: flytekit.extras.keras
:no-members:
:no-inherited-members:
:no-special-members:
1 change: 1 addition & 0 deletions docs/source/types.extend.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ Refer to :doc:`cookbook <advanced_custom_types>` if you'd like to contribute a F
types.builtins.file
types.builtins.directory
extras.pytorch
extras.keras
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@
from flytekit.core.workflow import ImperativeWorkflow as Workflow
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow
from flytekit.deck import Deck
from flytekit.extras import pytorch
from flytekit.extras import keras, pytorch
from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence
from flytekit.loggers import logger
from flytekit.models.common import Annotations, AuthRole, Labels
Expand Down
28 changes: 28 additions & 0 deletions flytekit/extras/keras/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Flytekit Keras
=========================================
.. currentmodule:: flytekit.extras.keras

.. autosummary::
:template: custom.rst
:toctree: generated/

"""
from flytekit.loggers import logger

# that have soft dependencies
try:
# isolate the exception to the keras import
from tensorflow import keras

_keras_installed = True
except (ImportError, OSError):
_keras_installed = False


if _keras_installed:
from .native import KerasModelTransformer, KerasSequentialTransformer
else:
logger.info(
"We won't register KerasSequentialTransformer and KerasModelTransformer because keras is not installed."
)
Loading