From 1730e9c3ddf9db709d4bb5011689219be1313acf Mon Sep 17 00:00:00 2001 From: williamcaicedo Date: Sat, 22 Oct 2022 17:14:58 +1300 Subject: [PATCH] Enable `TensorFlowModelDataset` to overwrite existing model, and add support for `tf.device` (#1915) * Fix issue with save operation. Add gpu option Signed-off-by: William Caicedo * Add tests Signed-off-by: William Caicedo * Update RELEASE.md Signed-off-by: William Caicedo * Update test description Signed-off-by: William Caicedo * Remove double slash and overwrite flag in fsspec.put method invocation Signed-off-by: William Caicedo * Allow to explicitly set device name Signed-off-by: William Caicedo * Update RELEASE.md Co-authored-by: Deepyaman Datta Signed-off-by: William Caicedo * Update docs Signed-off-by: William Caicedo Co-authored-by: Deepyaman Datta Signed-off-by: Minh Le --- RELEASE.md | 2 + kedro/extras/datasets/tensorflow/README.md | 2 + .../tensorflow/tensorflow_model_dataset.py | 16 +++--- .../test_tensorflow_model_dataset.py | 54 +++++++++++++++++++ 4 files changed, 68 insertions(+), 6 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index c7af457002..13ac55056f 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -17,6 +17,8 @@ ## Bug fixes and other changes * Fixed `kedro micropkg pull` for packages on PyPI. * Fixed `format` in `save_args` for `SparkHiveDataSet`, previously it didn't allow you to save it as delta format. +* Fixed save errors in `TensorFlowModelDataset` when used without versioning; previously, it wouldn't overwrite an existing model. +* Added support for `tf.device` in `TensorFlowModelDataset`. * Updated error message for `VersionNotFoundError` to handle insufficient permission issues for cloud storage. * Updated Experiment Tracking docs with working examples. diff --git a/kedro/extras/datasets/tensorflow/README.md b/kedro/extras/datasets/tensorflow/README.md index 5f079787b1..704d164977 100644 --- a/kedro/extras/datasets/tensorflow/README.md +++ b/kedro/extras/datasets/tensorflow/README.md @@ -27,6 +27,8 @@ np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6) example_tensorflow_data: type: tensorflow.TensorFlowModelDataset filepath: data/08_reporting/tf_model_dirname + load_args: + tf_device: "/CPU:0" # optional ``` Contributed by (Aleks Hughes)[https://github.com/w0rdsm1th]. diff --git a/kedro/extras/datasets/tensorflow/tensorflow_model_dataset.py b/kedro/extras/datasets/tensorflow/tensorflow_model_dataset.py index 441ffd455f..70331c3d9b 100644 --- a/kedro/extras/datasets/tensorflow/tensorflow_model_dataset.py +++ b/kedro/extras/datasets/tensorflow/tensorflow_model_dataset.py @@ -3,7 +3,7 @@ """ import copy import tempfile -from pathlib import Path, PurePath, PurePosixPath +from pathlib import PurePath, PurePosixPath from typing import Any, Dict import fsspec @@ -118,15 +118,17 @@ def _load(self) -> tf.keras.Model: self._fs.get(load_path, path, recursive=True) # Pass the local temporary directory/file path to keras.load_model - return tf.keras.models.load_model(path, **self._load_args) + device_name = self._load_args.pop("tf_device", None) + if device_name: + with tf.device(device_name): + model = tf.keras.models.load_model(path, **self._load_args) + else: + model = tf.keras.models.load_model(path, **self._load_args) + return model def _save(self, data: tf.keras.Model) -> None: save_path = get_filepath_str(self._get_save_path(), self._protocol) - # Make sure all intermediate directories are created. - save_dir = Path(save_path).parent - save_dir.mkdir(parents=True, exist_ok=True) - with tempfile.TemporaryDirectory(prefix=self._tmp_prefix) as path: if self._is_h5: path = str(PurePath(path) / TEMPORARY_H5_FILE) @@ -138,6 +140,8 @@ def _save(self, data: tf.keras.Model) -> None: if self._is_h5: self._fs.copy(path, save_path) else: + if self._fs.exists(save_path): + self._fs.rm(save_path, recursive=True) self._fs.put(path, save_path, recursive=True) def _exists(self) -> bool: diff --git a/tests/extras/datasets/tensorflow/test_tensorflow_model_dataset.py b/tests/extras/datasets/tensorflow/test_tensorflow_model_dataset.py index 51b8b6ab9b..80fbc02c7f 100644 --- a/tests/extras/datasets/tensorflow/test_tensorflow_model_dataset.py +++ b/tests/extras/datasets/tensorflow/test_tensorflow_model_dataset.py @@ -94,6 +94,24 @@ def dummy_tf_base_model(dummy_x_train, dummy_y_train, tf): return model +@pytest.fixture +def dummy_tf_base_model_new(dummy_x_train, dummy_y_train, tf): + # dummy 2 layer model + inputs = tf.keras.Input(shape=(2, 1)) + x = tf.keras.layers.Dense(1)(inputs) + x = tf.keras.layers.Dense(1)(x) + outputs = tf.keras.layers.Dense(1)(x) + + model = tf.keras.Model(inputs=inputs, outputs=outputs, name="2_layer_dummy") + model.compile("rmsprop", "mse") + model.fit(dummy_x_train, dummy_y_train, batch_size=64, epochs=1) + # from https://www.tensorflow.org/guide/keras/save_and_serialize + # Reset metrics before saving so that loaded model has same state, + # since metric states are not preserved by Model.save_weights + model.reset_metrics() + return model + + @pytest.fixture def dummy_tf_subclassed_model(dummy_x_train, dummy_y_train, tf): """Demonstrate that own class models cannot be saved @@ -246,6 +264,19 @@ def test_exists_with_exception(self, tf_model_dataset, mocker): mocker.patch("kedro.io.core.get_filepath_str", side_effct=DataSetError) assert not tf_model_dataset.exists() + def test_save_and_overwrite_existing_model( + self, tf_model_dataset, dummy_tf_base_model, dummy_tf_base_model_new + ): + """Test models are correcty overwritten.""" + tf_model_dataset.save(dummy_tf_base_model) + + tf_model_dataset.save(dummy_tf_base_model_new) + + reloaded = tf_model_dataset.load() + + assert len(dummy_tf_base_model.layers) != len(reloaded.layers) + assert len(dummy_tf_base_model_new.layers) == len(reloaded.layers) + class TestTensorFlowModelDatasetVersioned: """Test suite with versioning argument passed into TensorFlowModelDataset creator""" @@ -385,3 +416,26 @@ def test_versioning_existing_dataset( assert tf_model_dataset._filepath == versioned_tf_model_dataset._filepath versioned_tf_model_dataset.save(dummy_tf_base_model) assert versioned_tf_model_dataset.exists() + + def test_save_and_load_with_device( + self, + dummy_tf_base_model, + dummy_x_test, + filepath, + tensorflow_model_dataset, + load_version, + save_version, + ): + """Test versioned TensorflowModelDataset can load models using an explicit tf_device""" + hdf5_dataset = tensorflow_model_dataset( + filepath=filepath, + load_args={"tf_device": "/CPU:0"}, + version=Version(load_version, save_version), + ) + + predictions = dummy_tf_base_model.predict(dummy_x_test) + hdf5_dataset.save(dummy_tf_base_model) + + reloaded = hdf5_dataset.load() + new_predictions = reloaded.predict(dummy_x_test) + np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)