Skip to content

Commit

Permalink
Enable TensorFlowModelDataset to overwrite existing model, and add …
Browse files Browse the repository at this point in the history
…support for `tf.device` (kedro-org#1915)

* Fix issue with save operation. Add gpu option

Signed-off-by: William Caicedo <williamc@movio.co>

* Add tests

Signed-off-by: William Caicedo <williamc@movio.co>

* Update RELEASE.md

Signed-off-by: William Caicedo <williamc@movio.co>

* Update test description

Signed-off-by: William Caicedo <williamc@movio.co>

* Remove double slash and overwrite flag in fsspec.put method invocation

Signed-off-by: William Caicedo <williamc@movio.co>

* Allow to explicitly set device name

Signed-off-by: William Caicedo <williamc@movio.co>

* Update RELEASE.md

Co-authored-by: Deepyaman Datta <deepyaman.datta@utexas.edu>
Signed-off-by: William Caicedo <williamc@movio.co>

* Update docs

Signed-off-by: William Caicedo <williamc@movio.co>
Co-authored-by: Deepyaman Datta <deepyaman.datta@utexas.edu>
Signed-off-by: Minh Le <m.le@elsevier.com>
  • Loading branch information
2 people authored and mle-els committed Nov 7, 2022
1 parent 9fc83cd commit 1730e9c
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 6 deletions.
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
## Bug fixes and other changes
* Fixed `kedro micropkg pull` for packages on PyPI.
* Fixed `format` in `save_args` for `SparkHiveDataSet`, previously it didn't allow you to save it as delta format.
* Fixed save errors in `TensorFlowModelDataset` when used without versioning; previously, it wouldn't overwrite an existing model.
* Added support for `tf.device` in `TensorFlowModelDataset`.
* Updated error message for `VersionNotFoundError` to handle insufficient permission issues for cloud storage.
* Updated Experiment Tracking docs with working examples.

Expand Down
2 changes: 2 additions & 0 deletions kedro/extras/datasets/tensorflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)
example_tensorflow_data:
type: tensorflow.TensorFlowModelDataset
filepath: data/08_reporting/tf_model_dirname
load_args:
tf_device: "/CPU:0" # optional
```
Contributed by (Aleks Hughes)[https://github.com/w0rdsm1th].
16 changes: 10 additions & 6 deletions kedro/extras/datasets/tensorflow/tensorflow_model_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
import copy
import tempfile
from pathlib import Path, PurePath, PurePosixPath
from pathlib import PurePath, PurePosixPath
from typing import Any, Dict

import fsspec
Expand Down Expand Up @@ -118,15 +118,17 @@ def _load(self) -> tf.keras.Model:
self._fs.get(load_path, path, recursive=True)

# Pass the local temporary directory/file path to keras.load_model
return tf.keras.models.load_model(path, **self._load_args)
device_name = self._load_args.pop("tf_device", None)
if device_name:
with tf.device(device_name):
model = tf.keras.models.load_model(path, **self._load_args)
else:
model = tf.keras.models.load_model(path, **self._load_args)
return model

def _save(self, data: tf.keras.Model) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

# Make sure all intermediate directories are created.
save_dir = Path(save_path).parent
save_dir.mkdir(parents=True, exist_ok=True)

with tempfile.TemporaryDirectory(prefix=self._tmp_prefix) as path:
if self._is_h5:
path = str(PurePath(path) / TEMPORARY_H5_FILE)
Expand All @@ -138,6 +140,8 @@ def _save(self, data: tf.keras.Model) -> None:
if self._is_h5:
self._fs.copy(path, save_path)
else:
if self._fs.exists(save_path):
self._fs.rm(save_path, recursive=True)
self._fs.put(path, save_path, recursive=True)

def _exists(self) -> bool:
Expand Down
54 changes: 54 additions & 0 deletions tests/extras/datasets/tensorflow/test_tensorflow_model_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,24 @@ def dummy_tf_base_model(dummy_x_train, dummy_y_train, tf):
return model


@pytest.fixture
def dummy_tf_base_model_new(dummy_x_train, dummy_y_train, tf):
# dummy 2 layer model
inputs = tf.keras.Input(shape=(2, 1))
x = tf.keras.layers.Dense(1)(inputs)
x = tf.keras.layers.Dense(1)(x)
outputs = tf.keras.layers.Dense(1)(x)

model = tf.keras.Model(inputs=inputs, outputs=outputs, name="2_layer_dummy")
model.compile("rmsprop", "mse")
model.fit(dummy_x_train, dummy_y_train, batch_size=64, epochs=1)
# from https://www.tensorflow.org/guide/keras/save_and_serialize
# Reset metrics before saving so that loaded model has same state,
# since metric states are not preserved by Model.save_weights
model.reset_metrics()
return model


@pytest.fixture
def dummy_tf_subclassed_model(dummy_x_train, dummy_y_train, tf):
"""Demonstrate that own class models cannot be saved
Expand Down Expand Up @@ -246,6 +264,19 @@ def test_exists_with_exception(self, tf_model_dataset, mocker):
mocker.patch("kedro.io.core.get_filepath_str", side_effct=DataSetError)
assert not tf_model_dataset.exists()

def test_save_and_overwrite_existing_model(
self, tf_model_dataset, dummy_tf_base_model, dummy_tf_base_model_new
):
"""Test models are correcty overwritten."""
tf_model_dataset.save(dummy_tf_base_model)

tf_model_dataset.save(dummy_tf_base_model_new)

reloaded = tf_model_dataset.load()

assert len(dummy_tf_base_model.layers) != len(reloaded.layers)
assert len(dummy_tf_base_model_new.layers) == len(reloaded.layers)


class TestTensorFlowModelDatasetVersioned:
"""Test suite with versioning argument passed into TensorFlowModelDataset creator"""
Expand Down Expand Up @@ -385,3 +416,26 @@ def test_versioning_existing_dataset(
assert tf_model_dataset._filepath == versioned_tf_model_dataset._filepath
versioned_tf_model_dataset.save(dummy_tf_base_model)
assert versioned_tf_model_dataset.exists()

def test_save_and_load_with_device(
self,
dummy_tf_base_model,
dummy_x_test,
filepath,
tensorflow_model_dataset,
load_version,
save_version,
):
"""Test versioned TensorflowModelDataset can load models using an explicit tf_device"""
hdf5_dataset = tensorflow_model_dataset(
filepath=filepath,
load_args={"tf_device": "/CPU:0"},
version=Version(load_version, save_version),
)

predictions = dummy_tf_base_model.predict(dummy_x_test)
hdf5_dataset.save(dummy_tf_base_model)

reloaded = hdf5_dataset.load()
new_predictions = reloaded.predict(dummy_x_test)
np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)

0 comments on commit 1730e9c

Please sign in to comment.