diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index c03fba2d7784..3cbc4c8414b7 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -120,21 +120,25 @@ def split_half_float_double_sparse(tensors): device_type = get_accelerator().device_name() - supported_types = [ - "torch.{}.HalfTensor".format(device_type), "torch.{}.FloatTensor".format(device_type), - "torch.{}.DoubleTensor".format(device_type), "torch.{}.BFloat16Tensor".format(device_type), - SparseTensor.type() - ] + supported_types = get_accelerator().supported_dtypes() for t in tensors: - assert t.type() in supported_types, f"attempting to reduce an unsupported grad type: {t.type()}" + assert t.dtype in supported_types, f"attempting to reduce an unsupported grad type: {t.dtype}" - buckets = [] + sparse_tensor_buckets, dense_tensor_buckets = [], [] for i, dtype in enumerate(supported_types): - bucket = [t for t in tensors if t.type() == dtype] - if bucket: - buckets.append((dtype, bucket)) - return buckets + sparse_bucket, dense_bucket = [], [] + for t in tensors: + if t.dtype == dtype: + if isinstance(t, SparseTensor): + sparse_bucket.append(t) + else: + dense_bucket.append(t) + if sparse_bucket: + sparse_tensor_buckets.append((dtype, sparse_bucket)) + if dense_bucket: + dense_tensor_buckets.append((dtype, dense_bucket)) + return sparse_tensor_buckets, dense_tensor_buckets class EngineTimers(object): @@ -2396,30 +2400,37 @@ def _get_gradients_for_reduction(self): return non_expert_grads, expert_grads def _reduce_non_expert_gradients(self, grads, elements_per_buffer): - split_buckets = split_half_float_double_sparse(grads) - for _, bucket_tuple in enumerate(split_buckets): - bucket_type, bucket = bucket_tuple + split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse(grads) + if self.pipeline_parallelism: + dp_group = self.mpu.get_data_parallel_group() + else: + dp_group = groups._get_sequence_data_parallel_group() - if self.pipeline_parallelism: - dp_group = self.mpu.get_data_parallel_group() - else: - dp_group = groups._get_sequence_data_parallel_group() + for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets): + if sparse_bucket_tuple: + bucket_type, sparse_bucket = sparse_bucket_tuple + self.sparse_allreduce_no_retain(sparse_bucket, dp_group=dp_group) - if bucket_type == SparseTensor.type(): - self.sparse_allreduce_no_retain(bucket, dp_group=dp_group) - else: - self.allreduce_no_retain(bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer) + for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets): + if dense_bucket_tuple: + bucket_type, dense_bucket = dense_bucket_tuple + self.allreduce_no_retain(dense_bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer) def _reduce_expert_gradients(self, expert_grads, elements_per_buffer): for ep_name, expert_grads_group in expert_grads.items(): - expert_split_buckets = split_half_float_double_sparse(expert_grads_group) - for i, bucket_tuple in enumerate(expert_split_buckets): - bucket_type, bucket = bucket_tuple - if bucket_type == SparseTensor.type(): - self.sparse_allreduce_no_retain(bucket, groups._get_expert_data_parallel_group(ep_name)) - else: + split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse( + expert_grads_group) + + for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets): + if sparse_bucket_tuple: + bucket_type, sparse_bucket = sparse_bucket_tuple + self.sparse_allreduce_no_retain(sparse_bucket, groups._get_expert_data_parallel_group(ep_name)) + + for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets): + if dense_bucket_tuple: + bucket_type, dense_bucket = dense_bucket_tuple # Separate between diff groups - self.allreduce_no_retain(bucket, + self.allreduce_no_retain(dense_bucket, dp_group=groups._get_expert_data_parallel_group(ep_name), numel_per_bucket=elements_per_buffer) diff --git a/deepspeed/runtime/sparse_tensor.py b/deepspeed/runtime/sparse_tensor.py index f0bb5c75530e..291ba5f0c786 100644 --- a/deepspeed/runtime/sparse_tensor.py +++ b/deepspeed/runtime/sparse_tensor.py @@ -15,6 +15,7 @@ class SparseTensor(object): def __init__(self, dense_tensor=None): self.orig_dense_tensor = dense_tensor + self.dtype = self.orig_dense_tensor.dtype self.is_sparse = dense_tensor.is_sparse if dense_tensor is not None: if dense_tensor.is_sparse: