Skip to content

Commit

Permalink
Add op bitwise_and (#31104)
Browse files Browse the repository at this point in the history
Summary:
Refer to pytorch/pytorch#25665,  add `bitwise_and` operator.
Benchmark script :
```
import timeit
#for __and__
for n, t in [(10, 100000),(1000, 10000)]:
    print('__and__ (a.numel() == {}) for {} times'.format(n, t))
    for device in ('cpu', 'cuda'):
        for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'):
            print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t')
            print(timeit.timeit(f'a & b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}")', number=t))
#for __iand__
for n, t in [(10, 100000),(1000, 10000)]:
    print('__iand__ (a.numel() == {}) for {} times'.format(n, t))
    for device in ('cpu', 'cuda'):
        for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'):
            print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t')
            print(timeit.timeit(f'a & b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.tensor(5, dtype = {dtype}, device="{device}")', number=t))
```
Device: **Tesla P100, skx-8180**
Cuda verison: **9.0.176**

Before:
```
__and__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times            0.1766007635742426
device: cpu, dtype: torch.uint8, 100000 times           0.17322628945112228
device: cpu, dtype: torch.int16, 100000 times           0.17650844901800156
device: cpu, dtype: torch.int32, 100000 times           0.17711848113685846
device: cpu, dtype: torch.int64, 100000 times           0.18240160401910543
device: cuda, dtype: torch.int8, 100000 times           1.273967768996954
device: cuda, dtype: torch.uint8, 100000 times          1.2778537990525365
device: cuda, dtype: torch.int16, 100000 times          1.2753686187788844
device: cuda, dtype: torch.int32, 100000 times          1.2797665279358625
device: cuda, dtype: torch.int64, 100000 times          1.2933144550770521
__and__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times             0.031139614060521126
device: cpu, dtype: torch.uint8, 10000 times            0.03091452084481716
device: cpu, dtype: torch.int16, 10000 times            0.022756479680538177
device: cpu, dtype: torch.int32, 10000 times            0.025045674294233322
device: cpu, dtype: torch.int64, 10000 times            0.024164282716810703
device: cuda, dtype: torch.int8, 10000 times            0.12820732593536377
device: cuda, dtype: torch.uint8, 10000 times           0.12775669433176517
device: cuda, dtype: torch.int16, 10000 times           0.12697868794202805
device: cuda, dtype: torch.int32, 10000 times           0.12832533661276102
device: cuda, dtype: torch.int64, 10000 times           0.1280576130375266
__iand__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times            0.3687064303085208
device: cpu, dtype: torch.uint8, 100000 times           0.36253443732857704
device: cpu, dtype: torch.int16, 100000 times           0.362891579978168
device: cpu, dtype: torch.int32, 100000 times           0.37680106051266193
device: cpu, dtype: torch.int64, 100000 times           0.3689364707097411
device: cuda, dtype: torch.int8, 100000 times           1.419940729625523
device: cuda, dtype: torch.uint8, 100000 times          1.4247053815051913
device: cuda, dtype: torch.int16, 100000 times          1.4191444097086787
device: cuda, dtype: torch.int32, 100000 times          1.4305962566286325
device: cuda, dtype: torch.int64, 100000 times          1.4567416654899716
__iand__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times             0.06224383972585201
device: cpu, dtype: torch.uint8, 10000 times            0.06205617543309927
device: cpu, dtype: torch.int16, 10000 times            0.05016433447599411
device: cpu, dtype: torch.int32, 10000 times            0.05216377507895231
device: cpu, dtype: torch.int64, 10000 times            0.06139362137764692
device: cuda, dtype: torch.int8, 10000 times            0.14827249851077795
device: cuda, dtype: torch.uint8, 10000 times           0.14801877550780773
device: cuda, dtype: torch.int16, 10000 times           0.14952312968671322
device: cuda, dtype: torch.int32, 10000 times           0.14999118447303772
device: cuda, dtype: torch.int64, 10000 times           0.14951884001493454
```
After:
```
__and__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times            0.23157884553074837
device: cpu, dtype: torch.uint8, 100000 times           0.23063660878688097
device: cpu, dtype: torch.int16, 100000 times           0.23005440644919872
device: cpu, dtype: torch.int32, 100000 times           0.23748818412423134
device: cpu, dtype: torch.int64, 100000 times           0.24106105230748653
device: cuda, dtype: torch.int8, 100000 times           1.4394256137311459
device: cuda, dtype: torch.uint8, 100000 times          1.4436759827658534
device: cuda, dtype: torch.int16, 100000 times          1.4631587155163288
device: cuda, dtype: torch.int32, 100000 times          1.459101552143693
device: cuda, dtype: torch.int64, 100000 times          1.4784048134461045
__and__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times             0.028442862443625927
device: cpu, dtype: torch.uint8, 10000 times            0.028130197897553444
device: cpu, dtype: torch.int16, 10000 times            0.025318274274468422
device: cpu, dtype: torch.int32, 10000 times            0.02519288007169962
device: cpu, dtype: torch.int64, 10000 times            0.028299466706812382
device: cuda, dtype: torch.int8, 10000 times            0.14342594426125288
device: cuda, dtype: torch.uint8, 10000 times           0.145280827768147
device: cuda, dtype: torch.int16, 10000 times           0.14673697855323553
device: cuda, dtype: torch.int32, 10000 times           0.14499565307050943
device: cuda, dtype: torch.int64, 10000 times           0.14582364354282618
__iand__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times            0.25548241566866636
device: cpu, dtype: torch.uint8, 100000 times           0.2552562616765499
device: cpu, dtype: torch.int16, 100000 times           0.25905191246420145
device: cpu, dtype: torch.int32, 100000 times           0.26635489892214537
device: cpu, dtype: torch.int64, 100000 times           0.26269810926169157
device: cuda, dtype: torch.int8, 100000 times           1.485458506271243
device: cuda, dtype: torch.uint8, 100000 times          1.4742380809038877
device: cuda, dtype: torch.int16, 100000 times          1.507783885113895
device: cuda, dtype: torch.int32, 100000 times          1.4926990242674947
device: cuda, dtype: torch.int64, 100000 times          1.519851053133607
__iand__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times             0.03425929415971041
device: cpu, dtype: torch.uint8, 10000 times            0.03293587639927864
device: cpu, dtype: torch.int16, 10000 times            0.029559112153947353
device: cpu, dtype: torch.int32, 10000 times            0.030915481969714165
device: cpu, dtype: torch.int64, 10000 times            0.03292469773441553
device: cuda, dtype: torch.int8, 10000 times            0.15792148280888796
device: cuda, dtype: torch.uint8, 10000 times           0.16000914946198463
device: cuda, dtype: torch.int16, 10000 times           0.1600684942677617
device: cuda, dtype: torch.int32, 10000 times           0.16162546630948782
device: cuda, dtype: torch.int64, 10000 times           0.1629159888252616
```
Fix  pytorch/pytorch#24508, pytorch/pytorch#24509,  pytorch/pytorch#24655, pytorch/pytorch#24656.
Pull Request resolved: pytorch/pytorch#31104

Differential Revision: D18938930

Pulled By: VitalyFedyunin

fbshipit-source-id: a77e805a0b84e8ace16c6e648c2f67dad44f2e44
  • Loading branch information
XiaobingSuper authored and facebook-github-bot committed Jan 3, 2020
1 parent 68f3782 commit b47e9b9
Show file tree
Hide file tree
Showing 21 changed files with 182 additions and 222 deletions.
44 changes: 0 additions & 44 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -304,50 +304,6 @@
- THTensor* self
- THTensor* other
]]
[[
name: _th_and
cpu_bool: True
cuda_bool: True
cname: __and__
variants:
- function
return: argument 0
options:
- cname: bitand
arguments:
- arg: THTensor* result
output: True
- THTensor* self
- real other
- cname: cbitand
arguments:
- arg: THTensor* result
output: True
- arg: THTensor* self
broadcast: other fallback
- THTensor* other
]]
[[
name: _th_iand_
cname: __iand__
cpu_bool: True
cuda_bool: True
variants:
- function
return: argument 0
options:
- cname: bitand
arguments:
- THTensor* self
- THTensor* self
- real other
- cname: cbitand
arguments:
- THTensor* self
- arg: THTensor* self
broadcast: other inplace fallback
- THTensor* other
]]
[[
name: _th_or
cname: __or__
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ _(aten, native_tensor) \
_(aten, native_zero) \
_(aten, ne) \
_(aten, neg) \
_(aten, bitwise_and) \
_(aten, bitwise_not) \
_(aten, bitwise_xor) \
_(aten, nll_loss) \
Expand Down
48 changes: 48 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ DEFINE_DISPATCH(sub_stub);
DEFINE_DISPATCH(mul_stub);
DEFINE_DISPATCH(div_stub);
DEFINE_DISPATCH(atan2_stub);
DEFINE_DISPATCH(bitwise_and_stub);
DEFINE_DISPATCH(bitwise_xor_stub);
DEFINE_DISPATCH(logical_and_stub);
DEFINE_DISPATCH(logical_or_stub);
Expand Down Expand Up @@ -234,6 +235,53 @@ Tensor rsub(const Tensor& self, Scalar other, Scalar alpha) {
return native::rsub(self, wrapped_scalar_tensor(other), alpha);
}

Tensor& bitwise_and_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
bitwise_and_stub(iter.device_type(), iter);
return result;
}

Tensor bitwise_and(const Tensor& self, const Tensor& other) {
Tensor result = at::empty({0}, self.options());
at::bitwise_and_out(result, self, other);
return result;
}

Tensor& bitwise_and_(Tensor& self, const Tensor& other) {
return at::bitwise_and_out(self, self, other);
}

Tensor& bitwise_and_out(Tensor& result, const Tensor& self, Scalar other) {
return at::bitwise_and_out(result, self, wrapped_scalar_tensor(other));
}

Tensor bitwise_and(const Tensor& self, Scalar other) {
Tensor result = at::empty({0}, self.options());
return at::bitwise_and_out(result, self, other);
}

Tensor& bitwise_and_(Tensor& self, Scalar other) {
return at::bitwise_and_out(self, self, other);
}

// Legacy and interfaces. They are aliased to bitwise_and* functions
Tensor __and__(const Tensor& self, const Tensor& other) {
return at::bitwise_and(self, other);
}

Tensor __and__(const Tensor& self, Scalar other) {
return at::bitwise_and(self, other);
}

Tensor& __iand__(Tensor& self, const Tensor& other) {
return self.bitwise_and_(other);
}

Tensor& __iand__(Tensor& self, Scalar other) {
return self.bitwise_and_(other);
}

Tensor& bitwise_xor_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/BinaryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ DECLARE_DISPATCH(binary_fn_alpha, sub_stub);
DECLARE_DISPATCH(binary_fn, mul_stub);
DECLARE_DISPATCH(binary_fn, div_stub);
DECLARE_DISPATCH(binary_fn, atan2_stub);
DECLARE_DISPATCH(binary_fn, bitwise_and_stub);
DECLARE_DISPATCH(binary_fn, bitwise_xor_stub);
DECLARE_DISPATCH(binary_fn, logical_xor_stub);
DECLARE_DISPATCH(binary_fn, logical_and_stub);
Expand Down
22 changes: 22 additions & 0 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,27 @@ void div_kernel(TensorIterator& iter) {
}
}

void bitwise_and_kernel(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
cpu_kernel(
iter,
[](bool a, bool b) {
return a && b;
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_and_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return a & b;
},
[](Vec256<scalar_t> a, Vec256<scalar_t> b) {
return a & b;
});
});
}
}

void bitwise_xor_kernel(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
// Boolean type does not work with ^ (bitwise XOR) in C++. bitwise_xor wraps this operation for both Boolean and
Expand Down Expand Up @@ -341,6 +362,7 @@ REGISTER_DISPATCH(sub_stub, &sub_kernel);
REGISTER_DISPATCH(mul_stub, &mul_kernel);
REGISTER_DISPATCH(div_stub, &div_kernel);
REGISTER_DISPATCH(atan2_stub, &atan2_kernel);
REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel);
REGISTER_DISPATCH(bitwise_xor_stub, &bitwise_xor_kernel);
REGISTER_DISPATCH(logical_xor_stub, &logical_xor_kernel);
REGISTER_DISPATCH(logical_and_stub, &logical_and_kernel);
Expand Down
19 changes: 19 additions & 0 deletions aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@ void atan2_kernel_cuda(TensorIterator& iter) {
});
}

void bitwise_and_kernel_cuda(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
gpu_kernel_with_scalars(
iter,
[]GPU_LAMBDA(bool a, bool b) {
return a && b;
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_and_cuda", [&]() {
gpu_kernel_with_scalars(
iter,
[]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a & b;
});
});
}
}

void bitwise_xor_kernel_cuda(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
// Boolean type does not work with ^ (bitwise XOR) in C++. bitwise_xor wraps this operation for both Boolean and
Expand Down Expand Up @@ -97,6 +115,7 @@ void mse_kernel_cuda(TensorIterator& iter) {
}

REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel_cuda);
REGISTER_DISPATCH(bitwise_xor_stub, &bitwise_xor_kernel_cuda);
REGISTER_DISPATCH(logical_and_stub, &logical_and_kernel_cuda);
REGISTER_DISPATCH(logical_or_stub, &logical_or_kernel_cuda);
Expand Down
36 changes: 24 additions & 12 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3934,31 +3934,43 @@
- func: ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
variants: method

- func: bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
CPU: bitwise_and_out
CUDA: bitwise_and_out

- func: bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
CPU: bitwise_and_out
CUDA: bitwise_and_out

- func: bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor
variants: method, function

- func: bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor
variants: method, function

- func: bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
variants: method

- func: bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
variants: method

- func: __and__.Scalar(Tensor self, Scalar other) -> Tensor
use_c10_dispatcher: full
variants: method, function
dispatch:
CPU: legacy::cpu::_th_and
CUDA: legacy::cuda::_th_and

- func: __and__.Tensor(Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: method, function
dispatch:
CPU: legacy::cpu::_th_and
CUDA: legacy::cuda::_th_and

- func: __iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
variants: method
dispatch:
CPU: legacy::cpu::_th_iand_
CUDA: legacy::cuda::_th_iand_

- func: __iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
variants: method
dispatch:
CPU: legacy::cpu::_th_iand_
CUDA: legacy::cuda::_th_iand_

- func: __or__.Scalar(Tensor self, Scalar other) -> Tensor
use_c10_dispatcher: full
Expand Down
27 changes: 0 additions & 27 deletions aten/src/TH/generic/THTensorEvenMoreMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,33 +295,6 @@ accreal THTensor_(sumall)(THTensor *tensor)
return sum;
}

void THTensor_(bitand)(THTensor *r_, THTensor *t, scalar_t value)
{
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) || defined(TH_REAL_IS_BFLOAT16)
(void)r_;
(void)t;
(void)value;
return THError("bitand is only supported for integer type tensors");
#else
THTensor_(resizeAs)(r_, t);
int64_t r_Size = THTensor_(nElement)(r_);
int r_Contig = THTensor_(isContiguous)(r_);
int tContig = THTensor_(isContiguous)(t);
if (r_Contig && tContig) {
scalar_t *tp = t->data<scalar_t>();
scalar_t *rp = r_->data<scalar_t>();
at::parallel_for(0, r_Size, TH_OMP_OVERHEAD_THRESHOLD * 100,
[&](int64_t start, int64_t end) {
for (auto i = start; i < end; i++) {
rp[i] = tp[i] & value;
}
});
} else {
TH_TENSOR_APPLY2_PARALLEL(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = *t_data & value;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
}
#endif
}

scalar_t THTensor_(minall)(THTensor *tensor)
{
scalar_t theMin;
Expand Down
33 changes: 0 additions & 33 deletions aten/src/TH/generic/THTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,39 +22,6 @@
// sense (rather than just having cut the file down the middle, which is
// what I did when I split these up originally).

void THTensor_(cbitand)(THTensor *r_, THTensor *t, THTensor *src)
{
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF)
(void)r_;
(void)t;
(void)src;
return THError("cbitand is only supported for integer type tensors");
#else
THTensor_(resizeAs)(r_, t);
int64_t r_Size = THTensor_(nElement)(r_);
int64_t srcSize = THTensor_(nElement)(src);
int r_Contig = THTensor_(isContiguous)(r_);
int tContig = THTensor_(isContiguous)(t);
int srcContig = THTensor_(isContiguous)(src);
if (srcSize == r_Size){
if (r_Contig && tContig && srcContig) {
scalar_t *tp = t->data<scalar_t>();
scalar_t *sp = src->data<scalar_t>();
scalar_t *rp = r_->data<scalar_t>();
at::parallel_for(0, r_Size, TH_OMP_OVERHEAD_THRESHOLD,
[&](int64_t start, int64_t end) {
for (auto i = start; i < end; i++) {
rp[i] = tp[i] & sp[i];
}
});
} else {
TH_TENSOR_APPLY3_PARALLEL(r_Size, r_Contig, tContig, srcContig, scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data & *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
}
} else {
TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data & *src_data;);
}
#endif
}

void THTensor_(cbitor)(THTensor *r_, THTensor *t, THTensor *src)
{
Expand Down
2 changes: 0 additions & 2 deletions aten/src/TH/generic/THTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ TH_API void THTensor_(div)(THTensor *r_, THTensor *t, scalar_t value);

TH_API accreal THTensor_(sumall)(THTensor *t);

TH_API void THTensor_(bitand)(THTensor *r_, THTensor *t, scalar_t value);
TH_API void THTensor_(cbitand)(THTensor *r_, THTensor *t, THTensor *src);
TH_API void THTensor_(bitor)(THTensor *r_, THTensor *t, scalar_t value);
TH_API void THTensor_(cbitor)(THTensor *r_, THTensor *t, THTensor *src);

Expand Down
28 changes: 0 additions & 28 deletions aten/src/THC/THCTensorMathPairwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -259,20 +259,6 @@ struct TensorRShiftConstantOp {
const T val;
};

template <typename T>
struct TensorBitAndConstantOp {
TensorBitAndConstantOp(T v) : val(v) {}
__device__ __forceinline__ void operator()(T* out, T* in) {
*out = *in & val;
}

__device__ __forceinline__ void operator()(T* v) {
*v &= val;
}

const T val;
};

template <typename T>
struct TensorBitOrConstantOp {
TensorBitOrConstantOp(T v) : val(v) {}
Expand All @@ -287,20 +273,6 @@ struct TensorBitOrConstantOp {
const T val;
};

template <typename T>
struct TensorBitXorConstantOp {
TensorBitXorConstantOp(T v) : val(v) {}
__device__ __forceinline__ void operator()(T* out, T* in) {
*out = *in ^ val;
}

__device__ __forceinline__ void operator()(T* v) {
*v ^= val;
}

const T val;
};

#include <THC/generic/THCTensorMathPairwise.cu>
#include <THC/THCGenerateAllTypes.h>

Expand Down
Loading

0 comments on commit b47e9b9

Please sign in to comment.