-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathdata_parallel_tensor.py
411 lines (352 loc) · 16.4 KB
/
data_parallel_tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
from enum import Enum, auto
from typing import Any, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._C import NoneType, device
from torch._utils import _get_all_device_indices
from torch.cuda import comm
from torch.utils._pytree import tree_map
# NOTE: We need to set this because when we lift the module parameters to DataParallelTensors (DPT) using mod._apply,
# we not not want to do an in place copy of the new parameter value, we want to overwrite it.
# The DPT is a list of tensors and hence an in-place copy between the old and new values of the parameter are incompatible.
torch.__future__.set_overwrite_module_params_on_conversion(True)
import concurrent.futures as futures
torch.manual_seed(0)
aten = torch.ops.aten
NUM_DEVICES = 8
PARALLEL_DISPATCH = False
ALL_REDUCE = True
class DPTensorType(Enum):
# This tensor will be replicated across all the devices
replicated = auto()
# This tensor will be sharded along the first/batch dimension across
# the devices, NOTE: only equal chunk sizes are supported
distributed_batch = auto()
# This is a list of tensors, each of which rests on different devices
distributed = auto()
class DataParallelTensor(torch.Tensor):
# This class is a tensor subclass that stores a list of tensors with the aim
# DataParallelTensors(DPT) are categorized in three ways
# 1) replicated: When a single tensor is supplied, it is replicated across
# all the devices by using broadcast
# 2) distributed: DPT can also be initialized by supplying a list/tuple of tensors
# if the elements rest on different devices, they will just be wrapped in DPT
# else the elements are scattered to different devices
# 3) distributed batch: This type of DPT tensor is created by sharding the input tensor across
# a specified sharding dimension (default: 0). Currently only equal chunk sizes are supported.
elem: List[torch.Tensor]
if torch.cuda.is_available():
# device_ids: List[int] = _get_all_device_indices()
device_ids = [i for i in range(NUM_DEVICES)]
if PARALLEL_DISPATCH:
num_threads: int = len(device_ids)
threadpool: futures.ThreadPoolExecutor = futures.ThreadPoolExecutor(
max_workers=num_threads
)
__slots__ = ["elem"]
@staticmethod
def __new__(
cls,
elem: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
func: Optional[Any] = None,
dpt_type: DPTensorType = DPTensorType.replicated,
batch_dim: Optional[int] = 0,
):
if dpt_type == DPTensorType.replicated:
# NOTE: If the input is None, we return None
if elem is None:
return None
assert isinstance(elem, torch.Tensor)
# NOTE: For handling meta tensors, if the device of an input tensor is meta,
# we just return the first element in such a list/tuple
if elem.device == device("meta"):
return elem
with torch.no_grad():
dp_tensor: List[torch.Tensor] = comm.broadcast(
elem, devices=cls.device_ids
)
elif dpt_type == DPTensorType.distributed:
assert isinstance(elem, list) or isinstance(elem, tuple)
# We check if the first elemnt of the list/tuple is a tensor
if isinstance(elem[0], torch.Tensor):
# Make a check to see if all elements are of type tensor
assert all(isinstance(e, torch.Tensor) for e in elem)
requires_scatter: bool = False
with torch.no_grad():
for t, d_id in zip(elem, cls.device_ids):
if t.device == device("meta"):
# NOTE: For handling meta tensors, if the device of any tensor in such a list/tuple is meta,
# we just return the first element in such a list/tuple. This usually happens for factory functions,
# like torch.ones or torch.zeros generated either during forward or backward mode autodiff.
# we cannot check the equality of elemts in here since they do not exist physically
# we just check that all of them should be meta tensors
if all(e.device == torch.device("meta") for e in elem):
return elem[0]
else:
raise TypeError(
f"Device error in {func}: Not all tensors are meta."
)
if t.device != device(d_id):
requires_scatter = True
break
if requires_scatter:
# We first stack all the tensors in the list/tuple along dimension 0, to get a single tensor
# We then scatter the tensor along the 0th dimension to different devices
# The scatter function returns a list of tensors with a redundant 0th dimension for each element
# We squeeze out the redundant dimension from each of these elements to finally get a list of tensors
# each residing on a list of devices
stacked_t: torch.Tensor = torch.stack(elem, dim=0)
scattered_t: Tuple[torch.Tensor] = comm.scatter(
stacked_t, devices=cls.device_ids, dim=0
)
dp_tensor: List[torch.Tensor] = [
torch.squeeze(t, dim=0) for t in scattered_t
]
else:
dp_tensor: List[torch.Tensor] = elem
else:
# Elements of the list/tuple are non-tensors.
# NOTE: If the list contains non-tensor types then we return a single value only if all of them have identical value.
if all(v == elem[0] for v in elem):
return elem[0]
else:
raise ValueError(
f"Operation {func} retuns non-identical non-tensor values for some elemnts of DPT"
)
elif dpt_type == DPTensorType.distributed_batch:
# NOTE: This requires the batch dimension to be divisible by the number of devices.
assert isinstance(elem, torch.Tensor)
with torch.no_grad():
scattered_t: Tuple[torch.Tensor] = comm.scatter(
elem, devices=cls.device_ids, dim=batch_dim
)
dp_tensor: List[torch.Tensor] = list(scattered_t)
meta_t: torch.Tensor = (
elem if dpt_type == DPTensorType.replicated else dp_tensor[0]
)
r = torch.Tensor._make_wrapper_subclass(
cls,
meta_t.size(),
strides=meta_t.stride(),
storage_offset=meta_t.storage_offset(),
device=meta_t.device, # This is the device of of either input tensor or first tensor of a list
dtype=meta_t.dtype,
layout=meta_t.layout,
requires_grad=meta_t.requires_grad,
)
r.elem = dp_tensor
return r
def __repr__(self):
if self.grad_fn:
return f"DataParallelTensor({self.elem}, grad_fn={self.grad_fn})"
return f"DataParallelTensor({self.elem})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def wrap(e):
if isinstance(e, DataParallelTensor):
return e
elif isinstance(e, torch.Tensor):
return DataParallelTensor(e, func, DPTensorType.replicated)
else:
return e
# All the args and kwargs are checked and any leaf tensors are wrapped as replicated DPTs
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
def unwrap_with_position(pos):
def get_element(e):
return e.elem[pos] if isinstance(e, DataParallelTensor) else e
return get_element
# Call the function for each of the DPT elements by unwarpping them and corresponding args and kwargs,
# into element tensors so that the operation is performed on all the elements residing on the same device
if PARALLEL_DISPATCH:
future_res: List[futures.Future] = []
for pos in range(cls.num_threads):
future_res.append(
cls.threadpool.submit(
func,
*tree_map(unwrap_with_position(pos), args),
**tree_map(unwrap_with_position(pos), kwargs),
)
)
outs = [future_res[i].result() for i in range(cls.num_threads)]
else:
outs = []
for pos in range(len(cls.device_ids)):
outs.append(
func(
*tree_map(unwrap_with_position(pos), args),
**tree_map(unwrap_with_position(pos), kwargs),
)
)
# The ouput will always be a list since we are creating it
# The list can contain tensors, bools, list of tensors or tuples of tensors or None
# In case of tensors we just wrap them in DPT
# In case of list/tuple of tensors, the corresponding elements across list/tuple are warpped
# into a DPT and a list/tuple is returned respectively
def out_wrap(e, func):
assert isinstance(e, list)
if isinstance(e[0], torch.Tensor):
return DataParallelTensor(outs, func, DPTensorType.distributed)
elif isinstance(e[0], list):
return list(
DataParallelTensor(list(t), func, DPTensorType.distributed)
for t in zip(*e)
)
elif isinstance(e[0], tuple):
return tuple(
DataParallelTensor(list(t), func, DPTensorType.distributed)
for t in zip(*e)
)
else:
# NOTE: If the list contains non-tensor types then we return a single value only if all of them have identical value.
if all(v == e[0] for v in e):
return e[0]
else:
raise ValueError(
f"Operation {func} retuns non-identical non-tensor values for some elemnts of DPT"
)
outs = out_wrap(outs, func)
return outs
def all_reduce_grad(
self,
r_device: Optional[int] = torch.cuda.current_device()
if torch.cuda.is_available()
else 0,
):
with torch.no_grad():
reduced_tensor: torch.Tensor = comm.reduce_add(self.elem, r_device)
b_tensor: List[torch.Tensor] = comm.broadcast(reduced_tensor, out=self.elem)
self.elem = b_tensor
return reduced_tensor
def make_data_parallel_module(mod: torch.nn.Module):
# This function converts the parameters of a nn.Module to replicated DataParallelTensors
# the else part is important for buffers of the module
def wrapper(t):
if isinstance(t, torch.nn.Parameter):
return DataParallelTensor(t.data, None, DPTensorType.replicated)
else:
assert type(t) in (torch.Tensor, NoneType, bool)
return DataParallelTensor(t, None, DPTensorType.replicated)
mod._apply(wrapper)
if __name__ == "__main__":
if torch.cuda.is_available():
print("Devices: ", [i for i in range(NUM_DEVICES)])
else:
print("GPU not found. Need GPUs to run examples. Exiting...")
exit()
try:
from functools import partial
from functorch import hessian, jacfwd, jacrev, vjp, vmap
D = 16
x: torch.Tensor = torch.randn(D, device="cuda")
dpt_x = DataParallelTensor(x, None, DPTensorType.replicated)
def predict(weight, bias, x):
return F.linear(x, weight, bias).tanh()
weight = torch.randn(D, D, device="cuda")
bias = torch.randn(D, device="cuda")
# Computing Jacobian using vmap and vjp and jacrev
clone_x = dpt_x.clone().requires_grad_()
unit_vectors = torch.eye(D).cuda()
_, vjp_fn = vjp(partial(predict, weight, bias), clone_x)
(ft_jacobian,) = vmap(vjp_fn)(unit_vectors)
clone_x = dpt_x.clone().requires_grad_()
jacobian_rev = jacrev(predict, argnums=2)(weight, bias, clone_x)
print(torch.allclose(ft_jacobian, jacobian_rev))
# Computing Hessian using composition of jacrev and jacfwd vs hessian api
clone_x = dpt_x.clone().requires_grad_()
hess_api = hessian(predict, argnums=2)(weight, bias, clone_x)
hess_fwdrev = jacfwd(jacrev(predict, argnums=2), argnums=2)(
weight, bias, clone_x
)
print(torch.allclose(hess_api, hess_fwdrev))
except ImportError:
print("Skipping functorch example, package missing.")
try:
# Example with a torchvision model
import torchvision.models as models
batch_size = 256
test_tensor: torch.Tensor = torch.randn(
batch_size * NUM_DEVICES, 3, 224, 224, device="cuda"
)
dp_tensor = DataParallelTensor(
test_tensor, None, DPTensorType.distributed_batch
)
model = models.resnet50().cuda()
make_data_parallel_module(model)
# Warmp up iteration
out = model(dp_tensor)
loss = out.sum()
loss.backward()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(1):
out = model(dp_tensor)
loss = out.sum()
loss.backward()
if ALL_REDUCE:
for p in model.parameters():
p.grad.all_reduce_grad()
# p = p - 0.5 * p.grad
end_event.record()
torch.cuda.synchronize()
print("Timing for 1 iteration (ms) DPT: ", start_event.elapsed_time(end_event))
test_tensor: torch.Tensor = torch.randn(batch_size, 3, 224, 224, device="cuda")
model = models.resnet50().cuda()
# Warmp up iteration
out = model(test_tensor)
loss = out.sum()
loss.backward()
start_event.record()
for i in range(NUM_DEVICES):
out = model(test_tensor)
loss = out.sum()
loss.backward()
# for p in model.parameters():
# p = p - 0.5 * p.grad
end_event.record()
torch.cuda.synchronize()
print(
"Timing for " + str(NUM_DEVICES) + " iterations(ms): ",
start_event.elapsed_time(end_event),
)
except ImportError:
print("Running custom model since torchvision package is absent.")
# Custom Model Example
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
mod: torch.nn.Module = MyModel().cuda()
inp: torch.Tensor = torch.randn(512, 3, 32, 32, device="cuda")
dpt_inp = DataParallelTensor(inp, None, DPTensorType.distributed_batch)
make_data_parallel_module(mod)
out = mod(dpt_inp)
loss = out.sum()
loss.backward()
for p in mod.parameters():
p.grad.all_reduce_grad()
p = p - 0.5 * p.grad
# Custom Function Example
test_tensor = torch.randn(8, 5, device="cuda", requires_grad=True)
dp_tensor = DataParallelTensor(test_tensor, None, DPTensorType.distributed_batch)
def custom_func(x):
return x.cos().cos().sum()
res_tensor = custom_func(dp_tensor)
print(res_tensor)
res_tensor.backward()
print(dp_tensor.grad)