Skip to content

Commit

Permalink
port fmod from TH to ATen (#24405)
Browse files Browse the repository at this point in the history
Summary:
pytorch/pytorch#22803

performance benchmarks:

    import timeit
    import torch
    import itertools
    import statistics

    def test_perf(sizes, device, repeat, times):
        def _tensor(name, sizes, device):
            return '''{0} = torch.rand({1}, device="{2}");'''.format(name, sizes, device)

        setup_code = 'import torch;' + _tensor('x', sizes, device) + _tensor('y', sizes, device)
        test_code = '''torch.fmod(y, x);'''
        if device == "cuda":
            test_code = test_code + 'torch.cuda.synchronize()'
        result = timeit.repeat(setup = setup_code,stmt = test_code,repeat = repeat,number = times)
        mean = statistics.mean(result)
        std = statistics.stdev(result)
        print('''sizes = {0} std = {1} mean = {2}'''.format(sizes, std, mean))

    def test_perf_for_device(device, small, mid, large):
        print(device)
        for s in itertools.product((small, mid, large), (small, mid, large)):
            test_perf(str(s), device, 3, 300)

    test_perf_for_device("cpu", 5, 100, 1000)
    test_perf_for_device("cuda", 5, 100, 10000)

pytorch:master

	cpu
	sizes = (5, 5) std = 0.0004191587896767566 mean = 0.0052408403377436725
	sizes = (5, 100) std = 0.00012129380478190695 mean = 0.006508304664748721
	sizes = (5, 1000) std = 0.00018175678335131663 mean = 0.0363664986701527
	sizes = (100, 5) std = 0.00034399426107962946 mean = 0.006770268999389373
	sizes = (100, 100) std = 0.0006779367543473553 mean = 0.07270567266580959
	sizes = (100, 1000) std = 0.01670362224705441 mean = 0.1300258070017056
	sizes = (1000, 5) std = 0.010281040640935534 mean = 0.045936293997025736
	sizes = (1000, 100) std = 0.012529932966256128 mean = 0.12733882099564653
	sizes = (1000, 1000) std = 0.002150238308503937 mean = 1.1608000710014796
	cuda
	sizes = (5, 5) std = 0.00016137550559233116 mean = 0.014315356330674453
	sizes = (5, 100) std = 0.0014720358192929545 mean = 0.015730336332732502
	sizes = (5, 10000) std = 0.0017510024071247026 mean = 0.015462367334597124
	sizes = (100, 5) std = 0.001569950832690219 mean = 0.015847195667447522
	sizes = (100, 100) std = 0.000935629392520788 mean = 0.015551854667137377
	sizes = (100, 10000) std = 0.002454919985869727 mean = 0.04476405966367262
	sizes = (10000, 5) std = 0.0013192075275361463 mean = 0.015794202001416124
	sizes = (10000, 100) std = 0.001418935833245521 mean = 0.04419450566638261
	sizes = (10000, 10000) std = 0.0070977799177425 mean = 3.267501967328523

shihongzhi:feature/port_fmod

	cpu
	sizes = (5, 5) std = 0.0003939277361171243 mean = 0.008732202996422226
	sizes = (5, 100) std = 7.568185896146914e-05 mean = 0.010897216998273507
	sizes = (5, 1000) std = 3.916722355255723e-05 mean = 0.03223436966557832
	sizes = (100, 5) std = 0.00016529833171236708 mean = 0.011018406672519632
	sizes = (100, 100) std = 0.000155446405937598 mean = 0.055315166668151505
	sizes = (100, 1000) std = 0.005295612670839835 mean = 0.09823771333321929
	sizes = (1000, 5) std = 5.087993715488194e-05 mean = 0.03315563267096877
	sizes = (1000, 100) std = 0.004952377745126246 mean = 0.09605619766807649
	sizes = (1000, 1000) std = 0.10362095898303665 mean = 0.9652185496718934
	cuda
	sizes = (5, 5) std = 5.004076916927963e-05 mean = 0.016851375335439418
	sizes = (5, 100) std = 0.0008912925390246038 mean = 0.01788881132476187
	sizes = (5, 10000) std = 0.0009701942336158022 mean = 0.018210363331794117
	sizes = (100, 5) std = 0.0007897575234315655 mean = 0.017682057005004026
	sizes = (100, 100) std = 0.0012395220098068511 mean = 0.016444508665396523
	sizes = (100, 10000) std = 0.000957364387413519 mean = 0.016943917328414198
	sizes = (10000, 5) std = 0.0011325899538680206 mean = 0.017102815332085203
	sizes = (10000, 100) std = 0.0013052748368152663 mean = 0.017058989333842572
	sizes = (10000, 10000) std = 0.024267574119715446 mean = 0.30735275766831666
Pull Request resolved: pytorch/pytorch#24405

Differential Revision: D16864196

Pulled By: VitalyFedyunin

fbshipit-source-id: d884cc9e74bb8f4ce2ad8d23c676fa914b26d8fb
  • Loading branch information
shihongzhi authored and facebook-github-bot committed Apr 1, 2020
1 parent a736b99 commit ceff21a
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 99 deletions.
41 changes: 0 additions & 41 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -589,47 +589,6 @@
arguments:
- THTensor* self
]]
[[
name: _th_fmod
return: argument 0
variants:
- function
backends:
- CUDA
options:
- cname: fmod
arguments:
- arg: THTensor* result
output: True
- THTensor* self
- real other
- cname: cfmod
arguments:
- arg: THTensor* result
output: True
- arg: THTensor* self
broadcast: other fallback
- THTensor* other
]]
[[
name: _th_fmod_
return: argument 0
variants: function
backends:
- CUDA
options:
- cname: fmod
arguments:
- THTensor* self
- THTensor* self
- real other
- cname: cfmod
arguments:
- THTensor* self
- arg: THTensor* self
broadcast: other inplace fallback
- THTensor* other
]]
[[
name: _th_clamp
cname: clamp
Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,15 +703,13 @@ Tensor& floor_divide_(Tensor& self, Scalar other) {
Tensor& fmod_out(Tensor & result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
TORCH_CHECK(iter.device_type() == at::kCPU, "Native fmod only supports CPU");
fmod_stub(iter.device_type(), iter);
return result;
}

Tensor& fmod_out(Tensor & result, const Tensor& self, Scalar other) {
auto iter = TensorIterator::unary_op(result, self,
/*check_mem_overlap=*/true);
TORCH_CHECK(iter.device_type() == at::kCPU, "Native fmod only supports CPU");
fmod_scalar_stub(iter.device_type(), iter, other);
return result;
}
Expand Down
21 changes: 21 additions & 0 deletions aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/zmath.cuh>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>

Expand Down Expand Up @@ -39,8 +40,28 @@ void mse_kernel_cuda(TensorIterator& iter) {
});
}

void fmod_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "fmod_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return ::fmod(a, b);
});
});
}

void fmod_scalar_kernel_cuda(TensorIterator& iter, Scalar divisor) {
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "fmod_scalar_cuda", [&]() {
using thrust_t = typename ztype_cuda<scalar_t>::thrust_t;
auto div = thrust_t(divisor.to<scalar_t>());
gpu_kernel(iter, [div]GPU_LAMBDA(thrust_t a) -> thrust_t {
return ::fmod(a, div);
});
});
}

REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
REGISTER_DISPATCH(fmod_stub, &fmod_kernel_cuda);
REGISTER_DISPATCH(fmod_scalar_stub, &fmod_scalar_kernel_cuda)

}} // namespace at::native
14 changes: 4 additions & 10 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4304,15 +4304,9 @@

- func: fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
variants: method
dispatch:
CPU: fmod_
CUDA: legacy::cuda::_th_fmod_

- func: fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
variants: method
dispatch:
CPU: fmod_
CUDA: legacy::cuda::_th_fmod_

- func: remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
variants: method
Expand Down Expand Up @@ -5035,26 +5029,26 @@
- func: fmod.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: fmod_out
CUDA: legacy::cuda::_th_fmod_out
CUDA: fmod_out

- func: fmod.Scalar(Tensor self, Scalar other) -> Tensor
use_c10_dispatcher: full
variants: method, function
dispatch:
CPU: fmod
CUDA: legacy::cuda::_th_fmod
CUDA: fmod

- func: fmod.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: fmod_out
CUDA: legacy::cuda::_th_fmod_out
CUDA: fmod_out

- func: fmod.Tensor(Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: method, function
dispatch:
CPU: fmod
CUDA: legacy::cuda::_th_fmod
CUDA: fmod

- func: remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
Expand Down
8 changes: 8 additions & 0 deletions aten/src/THC/THCNumerics.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ struct THCNumerics<uint8_t> {
static inline __host__ __device__ uint8_t pow(uint8_t a, uint8_t b) { return powi<uint8_t>(a, b); }
static inline __host__ __device__ bool isnan(uint8_t a) { return false; }
static inline __host__ __device__ bool isinf(uint8_t a) { return false; }
static inline __host__ __device__ uint8_t fmod(uint8_t a, uint8_t b) { return a % b; }
};

template <>
Expand Down Expand Up @@ -105,6 +106,7 @@ struct THCNumerics<int8_t> {
static inline __host__ __device__ int8_t pow(int8_t a, int8_t b) { return powi<int8_t>(a, b); }
static inline __host__ __device__ bool isnan(int8_t a) { return false; }
static inline __host__ __device__ bool isinf(int8_t a) { return false; }
static inline __host__ __device__ int8_t fmod(int8_t a, int8_t b) { return a % b; }
};

template <>
Expand All @@ -129,6 +131,7 @@ struct THCNumerics<int16_t> {
static inline __host__ __device__ int16_t pow(int16_t a, int16_t b) { return powi<int16_t>(a, b); }
static inline __host__ __device__ bool isnan(int16_t a) { return false; }
static inline __host__ __device__ bool isinf(int16_t a) { return false; }
static inline __host__ __device__ int16_t fmod(int16_t a, int16_t b) { return a % b; }
};

template <>
Expand All @@ -153,6 +156,7 @@ struct THCNumerics<int32_t> {
static inline __host__ __device__ int32_t pow(int32_t a, int32_t b) { return powi<int32_t>(a, b); }
static inline __host__ __device__ bool isnan(int32_t a) { return false; }
static inline __host__ __device__ bool isinf(int32_t a) { return false; }
static inline __host__ __device__ int32_t fmod(int32_t a, int32_t b) { return a % b; }
};

template <>
Expand All @@ -178,6 +182,7 @@ struct THCNumerics<int64_t> {
static inline __host__ __device__ int64_t pow(int64_t a, int64_t b) { return powi<int64_t>(a, b); }
static inline __host__ __device__ bool isnan(int64_t a) { return false; }
static inline __host__ __device__ bool isinf(int64_t a) { return false; }
static inline __host__ __device__ int64_t fmod(int64_t a, int64_t b) { return a % b; }
};

// DEPRECATED: use math functions from std and NumericLimits.cuh
Expand Down Expand Up @@ -210,6 +215,7 @@ struct THCNumerics<at::Half> {
static inline __host__ __device__ at::Half mul(at::Half a, at::Half b) { return a * b; }
static inline __host__ __device__ at::Half sub(at::Half a, at::Half b) { return a - b; }
static inline __host__ __device__ at::Half pow(at::Half a, at::Half b) { return ::pow(a, b); }
static inline __host__ __device__ at::Half fmod(at::Half a, at::Half b) { return ::fmod(a, b); }

static inline __host__ __device__ bool isnan(at::Half a) {
#ifdef _MSC_VER
Expand Down Expand Up @@ -263,6 +269,7 @@ struct THCNumerics<float> {
static inline __host__ __device__ float mul (float a, float b) { return a * b; }
static inline __host__ __device__ float sub (float a, float b) { return a - b; }
static inline __host__ __device__ float pow (float a, float b) { return powf(a, b); }
static inline __host__ __device__ float fmod (float a, float b) { return ::fmod(a, b); }
static inline __host__ __device__ bool isnan(float a) { return ::isnan(a); }
static inline __host__ __device__ bool isinf(float a) { return ::isinf(a); }
};
Expand Down Expand Up @@ -348,6 +355,7 @@ struct THCNumerics<double> {
static inline __host__ __device__ double mul (double a, double b) { return a * b; }
static inline __host__ __device__ double sub (double a, double b) { return a - b; }
static inline __host__ __device__ double pow (double a, double b) { return ::pow(a, b); }
static inline __host__ __device__ double fmod (double a, double b) { return ::fmod(a, b); }
static inline __host__ __device__ bool isnan(double a) { return ::isnan(a); }
static inline __host__ __device__ bool isinf(double a) { return ::isinf(a); }
};
Expand Down
28 changes: 0 additions & 28 deletions aten/src/THC/THCTensorMathPairwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,34 +91,6 @@ struct TensorDivConstantOp<double> {
const double val;
};

template <typename T>
struct TensorFmodOp {
TensorFmodOp(T v) : val((float)v) {}
__device__ __forceinline__ void operator()(T* out, T* in) {
*out = (T) fmodf((float) *in, val);
}

__device__ __forceinline__ void operator()(T* v) {
*v = (T) fmodf((float) *v, val);
}

const float val;
};

template <>
struct TensorFmodOp<double> {
TensorFmodOp(double v) : val(v) {}
__device__ __forceinline__ void operator()(double* out, double* in) {
*out = fmod(*in, val);
}

__device__ __forceinline__ void operator()(double* v) {
*v = fmod(*v, val);
}

const double val;
};

template <typename T, int Upper>
struct TensorTriOp {
TensorTriOp(T *start_, int64_t stride0_, int64_t stride1_, int64_t k_)
Expand Down
18 changes: 0 additions & 18 deletions aten/src/THC/generic/THCTensorMathPairwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,6 @@ void THCTensor_(mul)(THCState *state, THCTensor *self_, THCTensor *src_, scalar_
THCudaCheck(cudaGetLastError());
}

void THCTensor_(fmod)(THCState *state, THCTensor *self_, THCTensor *src_, scalar_t value)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_));
if (self_ == src_) {
if (!THC_pointwiseApply1<scalar_t>(state, self_, TensorFmodOp<scalar_t>(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
} else {
THCTensor_(resizeAs)(state, self_, src_);

if (!THC_pointwiseApply2<scalar_t, scalar_t>(state, self_, src_, TensorFmodOp<scalar_t>(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
}

THCudaCheck(cudaGetLastError());
}

#endif

#endif

0 comments on commit ceff21a

Please sign in to comment.