Skip to content

Commit

Permalink
Add op bitwise_or (#31559)
Browse files Browse the repository at this point in the history
Summary:
ezyang ,  this PR add bitwise_or operator as pytorch/pytorch#31104 .
Benchmark script :
```
import timeit
import torch
torch.manual_seed(1)

for n, t in [(10, 100000),(1000, 10000)]:
    print('__or__ (a.numel() == {}) for {} times'.format(n, t))
    for device in ('cpu', 'cuda'):
        for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'):
            print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t')
            print(timeit.timeit(f'a | b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}")', number=t))

for n, t in [(10, 100000),(1000, 10000)]:
    print('__ior__ (a.numel() == {}) for {} times'.format(n, t))
    for device in ('cpu', 'cuda'):
        for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'):
            print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t')
            print(timeit.timeit(f'a | b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.tensor(5, dtype = {dtype}, device="{device}")', number=t))
```
Device: **Tesla P100, skx-8180**
Cuda verison: **9.0.176**

Before:
```
__or__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times            0.17616272252053022
device: cpu, dtype: torch.uint8, 100000 times           0.17148233391344547
device: cpu, dtype: torch.int16, 100000 times           0.17616403382271528
device: cpu, dtype: torch.int32, 100000 times           0.17717823758721352
device: cpu, dtype: torch.int64, 100000 times           0.1801931718364358
device: cuda, dtype: torch.int8, 100000 times           1.270583058707416
device: cuda, dtype: torch.uint8, 100000 times          1.2636413089931011
device: cuda, dtype: torch.int16, 100000 times          1.2839747751131654
device: cuda, dtype: torch.int32, 100000 times          1.2548385225236416
device: cuda, dtype: torch.int64, 100000 times          1.2650810535997152
__or__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times             0.031136621721088886
device: cpu, dtype: torch.uint8, 10000 times            0.030786747112870216
device: cpu, dtype: torch.int16, 10000 times            0.02391665056347847
device: cpu, dtype: torch.int32, 10000 times            0.024147341027855873
device: cpu, dtype: torch.int64, 10000 times            0.024414129555225372
device: cuda, dtype: torch.int8, 10000 times            0.12741921469569206
device: cuda, dtype: torch.uint8, 10000 times           0.1249831635504961
device: cuda, dtype: torch.int16, 10000 times           0.1283819805830717
device: cuda, dtype: torch.int32, 10000 times           0.12591975275427103
device: cuda, dtype: torch.int64, 10000 times           0.12655890546739101
__ior__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times            0.3908365070819855
device: cpu, dtype: torch.uint8, 100000 times           0.38267823681235313
device: cpu, dtype: torch.int16, 100000 times           0.38239253498613834
device: cpu, dtype: torch.int32, 100000 times           0.3817988149821758
device: cpu, dtype: torch.int64, 100000 times           0.3901665909215808
device: cuda, dtype: torch.int8, 100000 times           1.4211318120360374
device: cuda, dtype: torch.uint8, 100000 times          1.4215159295126796
device: cuda, dtype: torch.int16, 100000 times          1.4307750314474106
device: cuda, dtype: torch.int32, 100000 times          1.4123614141717553
device: cuda, dtype: torch.int64, 100000 times          1.4480243818834424
__ior__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times             0.06468924414366484
device: cpu, dtype: torch.uint8, 10000 times            0.06442475505173206
device: cpu, dtype: torch.int16, 10000 times            0.05267547257244587
device: cpu, dtype: torch.int32, 10000 times            0.05286940559744835
device: cpu, dtype: torch.int64, 10000 times            0.06211103219538927
device: cuda, dtype: torch.int8, 10000 times            0.15332304500043392
device: cuda, dtype: torch.uint8, 10000 times           0.15353196952492
device: cuda, dtype: torch.int16, 10000 times           0.15300503931939602
device: cuda, dtype: torch.int32, 10000 times           0.15274472255259752
device: cuda, dtype: torch.int64, 10000 times           0.1512152962386608
```
After:
```
__or__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times            0.2465507509186864
device: cpu, dtype: torch.uint8, 100000 times           0.2472386620938778
device: cpu, dtype: torch.int16, 100000 times           0.2469814233481884
device: cpu, dtype: torch.int32, 100000 times           0.2535214088857174
device: cpu, dtype: torch.int64, 100000 times           0.24855613708496094
device: cuda, dtype: torch.int8, 100000 times           1.4351346511393785
device: cuda, dtype: torch.uint8, 100000 times          1.4434308474883437
device: cuda, dtype: torch.int16, 100000 times          1.4520929995924234
device: cuda, dtype: torch.int32, 100000 times          1.4456610176712275
device: cuda, dtype: torch.int64, 100000 times          1.4580101007595658
__or__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times             0.029985425993800163
device: cpu, dtype: torch.uint8, 10000 times            0.03024935908615589
device: cpu, dtype: torch.int16, 10000 times            0.026356655173003674
device: cpu, dtype: torch.int32, 10000 times            0.027377349324524403
device: cpu, dtype: torch.int64, 10000 times            0.029163731262087822
device: cuda, dtype: torch.int8, 10000 times            0.14540370367467403
device: cuda, dtype: torch.uint8, 10000 times           0.1456305105239153
device: cuda, dtype: torch.int16, 10000 times           0.1450125053524971
device: cuda, dtype: torch.int32, 10000 times           0.1472016740590334
device: cuda, dtype: torch.int64, 10000 times           0.14709716010838747
__ior__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times            0.27195510920137167
device: cpu, dtype: torch.uint8, 100000 times           0.2692424338310957
device: cpu, dtype: torch.int16, 100000 times           0.27726674638688564
device: cpu, dtype: torch.int32, 100000 times           0.2815811652690172
device: cpu, dtype: torch.int64, 100000 times           0.2852728571742773
device: cuda, dtype: torch.int8, 100000 times           1.4743850827217102
device: cuda, dtype: torch.uint8, 100000 times          1.4766502184793353
device: cuda, dtype: torch.int16, 100000 times          1.4774163831025362
device: cuda, dtype: torch.int32, 100000 times          1.4749693805351853
device: cuda, dtype: torch.int64, 100000 times          1.5772947426885366
__ior__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times             0.03614502027630806
device: cpu, dtype: torch.uint8, 10000 times            0.03619729354977608
device: cpu, dtype: torch.int16, 10000 times            0.0319912089034915
device: cpu, dtype: torch.int32, 10000 times            0.03319283854216337
device: cpu, dtype: torch.int64, 10000 times            0.0343862259760499
device: cuda, dtype: torch.int8, 10000 times            0.1581476852297783
device: cuda, dtype: torch.uint8, 10000 times           0.15974601730704308
device: cuda, dtype: torch.int16, 10000 times           0.15957212820649147
device: cuda, dtype: torch.int32, 10000 times           0.16002820804715157
device: cuda, dtype: torch.int64, 10000 times           0.16129320487380028
```

Fix  pytorch/pytorch#24511, pytorch/pytorch#24515, pytorch/pytorch#24658, pytorch/pytorch#24662.
Pull Request resolved: pytorch/pytorch#31559

Differential Revision: D19315875

Pulled By: ezyang

fbshipit-source-id: 4a3ca88fdafbeb796079687e676228111eb44aad
  • Loading branch information
XiaobingSuper authored and facebook-github-bot committed Jan 8, 2020
1 parent 4f9d2f7 commit 9ba6a76
Show file tree
Hide file tree
Showing 20 changed files with 191 additions and 208 deletions.
44 changes: 0 additions & 44 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -305,50 +305,6 @@
- THTensor* self
- THTensor* other
]]
[[
name: _th_or
cname: __or__
cpu_bool: True
cuda_bool: True
variants:
- function
return: argument 0
options:
- cname: bitor
arguments:
- arg: THTensor* result
output: True
- THTensor* self
- real other
- cname: cbitor
arguments:
- arg: THTensor* result
output: True
- arg: THTensor* self
broadcast: other fallback
- THTensor* other
]]
[[
name: _th_ior_
cname: __ior__
cpu_bool: True
cuda_bool: True
variants:
- function
return: argument 0
options:
- cname: bitor
arguments:
- THTensor* self
- THTensor* self
- real other
- cname: cbitor
arguments:
- THTensor* self
- arg: THTensor* self
broadcast: other inplace fallback
- THTensor* other
]]
[[
name: _th_lshift
cname: __lshift__
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ _(aten, ne) \
_(aten, neg) \
_(aten, bitwise_and) \
_(aten, bitwise_not) \
_(aten, bitwise_or) \
_(aten, bitwise_xor) \
_(aten, nll_loss) \
_(aten, nll_loss2d) \
Expand Down
48 changes: 48 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ DEFINE_DISPATCH(mul_stub);
DEFINE_DISPATCH(div_stub);
DEFINE_DISPATCH(atan2_stub);
DEFINE_DISPATCH(bitwise_and_stub);
DEFINE_DISPATCH(bitwise_or_stub);
DEFINE_DISPATCH(bitwise_xor_stub);
DEFINE_DISPATCH(logical_and_stub);
DEFINE_DISPATCH(logical_or_stub);
Expand Down Expand Up @@ -282,6 +283,53 @@ Tensor& __iand__(Tensor& self, Scalar other) {
return self.bitwise_and_(other);
}

Tensor& bitwise_or_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
bitwise_or_stub(iter.device_type(), iter);
return result;
}

Tensor bitwise_or(const Tensor& self, const Tensor& other) {
Tensor result = at::empty({0}, self.options());
at::bitwise_or_out(result, self, other);
return result;
}

Tensor& bitwise_or_(Tensor& self, const Tensor& other) {
return at::bitwise_or_out(self, self, other);
}

Tensor& bitwise_or_out(Tensor& result, const Tensor& self, Scalar other) {
return at::bitwise_or_out(result, self, wrapped_scalar_tensor(other));
}

Tensor bitwise_or(const Tensor& self, Scalar other) {
Tensor result = at::empty({0}, self.options());
return at::bitwise_or_out(result, self, other);
}

Tensor& bitwise_or_(Tensor& self, Scalar other) {
return at::bitwise_or_out(self, self, other);
}

// Legacy or interfaces. They are aliased to bitwise_or* functions
Tensor __or__(const Tensor& self, const Tensor& other) {
return at::bitwise_or(self, other);
}

Tensor __or__(const Tensor& self, Scalar other) {
return at::bitwise_or(self, other);
}

Tensor& __ior__(Tensor& self, const Tensor& other) {
return self.bitwise_or_(other);
}

Tensor& __ior__(Tensor& self, Scalar other) {
return self.bitwise_or_(other);
}

Tensor& bitwise_xor_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/BinaryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ DECLARE_DISPATCH(binary_fn, mul_stub);
DECLARE_DISPATCH(binary_fn, div_stub);
DECLARE_DISPATCH(binary_fn, atan2_stub);
DECLARE_DISPATCH(binary_fn, bitwise_and_stub);
DECLARE_DISPATCH(binary_fn, bitwise_or_stub);
DECLARE_DISPATCH(binary_fn, bitwise_xor_stub);
DECLARE_DISPATCH(binary_fn, logical_xor_stub);
DECLARE_DISPATCH(binary_fn, logical_and_stub);
Expand Down
30 changes: 26 additions & 4 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ void div_kernel(TensorIterator& iter) {
void bitwise_and_kernel(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
cpu_kernel(
iter,
[](bool a, bool b) {
return a && b;
});
iter,
[](bool a, bool b) {
return a && b;
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_and_cpu", [&]() {
cpu_kernel_vec(
Expand All @@ -114,6 +114,27 @@ void bitwise_and_kernel(TensorIterator& iter) {
}
}

void bitwise_or_kernel(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
cpu_kernel(
iter,
[](bool a, bool b) {
return a || b;
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_or_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return a | b;
},
[](Vec256<scalar_t> a, Vec256<scalar_t> b) {
return a | b;
});
});
}
}

void bitwise_xor_kernel(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
// Boolean type does not work with ^ (bitwise XOR) in C++. bitwise_xor wraps this operation for both Boolean and
Expand Down Expand Up @@ -363,6 +384,7 @@ REGISTER_DISPATCH(mul_stub, &mul_kernel);
REGISTER_DISPATCH(div_stub, &div_kernel);
REGISTER_DISPATCH(atan2_stub, &atan2_kernel);
REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel);
REGISTER_DISPATCH(bitwise_or_stub, &bitwise_or_kernel);
REGISTER_DISPATCH(bitwise_xor_stub, &bitwise_xor_kernel);
REGISTER_DISPATCH(logical_xor_stub, &logical_xor_kernel);
REGISTER_DISPATCH(logical_and_stub, &logical_and_kernel);
Expand Down
27 changes: 23 additions & 4 deletions aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ void atan2_kernel_cuda(TensorIterator& iter) {
void bitwise_and_kernel_cuda(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
gpu_kernel_with_scalars(
iter,
[]GPU_LAMBDA(bool a, bool b) {
return a && b;
});
iter,
[]GPU_LAMBDA(bool a, bool b) {
return a && b;
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_and_cuda", [&]() {
gpu_kernel_with_scalars(
Expand All @@ -36,6 +36,24 @@ void bitwise_and_kernel_cuda(TensorIterator& iter) {
}
}

void bitwise_or_kernel_cuda(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
gpu_kernel_with_scalars(
iter,
[]GPU_LAMBDA(bool a, bool b) {
return a || b;
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_or_cuda", [&]() {
gpu_kernel_with_scalars(
iter,
[]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a | b;
});
});
}
}

void bitwise_xor_kernel_cuda(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
// Boolean type does not work with ^ (bitwise XOR) in C++. bitwise_xor wraps this operation for both Boolean and
Expand Down Expand Up @@ -116,6 +134,7 @@ void mse_kernel_cuda(TensorIterator& iter) {

REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel_cuda);
REGISTER_DISPATCH(bitwise_or_stub, &bitwise_or_kernel_cuda);
REGISTER_DISPATCH(bitwise_xor_stub, &bitwise_xor_kernel_cuda);
REGISTER_DISPATCH(logical_and_stub, &logical_and_kernel_cuda);
REGISTER_DISPATCH(logical_or_stub, &logical_or_kernel_cuda);
Expand Down
36 changes: 24 additions & 12 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3978,31 +3978,43 @@
- func: __iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
variants: method

- func: bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
CPU: bitwise_or_out
CUDA: bitwise_or_out

- func: bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
CPU: bitwise_or_out
CUDA: bitwise_or_out

- func: bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor
variants: method, function

- func: bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor
variants: method, function

- func: bitwise_or_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
variants: method

- func: bitwise_or_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
variants: method

- func: __or__.Scalar(Tensor self, Scalar other) -> Tensor
use_c10_dispatcher: full
variants: method, function
dispatch:
CPU: legacy::cpu::_th_or
CUDA: legacy::cuda::_th_or

- func: __or__.Tensor(Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: method, function
dispatch:
CPU: legacy::cpu::_th_or
CUDA: legacy::cuda::_th_or

- func: __ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
variants: method
dispatch:
CPU: legacy::cpu::_th_ior_
CUDA: legacy::cuda::_th_ior_

- func: __ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
variants: method
dispatch:
CPU: legacy::cpu::_th_ior_
CUDA: legacy::cuda::_th_ior_

- func: bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
variants: function
Expand Down
61 changes: 0 additions & 61 deletions aten/src/TH/generic/THTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,67 +23,6 @@
// what I did when I split these up originally).


void THTensor_(cbitor)(THTensor *r_, THTensor *t, THTensor *src)
{
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF)
(void)r_;
(void)t;
(void)src;
return THError("cbitor is only supported for integer type tensors");
#else
THTensor_(resizeAs)(r_, t);
int64_t r_Size = THTensor_(nElement)(r_);
int64_t srcSize = THTensor_(nElement)(src);
int r_Contig = THTensor_(isContiguous)(r_);
int tContig = THTensor_(isContiguous)(t);
int srcContig = THTensor_(isContiguous)(src);
if (srcSize == r_Size){
if (r_Contig && tContig && srcContig) {
scalar_t *tp = t->data<scalar_t>();
scalar_t *sp = src->data<scalar_t>();
scalar_t *rp = r_->data<scalar_t>();
at::parallel_for(0, r_Size, TH_OMP_OVERHEAD_THRESHOLD,
[&](int64_t start, int64_t end) {
for (auto i = start; i < end; i++) {
rp[i] = tp[i] | sp[i];
}
});
} else {
TH_TENSOR_APPLY3_PARALLEL(r_Size, r_Contig, tContig, srcContig, scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data | *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
}
} else {
TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data | *src_data;);
}
#endif
}

void THTensor_(bitor)(THTensor *r_, THTensor *t, scalar_t value)
{
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF)
(void)r_;
(void)t;
(void)value;
return THError("bitor is only supported for integer type tensors");
#else
THTensor_(resizeAs)(r_, t);
int64_t r_Size = THTensor_(nElement)(r_);
int r_Contig = THTensor_(isContiguous)(r_);
int tContig = THTensor_(isContiguous)(t);
if (r_Contig && tContig) {
scalar_t *tp = t->data<scalar_t>();
scalar_t *rp = r_->data<scalar_t>();
at::parallel_for(0, r_Size, TH_OMP_OVERHEAD_THRESHOLD * 100,
[&](int64_t start, int64_t end) {
for (auto i = start; i < end; i++) {
rp[i] = tp[i] | value;
}
});
} else {
TH_TENSOR_APPLY2_PARALLEL(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = *t_data | value;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
}
#endif
}

#if !defined(TH_REAL_IS_BOOL) /* non bool only part */

static void THTensor_(addmmImpl)(THTensor *r_, THTensor *t, THTensor *m1, THTensor *m2, scalar_t beta, scalar_t alpha)
Expand Down
3 changes: 0 additions & 3 deletions aten/src/TH/generic/THTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ TH_API void THTensor_(div)(THTensor *r_, THTensor *t, scalar_t value);

TH_API accreal THTensor_(sumall)(THTensor *t);

TH_API void THTensor_(bitor)(THTensor *r_, THTensor *t, scalar_t value);
TH_API void THTensor_(cbitor)(THTensor *r_, THTensor *t, THTensor *src);

void THTensor_(preserveReduceDimSemantics)(THTensor *r_, int in_dims, int reduce_dimension, int keepdim);

TH_API void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim);
Expand Down
14 changes: 0 additions & 14 deletions aten/src/THC/THCTensorMathPairwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -259,20 +259,6 @@ struct TensorRShiftConstantOp {
const T val;
};

template <typename T>
struct TensorBitOrConstantOp {
TensorBitOrConstantOp(T v) : val(v) {}
__device__ __forceinline__ void operator()(T* out, T* in) {
*out = *in | val;
}

__device__ __forceinline__ void operator()(T* v) {
*v |= val;
}

const T val;
};

#include <THC/generic/THCTensorMathPairwise.cu>
#include <THC/THCGenerateAllTypes.h>

Expand Down
Loading

0 comments on commit 9ba6a76

Please sign in to comment.