diff --git a/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py b/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py index b1918ef7a..c2426f5f9 100644 --- a/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py +++ b/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py @@ -466,12 +466,12 @@ def export_to_savedmodel(model, savedmodel_dir): # TFRA modify the Keras save function with a patch. # !!!! Run save_model function in all rank !!!! - de.keras.models.de_hvd_save_model(model, - savedmodel_dir, - overwrite=True, - include_optimizer=True, - save_traces=True, - options=save_options) + de.keras.models.de_save_model(model, + savedmodel_dir, + overwrite=True, + include_optimizer=True, + save_traces=True, + options=save_options) def export_for_serving(model, export_dir): @@ -521,7 +521,7 @@ def serve(*args, **kwargs): # TFRA modify the Keras save function with a patch. # !!!! Run save_model function in all rank !!!! - de.keras.models.de_hvd_save_model( + de.keras.models.de_save_model( model, export_dir, overwrite=True, diff --git a/docs/api_docs/tfra/dynamic_embedding/keras/layers/HvdAllToAllEmbedding.md b/docs/api_docs/tfra/dynamic_embedding/keras/layers/HvdAllToAllEmbedding.md index 24c33db66..5215d9b71 100644 --- a/docs/api_docs/tfra/dynamic_embedding/keras/layers/HvdAllToAllEmbedding.md +++ b/docs/api_docs/tfra/dynamic_embedding/keras/layers/HvdAllToAllEmbedding.md @@ -83,7 +83,7 @@ In addition, we also provide parameter initialization and save callback related [`dynamic_embedding.keras.callbacks.DEHvdModelCheckpoint`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py) -[`dynamic_embedding.keras.models.de_hvd_save_model`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py) +[`dynamic_embedding.keras.models.de_save_model`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py) [`dynamic_embedding.train.DEHvdModelCheckpoint`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py index 47252ec6c..5aa2af8c2 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py @@ -535,6 +535,7 @@ class HvdAllToAllEmbedding(BasicEmbedding): def __init__(self, with_unique=True, + with_secondary_unique=True, mpi_size=None, batch_size=None, *args, @@ -547,6 +548,7 @@ def __init__(self, ) self.hvd = hvd self.with_unique = with_unique + self.with_secondary_unique = with_secondary_unique self.batch_size = batch_size if mpi_size is None: self._mpi_size = self.hvd.size() @@ -605,7 +607,14 @@ def __alltoall_embedding_lookup__(self, ids): reloc_ids, remote_sizes, gather_indices = self.__relocate_dense_feature__( ids, batch_size=batch_size_runtime) - lookup_result = de.shadow_ops.embedding_lookup(self.shadow, reloc_ids) + if self.with_secondary_unique: + with tf.name_scope(self.name + "/EmbeddingWithUnique"): + reloc_unique_ids, reloc_unique_idx = tf.unique(reloc_ids) + reloc_unique_embeddings = de.shadow_ops.embedding_lookup( + self.shadow, reloc_unique_ids) + lookup_result = tf.gather(reloc_unique_embeddings, reloc_unique_idx) + else: + lookup_result = de.shadow_ops.embedding_lookup(self.shadow, reloc_ids) lookup_result, _ = self.hvd.alltoall(lookup_result, splits=remote_sizes) recover_shape = tf.concat((input_shape, (self.embedding_size,)), axis=0) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py index 736e3a2c7..ee54880f4 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py @@ -26,6 +26,7 @@ from tensorflow.python.keras.saving.saved_model import save as tf_saved_model_save from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging +from tensorflow.python.saved_model.save_options import SaveOptions tf_original_save_func = tf_saved_model_save.save if keras_saved_model_save is not None: @@ -56,6 +57,11 @@ def _de_keras_save_func(original_save_func, except: hvd = None + if hvd is not None: + filepath = hvd.broadcast_object(filepath, + root_rank=0, + name='de_hvd_broadcast_filepath') + call_original_save_func = functools.partial( original_save_func, model=model, @@ -68,8 +74,9 @@ def _de_keras_save_func(original_save_func, *args, **kwargs) - def _traverse_emb_layers_and_save(hvd_rank): - de_dir = os.path.join(filepath, "variables", "TFRADynamicEmbedding") + de_dir = os.path.join(filepath, "variables", "TFRADynamicEmbedding") + + def _check_saveable_and_redirect_new_de_dir(): for var in model.variables: if not hasattr(var, "params"): continue @@ -85,33 +92,50 @@ def _traverse_emb_layers_and_save(hvd_rank): "It will allow TFRA load KV files when Embedding tensor parallel. " f"The embedding shards at each horovod rank are now temporarily stored in {de_dir}" ) - else: - if not isinstance(de_var.kv_creator.saver, de.FileSystemSaver): - # This function only serves FileSystemSaver. - continue - if hvd_rank == 0: - # FileSystemSaver works well at rank 0. - continue - # save Dynamic Embedding Parameters - de_var.save_to_file_system(dirpath=de_dir, - proc_size=hvd.size(), - proc_rank=hvd.rank()) - # save optimizer parameters of Dynamic Embedding - if include_optimizer is True: - de_opt_vars = a2a_emb.optimizer_vars.as_list() if hasattr( - a2a_emb.optimizer_vars, "as_list") else a2a_emb.optimizer_vars - for de_opt_var in de_opt_vars: - de_opt_var.save_to_file_system(dirpath=de_dir, - proc_size=hvd.size(), - proc_rank=hvd.rank()) + if not isinstance(de_var.kv_creator.saver, de.FileSystemSaver): + # This function only serves FileSystemSaver. + continue + # Redirect new de_dir + if hasattr(de_var, 'saveable'): + de_var.saveable._saver_config.save_path = de_dir + def _traverse_emb_layers_and_save(hvd_rank=0): + for var in model.variables: + if not hasattr(var, "params"): + continue + if not hasattr(var.params, "_created_in_class"): + continue + de_var = var.params + a2a_emb = de_var._created_in_class + if de_var._saveable_object_creator is not None: + if not isinstance(de_var.kv_creator.saver, de.FileSystemSaver): + # This function only serves FileSystemSaver. + continue + # save optimizer parameters of Dynamic Embedding + if include_optimizer is True: + de_opt_vars = a2a_emb.optimizer_vars.as_list() if hasattr( + a2a_emb.optimizer_vars, "as_list") else a2a_emb.optimizer_vars + for de_opt_var in de_opt_vars: + de_opt_var.save_to_file_system(dirpath=de_dir, + proc_size=hvd.size(), + proc_rank=hvd.rank()) + if hvd_rank == 0: + # FileSystemSaver works well at rank 0. + continue + # save Dynamic Embedding Parameters + de_var.save_to_file_system(dirpath=de_dir, + proc_size=hvd.size(), + proc_rank=hvd.rank()) + + _check_saveable_and_redirect_new_de_dir() if hvd is None: call_original_save_func() + _traverse_emb_layers_and_save(0) else: if hvd.rank() == 0: call_original_save_func() _traverse_emb_layers_and_save(hvd.rank()) - hvd.join() # Sync for avoiding data conflict + hvd.join() # Sync for avoiding rank conflict def de_hvd_save_model(model, @@ -123,11 +147,37 @@ def de_hvd_save_model(model, save_traces=True, *args, **kwargs): + return de_save_model(model=model, + filepath=filepath, + overwrite=True, + include_optimizer=True, + signatures=None, + options=None, + save_traces=True, + *args, + **kwargs) + + +def de_save_model(model, + filepath, + overwrite=True, + include_optimizer=True, + signatures=None, + options=None, + save_traces=True, + *args, + **kwargs): if keras_saved_model_save is not None: _save_handle = functools.partial(_de_keras_save_func, keras_original_save_func) else: _save_handle = functools.partial(_de_keras_save_func, tf_original_save_func) + if options is None: + options = SaveOptions(namespace_whitelist=['TFRA']) + elif isinstance(options, SaveOptions) and hasattr(options, + 'namespace_whitelist'): + options.namespace_whitelist.append('TFRA') + return _save_handle(model, filepath, overwrite, diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py index f6f974170..b3c328591 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py @@ -327,9 +327,9 @@ def common_all_to_all_embedding_trainable_v2(self, base_opt, test_opt, name): shutil.rmtree(save_dir) hvd.join() # Sync for avoiding files conflict # base_model.save(save_dir, options=save_options) - de.keras.models.de_hvd_save_model(base_model, - save_dir, - options=save_options) + de.keras.models.de_save_model(base_model, + save_dir, + options=save_options) ckpt = de.train.DECheckpoint( my_model=base_model) # Test custom model key "my_model" ckpt.save(save_dir + '/ckpt/test') @@ -407,31 +407,38 @@ def call(self, x): return self.l2(out) def check_TFRADynamicEmbedding_directory(save_dir, - save_it, + save_it=None, should_be_exist=True): hvd_size = hvd.size() if hvd_size <= 1: hvd_size = 1 + base_dir = os.path.join(save_dir, 'variables', 'TFRADynamicEmbedding') + if save_it is not None: + base_dir = os.path.join(save_dir, f'TFRADynamicEmbedding-{save_it}') for tag in ['keys', 'values']: for rank in range(hvd_size): self.assertTrue(not (os.path.exists( - save_dir + - f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_mht_1of1_rank{rank}_size{hvd_size}-{tag}' - ) ^ should_be_exist)) + base_dir + + f'/{name}-parameter_mht_1of1_rank{rank}_size{hvd_size}-{tag}') ^ + should_be_exist)) self.assertTrue(not (os.path.exists( - save_dir + - f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_DynamicEmbedding_keras_adam_lazy_build-shadow_m_mht_1of1_rank{rank}_size{hvd_size}-{tag}' + base_dir + + f'/{name}-parameter_DynamicEmbedding_keras_adam_lazy_build-shadow_m_mht_1of1_rank{rank}_size{hvd_size}-{tag}' ) ^ should_be_exist)) - # f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_no_compile_model_DynamicEmbedding_keras_adam_lazy_build-shadow_m_mht_1of1_rank{rank}_size{hvd_size}-{tag}' + # f'/{name}-parameter_no_compile_model_DynamicEmbedding_keras_adam_lazy_build-shadow_m_mht_1of1_rank{rank}_size{hvd_size}-{tag}' self.assertTrue(not (os.path.exists( - save_dir + - f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_DynamicEmbedding_keras_adam_lazy_build-shadow_v_mht_1of1_rank{rank}_size{hvd_size}-{tag}' + base_dir + + f'/{name}-parameter_DynamicEmbedding_keras_adam_lazy_build-shadow_v_mht_1of1_rank{rank}_size{hvd_size}-{tag}' ) ^ should_be_exist)) - # f'/TFRADynamicEmbedding-{save_it}/{name}-parameter_no_compile_model_DynamicEmbedding_keras_adam_lazy_build-shadow_v_mht_1of1_rank{rank}_size{hvd_size}-{tag}' + # f'/{name}-parameter_no_compile_model_DynamicEmbedding_keras_adam_lazy_build-shadow_v_mht_1of1_rank{rank}_size{hvd_size}-{tag}' with tf.device("/{}:{}".format(_device, _device_id)): x = tf.reshape(tf.range(0, 32, dtype=tf.int64), [32, 1]) y = tf.random.uniform(shape=[32, 1]) + base_de_emb_standard = {} + base_de_opt_standard = {} + new_de_emb_compared = {} + new_de_opt_compared = {} save_dir = self.get_temp_dir() @@ -454,13 +461,16 @@ def check_TFRADynamicEmbedding_directory(save_dir, l.params.upsert(x * 10, tf.random.uniform(shape=[32, 1, dim])) emb_size = l.params.size() emb_keys, emb_values = l.params.export() + base_de_emb_standard[l.name] = (emb_size, emb_keys, emb_values) break for v in base_opt.variables(): if name in v.name: v.params.upsert(x * 10, tf.random.uniform(shape=[32, 1, dim])) opt_size = v.params.size() - opt_keys, opt_values = l.params.export() - break + opt_keys, opt_values = v.params.export() + base_de_opt_standard[v._shared_name.split('/')[-1]] = (opt_size, + opt_keys, + opt_values) manager.save() if hvd.rank() == 0: check_TFRADynamicEmbedding_directory(save_dir, @@ -491,7 +501,9 @@ def check_TFRADynamicEmbedding_directory(save_dir, new_model.compile(optimizer=new_opt, loss='mean_absolute_error') new_model(x) # Build vairiables try: - new_opt._create_all_weights(new_model.variables) + new_opt._create_all_weights([ + new_model.variables[0] + ]) # Create DE slot variable from DE shadow variable except: #TODO(MoFHejia) raise ValueError: Cannot convert a partially known TensorShape to a Tensor. pass @@ -499,23 +511,92 @@ def check_TFRADynamicEmbedding_directory(save_dir, if name in l.name: new_emb_size = l.params.size() new_emb_keys, new_emb_values = l.params.export() + new_de_emb_compared[l.name] = (new_emb_size, new_emb_keys, + new_emb_values) break for v in new_opt.variables(): if name in v.name: new_opt_size = v.params.size() - new_opt_keys, new_opt_values = l.params.export() + new_opt_keys, new_opt_values = v.params.export() + new_de_opt_compared[v._shared_name.split('/')[-1]] = (new_opt_size, + new_opt_keys, + new_opt_values) + + for de_l_name in base_de_emb_standard.keys(): + self.assertEqual(base_de_emb_standard[de_l_name][0], + new_de_emb_compared[de_l_name][0]) + self.assertAllEqual(np.sort(base_de_emb_standard[de_l_name][1], axis=0), + np.sort(new_de_emb_compared[de_l_name][1], axis=0)) + self.assertAllClose(np.sort(base_de_emb_standard[de_l_name][2], axis=0), + np.sort(new_de_emb_compared[de_l_name][2], axis=0)) + for opt_v_name in base_de_opt_standard.keys(): + self.assertEqual(base_de_opt_standard[opt_v_name][0], + new_de_opt_compared[opt_v_name][0]) + self.assertAllEqual( + np.sort(base_de_opt_standard[opt_v_name][1], axis=0), + np.sort(new_de_opt_compared[opt_v_name][1], axis=0)) + self.assertAllClose( + np.sort(base_de_opt_standard[opt_v_name][2], axis=0), + np.sort(new_de_opt_compared[opt_v_name][2], axis=0)) + + extra_save_dir = self.get_temp_dir() + '/extra_save_dir' + de.keras.models.de_save_model(new_model, extra_save_dir) + if hvd.rank() == 0: + check_TFRADynamicEmbedding_directory(extra_save_dir) + del new_opt + del new_model + del new_ckpt + tf.keras.backend.clear_session() + tf.compat.v1.reset_default_graph() + new_saved_model = NoCompileModel('zeros') + new_saved_opt = Adam(1.2) + new_saved_opt = de.DynamicEmbeddingOptimizer(new_saved_opt, + synchronous=True) + new_saved_model.compile(optimizer=new_saved_opt, + loss='mean_absolute_error') + new_saved_model(x) # Build vairiables + try: + new_opt._create_all_weights([ + new_model.variables[0] + ]) # Create DE slot variable from DE shadow variable + except: + #TODO(MoFHejia) raise ValueError: Cannot convert a partially known TensorShape to a Tensor. + pass + extra_save_dir = hvd.broadcast_object( + extra_save_dir, root_rank=0, name='de_utest_hvd_broadcast_filepath' + ) # All ranks should share same save directory + new_saved_model.load_weights(extra_save_dir + '/variables/variables') + for l in new_saved_model.layers: + if name in l.name: + new_emb_size = l.params.size() + new_emb_keys, new_emb_values = l.params.export() + new_de_emb_compared[l.name] = (new_emb_size, new_emb_keys, + new_emb_values) break - - self.assertEqual(emb_size, new_emb_size) - self.assertEqual(opt_size, new_opt_size) - self.assertAllEqual(np.sort(emb_keys, axis=0), - np.sort(new_emb_keys, axis=0)) - self.assertAllClose(np.sort(emb_values, axis=0), - np.sort(new_emb_values, axis=0)) - self.assertAllEqual(np.sort(opt_keys, axis=0), - np.sort(new_opt_keys, axis=0)) - self.assertAllClose(np.sort(opt_values, axis=0), - np.sort(new_opt_values, axis=0)) + for v in new_saved_opt.variables(): + if name in v.name: + new_opt_size = v.params.size() + new_opt_keys, new_opt_values = l.params.export() + new_de_opt_compared[v._shared_name.split('/')[-1]] = (new_opt_size, + new_opt_keys, + new_opt_values) + + for de_l_name in base_de_emb_standard.keys(): + self.assertEqual(base_de_emb_standard[de_l_name][0], + new_de_emb_compared[de_l_name][0]) + self.assertAllEqual(np.sort(base_de_emb_standard[de_l_name][1], axis=0), + np.sort(new_de_emb_compared[de_l_name][1], axis=0)) + self.assertAllClose(np.sort(base_de_emb_standard[de_l_name][2], axis=0), + np.sort(new_de_emb_compared[de_l_name][2], axis=0)) + for opt_v_name in base_de_opt_standard.keys(): + self.assertEqual(base_de_opt_standard[opt_v_name][0], + new_de_opt_compared[opt_v_name][0]) + self.assertAllEqual( + np.sort(base_de_opt_standard[opt_v_name][1], axis=0), + np.sort(new_de_opt_compared[opt_v_name][1], axis=0)) + self.assertAllClose( + np.sort(base_de_opt_standard[opt_v_name][2], axis=0), + np.sort(new_de_opt_compared[opt_v_name][2], axis=0)) if __name__ == "__main__":