Skip to content

Commit b47e9b9

Browse files
XiaobingSuperfacebook-github-bot
authored andcommitted
Add op bitwise_and (#31104)
Summary: Refer to pytorch/pytorch#25665, add `bitwise_and` operator. Benchmark script : ``` import timeit #for __and__ for n, t in [(10, 100000),(1000, 10000)]: print('__and__ (a.numel() == {}) for {} times'.format(n, t)) for device in ('cpu', 'cuda'): for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'): print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t') print(timeit.timeit(f'a & b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}")', number=t)) #for __iand__ for n, t in [(10, 100000),(1000, 10000)]: print('__iand__ (a.numel() == {}) for {} times'.format(n, t)) for device in ('cpu', 'cuda'): for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'): print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t') print(timeit.timeit(f'a & b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.tensor(5, dtype = {dtype}, device="{device}")', number=t)) ``` Device: **Tesla P100, skx-8180** Cuda verison: **9.0.176** Before: ``` __and__ (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 0.1766007635742426 device: cpu, dtype: torch.uint8, 100000 times 0.17322628945112228 device: cpu, dtype: torch.int16, 100000 times 0.17650844901800156 device: cpu, dtype: torch.int32, 100000 times 0.17711848113685846 device: cpu, dtype: torch.int64, 100000 times 0.18240160401910543 device: cuda, dtype: torch.int8, 100000 times 1.273967768996954 device: cuda, dtype: torch.uint8, 100000 times 1.2778537990525365 device: cuda, dtype: torch.int16, 100000 times 1.2753686187788844 device: cuda, dtype: torch.int32, 100000 times 1.2797665279358625 device: cuda, dtype: torch.int64, 100000 times 1.2933144550770521 __and__ (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.031139614060521126 device: cpu, dtype: torch.uint8, 10000 times 0.03091452084481716 device: cpu, dtype: torch.int16, 10000 times 0.022756479680538177 device: cpu, dtype: torch.int32, 10000 times 0.025045674294233322 device: cpu, dtype: torch.int64, 10000 times 0.024164282716810703 device: cuda, dtype: torch.int8, 10000 times 0.12820732593536377 device: cuda, dtype: torch.uint8, 10000 times 0.12775669433176517 device: cuda, dtype: torch.int16, 10000 times 0.12697868794202805 device: cuda, dtype: torch.int32, 10000 times 0.12832533661276102 device: cuda, dtype: torch.int64, 10000 times 0.1280576130375266 __iand__ (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 0.3687064303085208 device: cpu, dtype: torch.uint8, 100000 times 0.36253443732857704 device: cpu, dtype: torch.int16, 100000 times 0.362891579978168 device: cpu, dtype: torch.int32, 100000 times 0.37680106051266193 device: cpu, dtype: torch.int64, 100000 times 0.3689364707097411 device: cuda, dtype: torch.int8, 100000 times 1.419940729625523 device: cuda, dtype: torch.uint8, 100000 times 1.4247053815051913 device: cuda, dtype: torch.int16, 100000 times 1.4191444097086787 device: cuda, dtype: torch.int32, 100000 times 1.4305962566286325 device: cuda, dtype: torch.int64, 100000 times 1.4567416654899716 __iand__ (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.06224383972585201 device: cpu, dtype: torch.uint8, 10000 times 0.06205617543309927 device: cpu, dtype: torch.int16, 10000 times 0.05016433447599411 device: cpu, dtype: torch.int32, 10000 times 0.05216377507895231 device: cpu, dtype: torch.int64, 10000 times 0.06139362137764692 device: cuda, dtype: torch.int8, 10000 times 0.14827249851077795 device: cuda, dtype: torch.uint8, 10000 times 0.14801877550780773 device: cuda, dtype: torch.int16, 10000 times 0.14952312968671322 device: cuda, dtype: torch.int32, 10000 times 0.14999118447303772 device: cuda, dtype: torch.int64, 10000 times 0.14951884001493454 ``` After: ``` __and__ (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 0.23157884553074837 device: cpu, dtype: torch.uint8, 100000 times 0.23063660878688097 device: cpu, dtype: torch.int16, 100000 times 0.23005440644919872 device: cpu, dtype: torch.int32, 100000 times 0.23748818412423134 device: cpu, dtype: torch.int64, 100000 times 0.24106105230748653 device: cuda, dtype: torch.int8, 100000 times 1.4394256137311459 device: cuda, dtype: torch.uint8, 100000 times 1.4436759827658534 device: cuda, dtype: torch.int16, 100000 times 1.4631587155163288 device: cuda, dtype: torch.int32, 100000 times 1.459101552143693 device: cuda, dtype: torch.int64, 100000 times 1.4784048134461045 __and__ (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.028442862443625927 device: cpu, dtype: torch.uint8, 10000 times 0.028130197897553444 device: cpu, dtype: torch.int16, 10000 times 0.025318274274468422 device: cpu, dtype: torch.int32, 10000 times 0.02519288007169962 device: cpu, dtype: torch.int64, 10000 times 0.028299466706812382 device: cuda, dtype: torch.int8, 10000 times 0.14342594426125288 device: cuda, dtype: torch.uint8, 10000 times 0.145280827768147 device: cuda, dtype: torch.int16, 10000 times 0.14673697855323553 device: cuda, dtype: torch.int32, 10000 times 0.14499565307050943 device: cuda, dtype: torch.int64, 10000 times 0.14582364354282618 __iand__ (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 0.25548241566866636 device: cpu, dtype: torch.uint8, 100000 times 0.2552562616765499 device: cpu, dtype: torch.int16, 100000 times 0.25905191246420145 device: cpu, dtype: torch.int32, 100000 times 0.26635489892214537 device: cpu, dtype: torch.int64, 100000 times 0.26269810926169157 device: cuda, dtype: torch.int8, 100000 times 1.485458506271243 device: cuda, dtype: torch.uint8, 100000 times 1.4742380809038877 device: cuda, dtype: torch.int16, 100000 times 1.507783885113895 device: cuda, dtype: torch.int32, 100000 times 1.4926990242674947 device: cuda, dtype: torch.int64, 100000 times 1.519851053133607 __iand__ (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.03425929415971041 device: cpu, dtype: torch.uint8, 10000 times 0.03293587639927864 device: cpu, dtype: torch.int16, 10000 times 0.029559112153947353 device: cpu, dtype: torch.int32, 10000 times 0.030915481969714165 device: cpu, dtype: torch.int64, 10000 times 0.03292469773441553 device: cuda, dtype: torch.int8, 10000 times 0.15792148280888796 device: cuda, dtype: torch.uint8, 10000 times 0.16000914946198463 device: cuda, dtype: torch.int16, 10000 times 0.1600684942677617 device: cuda, dtype: torch.int32, 10000 times 0.16162546630948782 device: cuda, dtype: torch.int64, 10000 times 0.1629159888252616 ``` Fix pytorch/pytorch#24508, pytorch/pytorch#24509, pytorch/pytorch#24655, pytorch/pytorch#24656. Pull Request resolved: pytorch/pytorch#31104 Differential Revision: D18938930 Pulled By: VitalyFedyunin fbshipit-source-id: a77e805a0b84e8ace16c6e648c2f67dad44f2e44
1 parent 68f3782 commit b47e9b9

21 files changed

+182
-222
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -304,50 +304,6 @@
304304
- THTensor* self
305305
- THTensor* other
306306
]]
307-
[[
308-
name: _th_and
309-
cpu_bool: True
310-
cuda_bool: True
311-
cname: __and__
312-
variants:
313-
- function
314-
return: argument 0
315-
options:
316-
- cname: bitand
317-
arguments:
318-
- arg: THTensor* result
319-
output: True
320-
- THTensor* self
321-
- real other
322-
- cname: cbitand
323-
arguments:
324-
- arg: THTensor* result
325-
output: True
326-
- arg: THTensor* self
327-
broadcast: other fallback
328-
- THTensor* other
329-
]]
330-
[[
331-
name: _th_iand_
332-
cname: __iand__
333-
cpu_bool: True
334-
cuda_bool: True
335-
variants:
336-
- function
337-
return: argument 0
338-
options:
339-
- cname: bitand
340-
arguments:
341-
- THTensor* self
342-
- THTensor* self
343-
- real other
344-
- cname: cbitand
345-
arguments:
346-
- THTensor* self
347-
- arg: THTensor* self
348-
broadcast: other inplace fallback
349-
- THTensor* other
350-
]]
351307
[[
352308
name: _th_or
353309
cname: __or__

aten/src/ATen/core/aten_interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ _(aten, native_tensor) \
502502
_(aten, native_zero) \
503503
_(aten, ne) \
504504
_(aten, neg) \
505+
_(aten, bitwise_and) \
505506
_(aten, bitwise_not) \
506507
_(aten, bitwise_xor) \
507508
_(aten, nll_loss) \

aten/src/ATen/native/BinaryOps.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ DEFINE_DISPATCH(sub_stub);
1515
DEFINE_DISPATCH(mul_stub);
1616
DEFINE_DISPATCH(div_stub);
1717
DEFINE_DISPATCH(atan2_stub);
18+
DEFINE_DISPATCH(bitwise_and_stub);
1819
DEFINE_DISPATCH(bitwise_xor_stub);
1920
DEFINE_DISPATCH(logical_and_stub);
2021
DEFINE_DISPATCH(logical_or_stub);
@@ -234,6 +235,53 @@ Tensor rsub(const Tensor& self, Scalar other, Scalar alpha) {
234235
return native::rsub(self, wrapped_scalar_tensor(other), alpha);
235236
}
236237

238+
Tensor& bitwise_and_out(Tensor& result, const Tensor& self, const Tensor& other) {
239+
auto iter = TensorIterator::binary_op(result, self, other,
240+
/*check_mem_overlap=*/true);
241+
bitwise_and_stub(iter.device_type(), iter);
242+
return result;
243+
}
244+
245+
Tensor bitwise_and(const Tensor& self, const Tensor& other) {
246+
Tensor result = at::empty({0}, self.options());
247+
at::bitwise_and_out(result, self, other);
248+
return result;
249+
}
250+
251+
Tensor& bitwise_and_(Tensor& self, const Tensor& other) {
252+
return at::bitwise_and_out(self, self, other);
253+
}
254+
255+
Tensor& bitwise_and_out(Tensor& result, const Tensor& self, Scalar other) {
256+
return at::bitwise_and_out(result, self, wrapped_scalar_tensor(other));
257+
}
258+
259+
Tensor bitwise_and(const Tensor& self, Scalar other) {
260+
Tensor result = at::empty({0}, self.options());
261+
return at::bitwise_and_out(result, self, other);
262+
}
263+
264+
Tensor& bitwise_and_(Tensor& self, Scalar other) {
265+
return at::bitwise_and_out(self, self, other);
266+
}
267+
268+
// Legacy and interfaces. They are aliased to bitwise_and* functions
269+
Tensor __and__(const Tensor& self, const Tensor& other) {
270+
return at::bitwise_and(self, other);
271+
}
272+
273+
Tensor __and__(const Tensor& self, Scalar other) {
274+
return at::bitwise_and(self, other);
275+
}
276+
277+
Tensor& __iand__(Tensor& self, const Tensor& other) {
278+
return self.bitwise_and_(other);
279+
}
280+
281+
Tensor& __iand__(Tensor& self, Scalar other) {
282+
return self.bitwise_and_(other);
283+
}
284+
237285
Tensor& bitwise_xor_out(Tensor& result, const Tensor& self, const Tensor& other) {
238286
auto iter = TensorIterator::binary_op(result, self, other,
239287
/*check_mem_overlap=*/true);

aten/src/ATen/native/BinaryOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ DECLARE_DISPATCH(binary_fn_alpha, sub_stub);
3232
DECLARE_DISPATCH(binary_fn, mul_stub);
3333
DECLARE_DISPATCH(binary_fn, div_stub);
3434
DECLARE_DISPATCH(binary_fn, atan2_stub);
35+
DECLARE_DISPATCH(binary_fn, bitwise_and_stub);
3536
DECLARE_DISPATCH(binary_fn, bitwise_xor_stub);
3637
DECLARE_DISPATCH(binary_fn, logical_xor_stub);
3738
DECLARE_DISPATCH(binary_fn, logical_and_stub);

aten/src/ATen/native/cpu/BinaryOpsKernel.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,27 @@ void div_kernel(TensorIterator& iter) {
9393
}
9494
}
9595

96+
void bitwise_and_kernel(TensorIterator& iter) {
97+
if (iter.dtype() == ScalarType::Bool) {
98+
cpu_kernel(
99+
iter,
100+
[](bool a, bool b) {
101+
return a && b;
102+
});
103+
} else {
104+
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_and_cpu", [&]() {
105+
cpu_kernel_vec(
106+
iter,
107+
[](scalar_t a, scalar_t b) -> scalar_t {
108+
return a & b;
109+
},
110+
[](Vec256<scalar_t> a, Vec256<scalar_t> b) {
111+
return a & b;
112+
});
113+
});
114+
}
115+
}
116+
96117
void bitwise_xor_kernel(TensorIterator& iter) {
97118
if (iter.dtype() == ScalarType::Bool) {
98119
// Boolean type does not work with ^ (bitwise XOR) in C++. bitwise_xor wraps this operation for both Boolean and
@@ -341,6 +362,7 @@ REGISTER_DISPATCH(sub_stub, &sub_kernel);
341362
REGISTER_DISPATCH(mul_stub, &mul_kernel);
342363
REGISTER_DISPATCH(div_stub, &div_kernel);
343364
REGISTER_DISPATCH(atan2_stub, &atan2_kernel);
365+
REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel);
344366
REGISTER_DISPATCH(bitwise_xor_stub, &bitwise_xor_kernel);
345367
REGISTER_DISPATCH(logical_xor_stub, &logical_xor_kernel);
346368
REGISTER_DISPATCH(logical_and_stub, &logical_and_kernel);

aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,24 @@ void atan2_kernel_cuda(TensorIterator& iter) {
1818
});
1919
}
2020

21+
void bitwise_and_kernel_cuda(TensorIterator& iter) {
22+
if (iter.dtype() == ScalarType::Bool) {
23+
gpu_kernel_with_scalars(
24+
iter,
25+
[]GPU_LAMBDA(bool a, bool b) {
26+
return a && b;
27+
});
28+
} else {
29+
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_and_cuda", [&]() {
30+
gpu_kernel_with_scalars(
31+
iter,
32+
[]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
33+
return a & b;
34+
});
35+
});
36+
}
37+
}
38+
2139
void bitwise_xor_kernel_cuda(TensorIterator& iter) {
2240
if (iter.dtype() == ScalarType::Bool) {
2341
// Boolean type does not work with ^ (bitwise XOR) in C++. bitwise_xor wraps this operation for both Boolean and
@@ -97,6 +115,7 @@ void mse_kernel_cuda(TensorIterator& iter) {
97115
}
98116

99117
REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
118+
REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel_cuda);
100119
REGISTER_DISPATCH(bitwise_xor_stub, &bitwise_xor_kernel_cuda);
101120
REGISTER_DISPATCH(logical_and_stub, &logical_and_kernel_cuda);
102121
REGISTER_DISPATCH(logical_or_stub, &logical_or_kernel_cuda);

aten/src/ATen/native/native_functions.yaml

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3934,31 +3934,43 @@
39343934
- func: ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
39353935
variants: method
39363936

3937+
- func: bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
3938+
variants: function
3939+
dispatch:
3940+
CPU: bitwise_and_out
3941+
CUDA: bitwise_and_out
3942+
3943+
- func: bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
3944+
variants: function
3945+
dispatch:
3946+
CPU: bitwise_and_out
3947+
CUDA: bitwise_and_out
3948+
3949+
- func: bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor
3950+
variants: method, function
3951+
3952+
- func: bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor
3953+
variants: method, function
3954+
3955+
- func: bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
3956+
variants: method
3957+
3958+
- func: bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
3959+
variants: method
3960+
39373961
- func: __and__.Scalar(Tensor self, Scalar other) -> Tensor
39383962
use_c10_dispatcher: full
39393963
variants: method, function
3940-
dispatch:
3941-
CPU: legacy::cpu::_th_and
3942-
CUDA: legacy::cuda::_th_and
39433964

39443965
- func: __and__.Tensor(Tensor self, Tensor other) -> Tensor
39453966
use_c10_dispatcher: full
39463967
variants: method, function
3947-
dispatch:
3948-
CPU: legacy::cpu::_th_and
3949-
CUDA: legacy::cuda::_th_and
39503968

39513969
- func: __iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
39523970
variants: method
3953-
dispatch:
3954-
CPU: legacy::cpu::_th_iand_
3955-
CUDA: legacy::cuda::_th_iand_
39563971

39573972
- func: __iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
39583973
variants: method
3959-
dispatch:
3960-
CPU: legacy::cpu::_th_iand_
3961-
CUDA: legacy::cuda::_th_iand_
39623974

39633975
- func: __or__.Scalar(Tensor self, Scalar other) -> Tensor
39643976
use_c10_dispatcher: full

aten/src/TH/generic/THTensorEvenMoreMath.cpp

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -295,33 +295,6 @@ accreal THTensor_(sumall)(THTensor *tensor)
295295
return sum;
296296
}
297297

298-
void THTensor_(bitand)(THTensor *r_, THTensor *t, scalar_t value)
299-
{
300-
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) || defined(TH_REAL_IS_BFLOAT16)
301-
(void)r_;
302-
(void)t;
303-
(void)value;
304-
return THError("bitand is only supported for integer type tensors");
305-
#else
306-
THTensor_(resizeAs)(r_, t);
307-
int64_t r_Size = THTensor_(nElement)(r_);
308-
int r_Contig = THTensor_(isContiguous)(r_);
309-
int tContig = THTensor_(isContiguous)(t);
310-
if (r_Contig && tContig) {
311-
scalar_t *tp = t->data<scalar_t>();
312-
scalar_t *rp = r_->data<scalar_t>();
313-
at::parallel_for(0, r_Size, TH_OMP_OVERHEAD_THRESHOLD * 100,
314-
[&](int64_t start, int64_t end) {
315-
for (auto i = start; i < end; i++) {
316-
rp[i] = tp[i] & value;
317-
}
318-
});
319-
} else {
320-
TH_TENSOR_APPLY2_PARALLEL(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = *t_data & value;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
321-
}
322-
#endif
323-
}
324-
325298
scalar_t THTensor_(minall)(THTensor *tensor)
326299
{
327300
scalar_t theMin;

aten/src/TH/generic/THTensorMath.cpp

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,39 +22,6 @@
2222
// sense (rather than just having cut the file down the middle, which is
2323
// what I did when I split these up originally).
2424

25-
void THTensor_(cbitand)(THTensor *r_, THTensor *t, THTensor *src)
26-
{
27-
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF)
28-
(void)r_;
29-
(void)t;
30-
(void)src;
31-
return THError("cbitand is only supported for integer type tensors");
32-
#else
33-
THTensor_(resizeAs)(r_, t);
34-
int64_t r_Size = THTensor_(nElement)(r_);
35-
int64_t srcSize = THTensor_(nElement)(src);
36-
int r_Contig = THTensor_(isContiguous)(r_);
37-
int tContig = THTensor_(isContiguous)(t);
38-
int srcContig = THTensor_(isContiguous)(src);
39-
if (srcSize == r_Size){
40-
if (r_Contig && tContig && srcContig) {
41-
scalar_t *tp = t->data<scalar_t>();
42-
scalar_t *sp = src->data<scalar_t>();
43-
scalar_t *rp = r_->data<scalar_t>();
44-
at::parallel_for(0, r_Size, TH_OMP_OVERHEAD_THRESHOLD,
45-
[&](int64_t start, int64_t end) {
46-
for (auto i = start; i < end; i++) {
47-
rp[i] = tp[i] & sp[i];
48-
}
49-
});
50-
} else {
51-
TH_TENSOR_APPLY3_PARALLEL(r_Size, r_Contig, tContig, srcContig, scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data & *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
52-
}
53-
} else {
54-
TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data & *src_data;);
55-
}
56-
#endif
57-
}
5825

5926
void THTensor_(cbitor)(THTensor *r_, THTensor *t, THTensor *src)
6027
{

aten/src/TH/generic/THTensorMath.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@ TH_API void THTensor_(div)(THTensor *r_, THTensor *t, scalar_t value);
7474

7575
TH_API accreal THTensor_(sumall)(THTensor *t);
7676

77-
TH_API void THTensor_(bitand)(THTensor *r_, THTensor *t, scalar_t value);
78-
TH_API void THTensor_(cbitand)(THTensor *r_, THTensor *t, THTensor *src);
7977
TH_API void THTensor_(bitor)(THTensor *r_, THTensor *t, scalar_t value);
8078
TH_API void THTensor_(cbitor)(THTensor *r_, THTensor *t, THTensor *src);
8179

0 commit comments

Comments
 (0)