Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
17b09a3
Support cpu tensors without direct device invocation
abhilash1910 Jun 29, 2023
9174511
Merge branch 'master' into abhilash1910_cpu_fix
tjruwase Jun 30, 2023
ca0f7c7
Merge branch 'microsoft:master' into abhilash1910_cpu_fix
abhilash1910 Jul 18, 2023
85a2b32
Add dtype from accelerator interface
abhilash1910 Jul 18, 2023
ba8e699
Merge branch 'master' into abhilash1910_cpu_fix
tjruwase Jul 20, 2023
37d24e4
Remove sparsetensor type check
abhilash1910 Jul 21, 2023
8ab94ef
Add dtype feature
abhilash1910 Jul 21, 2023
7fe8526
Merge branch 'microsoft:master' into abhilash1910_cpu_fix
abhilash1910 Jul 21, 2023
289719c
change print to dtype
abhilash1910 Jul 21, 2023
a1ba5b7
Merge branch 'microsoft:master' into abhilash1910_cpu_fix
abhilash1910 Jul 25, 2023
ae055ca
Merge branch 'master' into abhilash1910_cpu_fix
tjruwase Jul 25, 2023
04a5779
missing dtype attribute
abhilash1910 Jul 25, 2023
ed29074
remove type
abhilash1910 Jul 31, 2023
fbe7417
Merge branch 'microsoft:master' into abhilash1910_cpu_fix
abhilash1910 Jul 31, 2023
a020465
Merge branch 'master' into abhilash1910_cpu_fix
tjruwase Aug 1, 2023
73281d1
Merge branch 'microsoft:master' into abhilash1910_cpu_fix
abhilash1910 Aug 7, 2023
af0d94b
separate sparse and dense tensors
abhilash1910 Aug 7, 2023
7e6bdb3
Modify review commit
abhilash1910 Aug 10, 2023
64afaa4
Modify expert grads reduce
abhilash1910 Aug 10, 2023
e3dcb05
Merge branch 'master' into abhilash1910_cpu_fix
abhilash1910 Aug 10, 2023
efa0ce5
fix indent
abhilash1910 Aug 10, 2023
938abbb
precommit formatting
abhilash1910 Aug 11, 2023
b617e42
precommit format
abhilash1910 Aug 11, 2023
5287b33
Review commit
abhilash1910 Aug 12, 2023
026ca9e
Review commit
abhilash1910 Aug 12, 2023
1bdc29f
Merge branch 'master' into abhilash1910_cpu_fix
tjruwase Aug 21, 2023
0cc5297
Merge branch 'master' into abhilash1910_cpu_fix
tjruwase Aug 22, 2023
3891b6d
Merge branch 'master' into abhilash1910_cpu_fix
abhilash1910 Aug 23, 2023
50b83f6
Merge branch 'master' into abhilash1910_cpu_fix
tjruwase Aug 24, 2023
8e5dd19
Merge branch 'master' into abhilash1910_cpu_fix
tjruwase Aug 24, 2023
596d108
Bypass string check
abhilash1910 Aug 25, 2023
f160b32
Merge branch 'master' into abhilash1910_cpu_fix
abhilash1910 Aug 25, 2023
d08b518
revert change
abhilash1910 Aug 25, 2023
24c1092
Style format
abhilash1910 Aug 25, 2023
8d127ca
Check for str instance
abhilash1910 Aug 27, 2023
027bfd7
Merge branch 'master' into abhilash1910_cpu_fix
abhilash1910 Sep 3, 2023
990d1f4
Merge branch 'microsoft:master' into abhilash1910_cpu_fix
abhilash1910 Sep 22, 2023
5f0422f
Merge branch 'microsoft:master' into abhilash1910_cpu_fix
abhilash1910 Oct 23, 2023
ad62426
Merge branch 'master' into abhilash1910_cpu_fix
abhilash1910 Nov 22, 2023
3c1471c
Merge branch 'microsoft:master' into abhilash1910_cpu_fix
abhilash1910 Nov 22, 2023
3631025
revert manual typecasting
abhilash1910 Nov 22, 2023
0179d58
Fix bug in split tensor
abhilash1910 Nov 23, 2023
148d182
fix format
abhilash1910 Nov 24, 2023
5adfb17
remove tensor tuples
abhilash1910 Nov 29, 2023
6f09874
fix formats
abhilash1910 Nov 29, 2023
6bd8654
Merge branch 'microsoft:master' into abhilash1910_cpu_fix
abhilash1910 Nov 29, 2023
7481416
Merge branch 'master' into abhilash1910_cpu_fix
abhilash1910 Dec 4, 2023
982ea3e
Merge branch 'master' into abhilash1910_cpu_fix
tjruwase Dec 8, 2023
e34b08a
Merge branch 'microsoft:master' into abhilash1910_cpu_fix
abhilash1910 Dec 11, 2023
58d1712
fix the repetition loop list append
inkcherry Dec 14, 2023
cb7b76a
Merge pull request #2 from inkcherry/fix_repetition
abhilash1910 Dec 14, 2023
8301fc9
Merge branch 'master' into abhilash1910_cpu_fix
abhilash1910 Dec 14, 2023
1654496
Merge branch 'master' into abhilash1910_cpu_fix
abhilash1910 Dec 17, 2023
c81da04
Merge branch 'master' into abhilash1910_cpu_fix
abhilash1910 Dec 19, 2023
b98b4df
Merge branch 'master' into abhilash1910_cpu_fix
tjruwase Dec 20, 2023
81e2fb1
Merge branch 'microsoft:master' into abhilash1910_cpu_fix
abhilash1910 Jan 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 40 additions & 29 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,21 +120,25 @@

def split_half_float_double_sparse(tensors):
device_type = get_accelerator().device_name()
supported_types = [
"torch.{}.HalfTensor".format(device_type), "torch.{}.FloatTensor".format(device_type),
"torch.{}.DoubleTensor".format(device_type), "torch.{}.BFloat16Tensor".format(device_type),
SparseTensor.type()
]
supported_types = get_accelerator().supported_dtypes()

for t in tensors:
assert t.type() in supported_types, f"attempting to reduce an unsupported grad type: {t.type()}"
assert t.dtype in supported_types, f"attempting to reduce an unsupported grad type: {t.dtype}"

buckets = []
sparse_tensor_buckets, dense_tensor_buckets = [], []
for i, dtype in enumerate(supported_types):
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append((dtype, bucket))
return buckets
sparse_bucket, dense_bucket = [], []
for t in tensors:
if t.dtype == dtype:
if isinstance(t, SparseTensor):
sparse_bucket.append(t)
else:
dense_bucket.append(t)
if sparse_bucket:
sparse_tensor_buckets.append((dtype, sparse_bucket))
if dense_bucket:
dense_tensor_buckets.append((dtype, dense_bucket))
return sparse_tensor_buckets, dense_tensor_buckets


class EngineTimers(object):
Expand Down Expand Up @@ -2396,30 +2400,37 @@ def _get_gradients_for_reduction(self):
return non_expert_grads, expert_grads

def _reduce_non_expert_gradients(self, grads, elements_per_buffer):
split_buckets = split_half_float_double_sparse(grads)
for _, bucket_tuple in enumerate(split_buckets):
bucket_type, bucket = bucket_tuple
split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse(grads)
if self.pipeline_parallelism:
dp_group = self.mpu.get_data_parallel_group()
else:
dp_group = groups._get_sequence_data_parallel_group()

if self.pipeline_parallelism:
dp_group = self.mpu.get_data_parallel_group()
else:
dp_group = groups._get_sequence_data_parallel_group()
for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets):
if sparse_bucket_tuple:
bucket_type, sparse_bucket = sparse_bucket_tuple
self.sparse_allreduce_no_retain(sparse_bucket, dp_group=dp_group)

if bucket_type == SparseTensor.type():
self.sparse_allreduce_no_retain(bucket, dp_group=dp_group)
else:
self.allreduce_no_retain(bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer)
for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets):
if dense_bucket_tuple:
bucket_type, dense_bucket = dense_bucket_tuple
self.allreduce_no_retain(dense_bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer)

def _reduce_expert_gradients(self, expert_grads, elements_per_buffer):
for ep_name, expert_grads_group in expert_grads.items():
expert_split_buckets = split_half_float_double_sparse(expert_grads_group)
for i, bucket_tuple in enumerate(expert_split_buckets):
bucket_type, bucket = bucket_tuple
if bucket_type == SparseTensor.type():
self.sparse_allreduce_no_retain(bucket, groups._get_expert_data_parallel_group(ep_name))
else:
split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse(
expert_grads_group)

for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets):
if sparse_bucket_tuple:
bucket_type, sparse_bucket = sparse_bucket_tuple
self.sparse_allreduce_no_retain(sparse_bucket, groups._get_expert_data_parallel_group(ep_name))

for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets):
if dense_bucket_tuple:
bucket_type, dense_bucket = dense_bucket_tuple
# Separate between diff groups
self.allreduce_no_retain(bucket,
self.allreduce_no_retain(dense_bucket,
dp_group=groups._get_expert_data_parallel_group(ep_name),
numel_per_bucket=elements_per_buffer)

Expand Down
1 change: 1 addition & 0 deletions deepspeed/runtime/sparse_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class SparseTensor(object):

def __init__(self, dense_tensor=None):
self.orig_dense_tensor = dense_tensor
self.dtype = self.orig_dense_tensor.dtype
self.is_sparse = dense_tensor.is_sparse
if dense_tensor is not None:
if dense_tensor.is_sparse:
Expand Down