From a50037f46ea356ed52241320522ec5f9c8c429cc Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 14 Feb 2023 09:10:09 -0800 Subject: [PATCH 1/2] Use tf.shape for graph mode support --- distributed_embeddings/python/layers/dist_model_parallel.py | 6 +++--- tests/dist_model_parallel_test.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed_embeddings/python/layers/dist_model_parallel.py b/distributed_embeddings/python/layers/dist_model_parallel.py index db1f464..aaaae11 100644 --- a/distributed_embeddings/python/layers/dist_model_parallel.py +++ b/distributed_embeddings/python/layers/dist_model_parallel.py @@ -395,9 +395,9 @@ def _call_base(self, inputs): # pylint: disable=missing-param-doc,missing-type- local_shapes, local_splits, global_splits, flat_inputs = [], [], [], [] for rank_input_ids in self.strategy.input_ids_list: rank_inputs = [inputs[index] for index in rank_input_ids] - local_shapes.append([inp.shape for inp in rank_inputs]) + local_shapes.append([tf.shape(inp) for inp in rank_inputs]) rank_inputs = [tf.reshape(inp, [-1]) for inp in rank_inputs] - local_splits.append([inp.shape[0] for inp in rank_inputs]) + local_splits.append([tf.shape(inp)[0] for inp in rank_inputs]) global_splits.append(sum(local_splits[-1])) flat_inputs += rank_inputs inputs = tf.concat(flat_inputs, 0) @@ -405,7 +405,7 @@ def _call_base(self, inputs): # pylint: disable=missing-param-doc,missing-type- inputs = tf.reshape(inputs, [self.world_size, -1]) inputs = tf.split(inputs, local_splits[self.rank], 1) inputs = [ - tf.reshape(inp, [self.world_size * shape[0]] + shape[1:]) + tf.reshape(inp, tf.concat([[self.world_size * shape[0]], shape[1:]], 0)) for inp, shape in zip(inputs, local_shapes[self.rank]) ] else: diff --git a/tests/dist_model_parallel_test.py b/tests/dist_model_parallel_test.py index 94a2635..566264a 100644 --- a/tests/dist_model_parallel_test.py +++ b/tests/dist_model_parallel_test.py @@ -274,7 +274,7 @@ def test_column_slice_merge(self): dp_inputs, _ = self.gen_inputs(table_sizes) self.run_and_test(ref_model, dp_inputs, test_model, dp_inputs) - for tables in test_model.dist_embeddings.strategy.table_ids_list: + for tables in test_model.dist_embeddings.strategy.table_ids: self.assertEqual(len(tables), len(set(tables))) def test_column_slice_threshold(self): From bf76a45dc2fc739b40c680b2f083dd083fad36aa Mon Sep 17 00:00:00 2001 From: edknv Date: Wed, 22 Feb 2023 13:19:53 -0800 Subject: [PATCH 2/2] Fall back to tf.shape only when static shape is not available --- .../python/layers/dist_model_parallel.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/distributed_embeddings/python/layers/dist_model_parallel.py b/distributed_embeddings/python/layers/dist_model_parallel.py index aaaae11..7a45e9b 100644 --- a/distributed_embeddings/python/layers/dist_model_parallel.py +++ b/distributed_embeddings/python/layers/dist_model_parallel.py @@ -395,9 +395,13 @@ def _call_base(self, inputs): # pylint: disable=missing-param-doc,missing-type- local_shapes, local_splits, global_splits, flat_inputs = [], [], [], [] for rank_input_ids in self.strategy.input_ids_list: rank_inputs = [inputs[index] for index in rank_input_ids] - local_shapes.append([tf.shape(inp) for inp in rank_inputs]) + local_shapes.append( + [tf.shape(inp) if None in inp.shape else inp.shape for inp in rank_inputs] + ) rank_inputs = [tf.reshape(inp, [-1]) for inp in rank_inputs] - local_splits.append([tf.shape(inp)[0] for inp in rank_inputs]) + local_splits.append( + [tf.shape(inp)[0] if inp.shape[0] is None else inp.shape[0] for inp in rank_inputs] + ) global_splits.append(sum(local_splits[-1])) flat_inputs += rank_inputs inputs = tf.concat(flat_inputs, 0)