Skip to content

Commit

Permalink
Use tf.shape for graph mode support (#6)
Browse files Browse the repository at this point in the history
* Fall back to tf.shape for graph mode support when static shape is not available
  • Loading branch information
edknv authored Feb 23, 2023
1 parent 34cc5d7 commit 6c401c3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
10 changes: 7 additions & 3 deletions distributed_embeddings/python/layers/dist_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,17 +395,21 @@ def _call_base(self, inputs): # pylint: disable=missing-param-doc,missing-type-
local_shapes, local_splits, global_splits, flat_inputs = [], [], [], []
for rank_input_ids in self.strategy.input_ids_list:
rank_inputs = [inputs[index] for index in rank_input_ids]
local_shapes.append([inp.shape for inp in rank_inputs])
local_shapes.append(
[tf.shape(inp) if None in inp.shape else inp.shape for inp in rank_inputs]
)
rank_inputs = [tf.reshape(inp, [-1]) for inp in rank_inputs]
local_splits.append([inp.shape[0] for inp in rank_inputs])
local_splits.append(
[tf.shape(inp)[0] if inp.shape[0] is None else inp.shape[0] for inp in rank_inputs]
)
global_splits.append(sum(local_splits[-1]))
flat_inputs += rank_inputs
inputs = tf.concat(flat_inputs, 0)
inputs, _ = hvd.alltoall(inputs, splits=global_splits, name='inp_dp_to_mp')
inputs = tf.reshape(inputs, [self.world_size, -1])
inputs = tf.split(inputs, local_splits[self.rank], 1)
inputs = [
tf.reshape(inp, [self.world_size * shape[0]] + shape[1:])
tf.reshape(inp, tf.concat([[self.world_size * shape[0]], shape[1:]], 0))
for inp, shape in zip(inputs, local_shapes[self.rank])
]
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/dist_model_parallel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def test_column_slice_merge(self):

dp_inputs, _ = self.gen_inputs(table_sizes)
self.run_and_test(ref_model, dp_inputs, test_model, dp_inputs)
for tables in test_model.dist_embeddings.strategy.table_ids_list:
for tables in test_model.dist_embeddings.strategy.table_ids:
self.assertEqual(len(tables), len(set(tables)))

def test_column_slice_threshold(self):
Expand Down

0 comments on commit 6c401c3

Please sign in to comment.