-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Features/880 binop ben bou #902
Changes from 61 commits
346822a
169f062
207e843
cb053e6
ff03641
1958470
a5cc6ff
86172ba
bd33252
e41dc58
e1fb638
871bb42
55d46bf
c415539
45dc427
363f304
2eaf0c5
70ce2bc
2075684
00c31bd
ed753a6
c341248
85d376e
50d9406
b6ba31b
6b3f136
572c432
f5fa517
c77d086
7c7160d
6cbe0cb
eb7bf04
f9fce70
203b9fc
39e626f
d573bdc
04161e9
5d198e4
120054b
f85d18e
24045b6
b329f32
7e5ee42
ee504fc
ed32233
39527f1
6627c08
3f8d722
82036e8
dca94bf
a556754
4863cd3
07dd506
80549be
eac41b8
415ee52
e70aa39
ca23d4c
f05d31f
664b27e
c6a7d9e
3ff5a41
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,147 +51,134 @@ def __binary_op( | |
------- | ||
result: ht.DNDarray | ||
A DNDarray containing the results of element-wise operation. | ||
|
||
Warning | ||
------- | ||
If both operands are distributed, they must be distributed along the same dimension, i.e. `t1.split = t2.split`. | ||
|
||
MPI communication is necessary when both operands are distributed along the same dimension, but the distribution maps do not match. E.g.: | ||
``` | ||
a = ht.ones(10000, split=0) | ||
b = ht.zeros(10000, split=0) | ||
c = a[:-1] + b[1:] | ||
``` | ||
In such cases, one of the operands is redistributed OUT-OF-PLACE to match the distribution map of the other operand. | ||
The operand determining the resulting distribution is chosen as follows: | ||
1) split is preferred to no split | ||
2) no (shape)-broadcasting in the split dimension if not necessary | ||
3) t1 is preferred to t2 | ||
""" | ||
# Check inputs | ||
if not np.isscalar(t1) and not isinstance(t1, DNDarray): | ||
raise TypeError( | ||
"Only DNDarrays and numeric scalars are supported, but input was {}".format(type(t1)) | ||
) | ||
if not np.isscalar(t2) and not isinstance(t2, DNDarray): | ||
raise TypeError( | ||
"Only DNDarrays and numeric scalars are supported, but input was {}".format(type(t2)) | ||
) | ||
promoted_type = types.result_type(t1, t2).torch_type() | ||
|
||
if np.isscalar(t1): | ||
# Make inputs Dndarrays | ||
if np.isscalar(t1) and np.isscalar(t2): | ||
try: | ||
t1 = factories.array(t1, device=t2.device if isinstance(t2, DNDarray) else None) | ||
t1 = factories.array(t1) | ||
t2 = factories.array(t2) | ||
except (ValueError, TypeError): | ||
raise TypeError("Data type not supported, input was {}".format(type(t1))) | ||
|
||
if np.isscalar(t2): | ||
try: | ||
t2 = factories.array(t2) | ||
except (ValueError, TypeError): | ||
raise TypeError( | ||
"Only numeric scalars are supported, but input was {}".format(type(t2)) | ||
) | ||
output_shape = (1,) | ||
output_split = None | ||
output_device = t2.device | ||
output_comm = MPI_WORLD | ||
elif isinstance(t2, DNDarray): | ||
output_shape = t2.shape | ||
output_split = t2.split | ||
output_device = t2.device | ||
output_comm = t2.comm | ||
else: | ||
raise TypeError( | ||
"Only tensors and numeric scalars are supported, but input was {}".format(type(t2)) | ||
) | ||
|
||
if t1.dtype != t2.dtype: | ||
t1 = t1.astype(t2.dtype) | ||
|
||
elif isinstance(t1, DNDarray): | ||
if np.isscalar(t2): | ||
try: | ||
t2 = factories.array(t2, device=t1.device) | ||
output_shape = t1.shape | ||
output_split = t1.split | ||
output_device = t1.device | ||
output_comm = t1.comm | ||
except (ValueError, TypeError): | ||
raise TypeError("Data type not supported, input was {}".format(type(t2))) | ||
|
||
elif isinstance(t2, DNDarray): | ||
if t1.split is None: | ||
t1 = factories.array( | ||
t1, split=t2.split, copy=False, comm=t1.comm, device=t1.device, ndmin=-t2.ndim | ||
) | ||
elif t2.split is None: | ||
t2 = factories.array( | ||
t2, split=t1.split, copy=False, comm=t2.comm, device=t2.device, ndmin=-t1.ndim | ||
) | ||
elif t1.split != t2.split: | ||
# It is NOT possible to perform binary operations on tensors with different splits, e.g. split=0 | ||
# and split=1 | ||
raise NotImplementedError("Not implemented for other splittings") | ||
|
||
output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape) | ||
output_split = t1.split | ||
output_device = t1.device | ||
output_comm = t1.comm | ||
|
||
if t1.split is not None: | ||
if t1.shape[t1.split] == 1 and t1.comm.is_distributed(): | ||
# warnings.warn( | ||
# "Broadcasting requires transferring data of first operator between MPI ranks!" | ||
# ) | ||
color = 0 if t1.comm.rank < t2.shape[t1.split] else 1 | ||
newcomm = t1.comm.Split(color, t1.comm.rank) | ||
if t1.comm.rank > 0 and color == 0: | ||
t1.larray = torch.zeros( | ||
t1.shape, dtype=t1.dtype.torch_type(), device=t1.device.torch_device | ||
) | ||
newcomm.Bcast(t1) | ||
newcomm.Free() | ||
|
||
if t2.split is not None: | ||
if t2.shape[t2.split] == 1 and t2.comm.is_distributed(): | ||
# warnings.warn( | ||
# "Broadcasting requires transferring data of second operator between MPI ranks!" | ||
# ) | ||
color = 0 if t2.comm.rank < t1.shape[t2.split] else 1 | ||
newcomm = t2.comm.Split(color, t2.comm.rank) | ||
if t2.comm.rank > 0 and color == 0: | ||
t2.larray = torch.zeros( | ||
t2.shape, dtype=t2.dtype.torch_type(), device=t2.device.torch_device | ||
) | ||
newcomm.Bcast(t2) | ||
newcomm.Free() | ||
|
||
else: | ||
raise TypeError( | ||
"Only tensors and numeric scalars are supported, but input was {}".format(type(t2)) | ||
) | ||
else: | ||
raise NotImplementedError("Not implemented for non scalar") | ||
|
||
# sanitize output | ||
if out is not None: | ||
sanitation.sanitize_out(out, output_shape, output_split, output_device) | ||
|
||
# promoted_type = types.promote_types(t1.dtype, t2.dtype).torch_type() | ||
if t1.split is not None: | ||
if len(t1.lshape) > t1.split and t1.lshape[t1.split] == 0: | ||
result = t1.larray.type(promoted_type) | ||
else: | ||
result = operation( | ||
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs | ||
"Data type not supported, inputs were {} and {}".format(type(t1), type(t2)) | ||
) | ||
elif np.isscalar(t1) and isinstance(t2, DNDarray): | ||
try: | ||
t1 = factories.array(t1, device=t2.device, comm=t2.comm) | ||
except (ValueError, TypeError): | ||
raise TypeError("Data type not supported, input was {}".format(type(t1))) | ||
elif isinstance(t1, DNDarray) and np.isscalar(t2): | ||
try: | ||
t2 = factories.array(t2, device=t1.device, comm=t1.comm) | ||
except (ValueError, TypeError): | ||
raise TypeError("Data type not supported, input was {}".format(type(t2))) | ||
|
||
# Make inputs have the same dimensionality | ||
output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape) | ||
# Broadcasting allows additional empty dimensions on the left side | ||
# TODO simplify this once newaxis-indexing is supported to get rid of the loops | ||
while len(t1.shape) < len(output_shape): | ||
t1 = t1.expand_dims(axis=0) | ||
while len(t2.shape) < len(output_shape): | ||
t2 = t2.expand_dims(axis=0) | ||
# t1 = t1[tuple([None] * (len(output_shape) - t1.ndim))] | ||
# t2 = t2[tuple([None] * (len(output_shape) - t2.ndim))] | ||
# print(t1.lshape, t2.lshape) | ||
|
||
def __get_out_params(target, other=None, map=None): | ||
""" | ||
Getter for the output parameters of a binop with target. | ||
If other is provided, it's distribution will be matched to target or, if provided, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. other -> |
||
redistributed according to map. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. map -> |
||
|
||
Parameters | ||
---------- | ||
target : DNDarray | ||
DNDarray determining the parameters | ||
other : DNDarray | ||
DNDarray to be adapted | ||
map : Tensor | ||
Lshape-Map other should be matched to. Defaults to target's lshape_map | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lshape-Map -> lshape_map |
||
|
||
Returns | ||
------- | ||
Tuple | ||
split, device, comm, balanced, [other] | ||
""" | ||
if other is not None: | ||
ClaudiaComito marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if out is None: | ||
other = sanitation.sanitize_distribution(other, target=target, diff_map=map) | ||
return target.split, target.device, target.comm, target.balanced, other | ||
return target.split, target.device, target.comm, target.balanced | ||
|
||
if t1.split is not None and t1.shape[t1.split] == output_shape[t1.split]: # t1 is "dominant" | ||
output_split, output_device, output_comm, output_balanced, t2 = __get_out_params(t1, t2) | ||
elif t2.split is not None and t2.shape[t2.split] == output_shape[t2.split]: # t2 is "dominant" | ||
output_split, output_device, output_comm, output_balanced, t1 = __get_out_params(t2, t1) | ||
elif t1.split is not None: | ||
# t1 is split but broadcast -> only on one rank; manipulate lshape_map s.t. this rank has 'full' data | ||
lmap = t1.lshape_map | ||
idx = lmap[:, t1.split].nonzero(as_tuple=True)[0] | ||
lmap[idx.item(), t1.split] = output_shape[t1.split] | ||
output_split, output_device, output_comm, output_balanced, t2 = __get_out_params( | ||
t1, t2, map=lmap | ||
) | ||
elif t2.split is not None: | ||
|
||
if len(t2.lshape) > t2.split and t2.lshape[t2.split] == 0: | ||
result = t2.larray.type(promoted_type) | ||
else: | ||
result = operation( | ||
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs | ||
) | ||
else: | ||
result = operation( | ||
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs | ||
# t2 is split but broadcast -> only on one rank; manipulate lshape_map s.t. this rank has 'full' data | ||
lmap = t2.lshape_map | ||
idx = lmap[:, t2.split].nonzero(as_tuple=True)[0] | ||
lmap[idx.item(), t2.split] = output_shape[t2.split] | ||
output_split, output_device, output_comm, output_balanced, t1 = __get_out_params( | ||
t2, other=t1, map=lmap | ||
) | ||
|
||
if not isinstance(result, torch.Tensor): | ||
result = torch.tensor(result, device=output_device.torch_device) | ||
else: # both are not split | ||
output_split, output_device, output_comm, output_balanced = __get_out_params(t1) | ||
|
||
if out is not None: | ||
out_dtype = out.dtype | ||
out.larray = result | ||
out._DNDarray__comm = output_comm | ||
out = out.astype(out_dtype) | ||
sanitation.sanitize_out(out, output_shape, output_split, output_device, output_comm) | ||
t1, t2 = sanitation.sanitize_distribution(t1, t2, target=out) | ||
out.larray[:] = operation( | ||
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs | ||
) | ||
return out | ||
# print(t1.lshape, t2.lshape) | ||
|
||
result = operation(t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs) | ||
|
||
return DNDarray( | ||
result, | ||
output_shape, | ||
types.heat_type_of(result), | ||
output_split, | ||
output_device, | ||
output_comm, | ||
balanced=None, | ||
device=output_device, | ||
comm=output_comm, | ||
balanced=output_balanced, | ||
) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"binop with target" -> "binary operation with target distribution"