Skip to content

Commit

Permalink
version 0.1 updates
Browse files Browse the repository at this point in the history
  • Loading branch information
FDecaYed committed Jun 3, 2022
1 parent 4e74e34 commit d40c589
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 39 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ See more details at [User Guide](https://nvidia-merlin.github.io/distributed-emb

## Installation
### Requirements
Python 3, CUDA 11 or newer, TensorFlow 2.6.0 or newer
Python 3, CUDA 11 or newer, TensorFlow 2
### Containers ###
You can build inside 22.03 or later NGC TF2 [image](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tensorflow):
```bash
Expand Down
1 change: 1 addition & 0 deletions build_pip_pkg.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ echo "=== Copy TensorFlow Custom op files"
cp setup.py "${TMPDIR}"
cp MANIFEST.in "${TMPDIR}"
cp requirements.txt "${TMPDIR}"
cp version.txt "${TMPDIR}"
rsync -avm -L --exclude='*_test.py' distributed_embeddings "${TMPDIR}"

pushd ${TMPDIR}
Expand Down
1 change: 1 addition & 0 deletions distributed_embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
"""Distributed embedding API."""

from distributed_embeddings.python.ops.embedding_lookup_ops import embedding_lookup
from .version import __version__
161 changes: 125 additions & 36 deletions distributed_embeddings/python/layers/dist_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Distributed Embedding layers and utils"""
import math
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.utils import tf_utils
import horovod.tensorflow as hvd
from distributed_embeddings.python.ops.embedding_lookup_ops import read_var_no_copy
from .embedding import Embedding


Expand Down Expand Up @@ -48,19 +51,21 @@ def __init__(self,
self.input_ids_list = [list(range(len(input_table_map)))]
self.table_ids_list = [list(range(len(embeddings)))]
return

# Create (maybe) sliced configs
sliced_configs, self.sliced_out_ranges = self.create_sliced_configs(
world_size, column_slice_threshold, input_table_map)
# Apply strategy and save nested list containing table indices by rank
self.table_ids_list = self.apply_stragety(strategy, world_size, sliced_configs)
# Nested list to split embedding output from each rank into tables
self.widths_list = []

# Nested list containing input indices by rank
self.input_ids_list = []
# Nested list containing local input to local table map by rank
self.local_map_list = []
# Nested list containing local configs by rank
self.local_configs_list = []
# All of local widths ordered by rank flat into single list
self.widths_list_flat = []
# Each worker loop over all rank to get global view of strategy
for rank_table_ids in self.table_ids_list:
# calculate stats needed for each rank
Expand All @@ -73,20 +78,22 @@ def __init__(self,
rank_input_ids.append(k)
rank_input_map.append(m)
self.local_configs_list.append(rank_configs)
self.widths_list.append(rank_widths)
self.widths_list_flat += rank_widths
self.input_ids_list.append(rank_input_ids)
self.local_map_list.append(rank_input_map)
# List of total embedding widths to split embedding output by rank after alltoall
self.total_local_widths = [sum(widths) for widths in self.widths_list]

# List that maps local inputs to local table
self.local_input_table_map = self.local_map_list[rank]

# flatten self.input_ids_list
worker_order_input_ids = [item for sublist in self.input_ids_list for item in sublist]

# List of indices to shuffle worker ordered embedding outputs back to original order
self.rev_global_input_ids = [
index
for _, index in sorted(zip(worker_order_input_ids, range(len(worker_order_input_ids))))
]

# List of configs to create local embedding layers
self.local_configs = self.local_configs_list[rank]

Expand Down Expand Up @@ -286,18 +293,17 @@ def _call_base(self, inputs): # pylint: disable=missing-param-doc,missing-type-
for m, inp in zip(self.strategy.local_input_table_map, inputs)
]

# concat last axis to make all2all slice correct, and reshape to make later split easier
# TODO(Deyu): current assume 2D with same batch for all output, ideally should support general case
local_bs = inputs[0].shape[0] // self.world_size
mp_outs = tf.reshape(tf.concat(mp_outs, axis=-1), [-1, local_bs])
mp_outs = [tf.reshape(mp_out, [self.world_size, -1]) for mp_out in mp_outs]
mp_outs = tf.reshape(tf.concat(mp_outs, axis=1), [-1])
# cast before alltoall according to dtype policy
mp_outs = tf.cast(mp_outs, self.compute_dtype)
dp_outs = hvd.alltoall(mp_outs, name='out_mp_to_dp')
dp_outs = [
tf.reshape(t, [local_bs, -1]) for t in tf.split(dp_outs, self.strategy.total_local_widths)
]
# split each worker result and re-order using id
worker_order_res = []
for dp_out, widths in zip(dp_outs, self.strategy.widths_list):
worker_order_res += tf.split(dp_out, widths, 1)
local_bs = inputs[0].shape[0] // self.world_size
num_elements = [local_bs * item for item in self.strategy.widths_list_flat]
split_outs = tf.split(dp_outs, num_elements)
worker_order_res = [tf.reshape(split_out, [local_bs, -1]) for split_out in split_outs]

# reorder outputs to be same as inputs order
result = [worker_order_res[index] for index in self.strategy.rev_global_input_ids]
return result
Expand All @@ -309,70 +315,149 @@ def _concat_column_slice_outputs(self, outs):
outs[start:end] = [tf.concat(outs[start:end], axis=-1)]
return outs

def set_weights(self, weights): # pylint: disable=missing-param-doc,missing-type-doc
def set_weights(self, weights, chunk=134217728, use_lock=False):
"""Sets the weights of the layer, from NumPy arrays.
This override expects global weights for all tables as input.
Args:
weights (list): list containing global weights for all table.
item in the list can be either numpy array or file path to load from.
chunk (int): max number of elements per chunk when set weight on GPU by chunks.
this will be round to number of rows base on weight shape.
use_lock (bool): If true, set weights rank by rank in lock step to avoid OOM. Default False.
"""
if self.world_size == 1:
sliced_local_weights = weights
else:
if use_lock:
for _ in range(self.rank):
hvd.broadcast_object(0)

if self.world_size > 1:
slice_info = [[rank_tids.count(tid)
for rank_tids in self.strategy.table_ids_list]
for tid in range(len(weights))]
local_weights = [weights[index] for index in self.strategy.table_ids_list[self.rank]]
weights = [weights[index] for index in self.strategy.table_ids_list[self.rank]]
if isinstance(weights[0], str):
weights = [np.load(file=path, mmap_mode='r') for path in weights]
local_info = [slice_info[index] for index in self.strategy.table_ids_list[self.rank]]
# array to handle multiple slice into same table case
# TODO(Deyu): avoid this by merge those table again after find strategy
rank_ids = self.strategy.table_ids_list[self.rank]
index_offset = [rank_ids[:i].count(rank_id) for i, rank_id in enumerate(rank_ids)]

def _slice_weight_for_rank(weight, info, global_rank):
def _slice_weight_for_rank(weight, info, global_rank, offset):
num_columns = weight.shape[1]
num_slices = sum(info)
column_per_slice = num_columns // num_slices
remainder = num_columns % num_slices
rank = info[:global_rank].count(1)
rank = sum(info[:global_rank]) + offset

start = column_per_slice * rank + min(rank, remainder)
rank += 1
end = column_per_slice * rank + min(rank, remainder)
return weight[:, start:end]

sliced_local_weights = [
_slice_weight_for_rank(weight, info, self.rank)
for weight, info in zip(local_weights, local_info)
weights = [
_slice_weight_for_rank(weight, info, self.rank, offset)
for weight, info, offset in zip(weights, local_info, index_offset)
]
super().set_weights(sliced_local_weights)
# variable.assign and copy-on-write creates extra copy of weight that causes OOM
# so here we scatter update by ~128M elements chunks instead of just do
# super().set_weights(weights)
for weight, arr in zip(self.weights, weights):
if arr.size <= chunk:
weight.assign(arr)
else:
chunk_size_dim0 = chunk // weight.shape[1]
num_chunks = math.ceil(weight.shape[0] / chunk_size_dim0)
last_size = weight.shape[0] - chunk_size_dim0 * (num_chunks - 1)
chunk_sizes = [chunk_size_dim0] * (num_chunks - 1) + [last_size]
for i in range(num_chunks):
start = i * chunk_size_dim0
end = start + chunk_sizes[i]
indices = tf.range(start=start, limit=end, dtype=tf.int64)
update = tf.IndexedSlices(values=arr[start:end],
indices=indices,
dense_shape=weight.shape)
weight.scatter_update(sparse_delta=update)
del weights

if use_lock:
for _ in range(self.world_size - self.rank):
hvd.broadcast_object(0)

# 1d split that works beyond 32bit indexing limit TF support
def _split_1d(self, tensor, lengths):
# choose a number close to int32 limit as maximum chunk size
# This will handle tensor with size up to square of int32_max
chunking_threshold = 2147483646
if tensor.shape[0] <= chunking_threshold:
return tf.split(tensor, lengths)
num_chunks = math.ceil(tensor.shape[0] / chunking_threshold)
padding_len = math.ceil(tensor.shape[0] / num_chunks) * num_chunks - tensor.shape[0]
padded_tensor = tf.concat([tensor, tf.zeros(padding_len, tensor.dtype)], axis=0)
tensor_list = tf.unstack(tf.reshape(padded_tensor, [num_chunks, -1]))
result = []
for length in lengths:
this_slice = []
while length > 0:
if length > tensor_list[0].shape[0]:
this_slice.append(tensor_list.pop(0))
else:
this_slice.append(tensor_list[0][:length])
tensor_list[0] = tensor_list[0][length:]
length -= this_slice[-1].shape[0]
result.append(tf.concat(this_slice, axis=0))
return result

def get_weights(self):
def get_weights(self, all_ranks=False):
"""Returns the current weights of the layer, as NumPy arrays.
This override outputs global weights for all tables.
Args:
all_ranks (bool): If true, return weights in all ranks, otherwise only in rank 0.
Default False.
"""
# avoid copy-on-read on dense access
local_weights = [read_var_no_copy(w) for w in self.weights]
if self.world_size == 1:
return [weight.numpy() for weight in self.weights]
return [w.numpy() for w in local_weights]

# mpi segfault on over 32bit range index, so we gather weights chunk by chunk here
# choose a number not very close to int32 limit as maximum chunk size just to be safe
chunking_threshold = 2000000000
num_chunks = 1
for local_configs in self.strategy.local_configs_list:
total_elements = sum([c['input_dim'] * c['output_dim'] for c in local_configs])
num_chunks = max(num_chunks, math.ceil(self.world_size * total_elements / chunking_threshold))

# mpi segfault on large sizes so we gather weights chunk by chunk here
num_chunks = 8
with tf.device('CPU:0'):
local_weights = tf.concat([tf.reshape(w, [-1]) for w in self.weights], axis=0)
local_weights = tf.concat([tf.reshape(w, [-1]) for w in local_weights], axis=0)
chunk_size = local_weights.shape[0] // num_chunks
last_size = local_weights.shape[0] - chunk_size * (num_chunks - 1)
chunk_sizes = [chunk_size] * (num_chunks - 1) + [last_size]
local_weights = tf.split(local_weights, chunk_sizes)
local_weights = self._split_1d(local_weights, chunk_sizes)
# communicate chunk sizes
all_sizes = hvd.allgather(chunk_sizes)

# collect all chunks and split to reverse allgather concat
chunks = []
for i, w in enumerate(local_weights):
chunks += tf.split(hvd.allgather(w), all_sizes[i::num_chunks])
w = hvd.allgather(w)
if all_ranks or self.rank == 0:
chunks += self._split_1d(w, all_sizes[i::num_chunks])
if not chunks:
return []

# re-construct all local weights from chunks
local_weights = []
for i in range(self.world_size):
local_weights.append(tf.concat(chunks[i::self.world_size], axis=0))
del chunks

# split flat local weights into correct sizes
weights = []
for local_weight, local_configs in zip(local_weights, self.strategy.local_configs_list):
local_shapes = [[c['input_dim'], c['output_dim']] for c in local_configs]
local_sizes = [shape[0] * shape[1] for shape in local_shapes]
flat_weights = tf.split(local_weight, local_sizes)
flat_weights = self._split_1d(local_weight, local_sizes)
weights += [tf.reshape(weight, shape) for weight, shape in zip(flat_weights, local_shapes)]
# restore original table order
# flatten self.strategy.table_ids_list
Expand Down Expand Up @@ -408,6 +493,7 @@ def call(self, inputs): # pylint: disable=missing-function-docstring
self.local_embedding_layers[m](inp)
for m, inp in zip(self.strategy.local_input_table_map, inputs)
]
outputs = [tf.cast(output, self.compute_dtype) for output in outputs]
return outputs

# TODO(skyw): Revisit logics of selecting call functions for different strategy
Expand Down Expand Up @@ -460,7 +546,10 @@ def gradient(self, target, sources, output_gradients=None):
dp_vars.append(var)
dp_grads.append(grad)
split_infos.append((False, len(dp_grads) - 1))
dp_grads = self._allreduce_grads(dp_grads, dp_vars) # pylint: disable=protected-access
# TODO(Deyu): make sure not reusing _allreduce_grads doesn't lead to any issue
dp_grads = [
hvd.allreduce(g, name=f'dp_gradient_{i}', op=hvd.Average) for i, g in enumerate(dp_grads)
]
# put gradients back in original order
grads = []
for info in split_infos:
Expand Down
14 changes: 13 additions & 1 deletion distributed_embeddings/python/layers/dist_model_parallel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def run_and_test(self, ref_model, ref_inputs, test_model, test_inputs):
optimizer.apply_gradients(zip(ref_grads, ref_model.variables))
optimizer.apply_gradients(zip(test_grads, test_model.variables))
ref_weights = ref_model.get_weights()
test_weights = test_model.dist_embeddings.get_weights() + test_model.dense.get_weights()
test_weights = test_model.dist_embeddings.get_weights(True) + test_model.dense.get_weights()

for ref_w, test_w in zip(ref_weights, test_weights):
# assert close here since order of accumulations(inputs and batch dim) might have changed
Expand Down Expand Up @@ -269,6 +269,18 @@ def test_column_slice_threshold(self):
dp_inputs, _ = self.gen_inputs(table_sizes)
self.run_and_test(ref_model, dp_inputs, test_model, dp_inputs)

def test_column_slice_dup_worker(self):
table_sizes = [[10, 4], [11, 2], [4, 2], [4, 2]]
ref_model = EmbeddingListModel(table_sizes, distribute=False)
test_model = EmbeddingListModel(table_sizes,
distribute=True,
strategy='memory_balanced',
dp_input=False,
column_slice_threshold=10)
mp_input_ids = test_model.dist_embeddings.strategy.input_ids_list[self.hvd_rank]
dp_inputs, mp_inputs = self.gen_inputs(table_sizes, mp_input_ids=mp_input_ids)
self.run_and_test(ref_model, dp_inputs, test_model, mp_inputs)


if __name__ == "__main__":
test.main()
30 changes: 29 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,44 @@
# limitations under the License.
"""Simple setup script"""

import os
from setuptools import setup, find_packages

abspath = os.path.dirname(os.path.realpath(__file__))

with open("requirements.txt", encoding='utf-8') as f:
requirements = f.read().splitlines() # pylint: disable=invalid-name

print(find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]))

license_header = """#
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""

# Generate version file
with open(os.path.join(abspath, "version.txt"), encoding="utf-8") as f:
version = f.read().strip()
with open(os.path.join(abspath, "distributed_embeddings/version.py"), "w", encoding="utf-8") as f:
f.write(license_header)
f.write(F"__version__ = \"{version}\"")

setup(
name="distributed-embeddings",
version="1.0.0",
version=version,
description="Distributed Embedding",
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
install_requires=requirements,
Expand Down
1 change: 1 addition & 0 deletions version.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0.1.0

0 comments on commit d40c589

Please sign in to comment.