Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CustomOp] Split test and add inference test #31078

17 changes: 10 additions & 7 deletions python/paddle/fluid/tests/custom_op/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
# New custom OP can support Windows/Linux now
# 'test_simple_custom_op_jit/test_simple_custom_op_setup' compile .cc and .cu file
py_test(test_simple_custom_op_setup SRCS test_simple_custom_op_setup.py)
py_test(test_simple_custom_op_jit SRCS test_simple_custom_op_jit.py)
# 'test_custom_relu_op_setup/jit' compile .cc and .cu file
py_test(test_custom_relu_op_setup SRCS test_custom_relu_op_setup.py)
py_test(test_custom_relu_op_jit SRCS test_custom_relu_op_jit.py)

# Compiling shared library will cost some time, but running process is very fast.
set_tests_properties(test_simple_custom_op_setup PROPERTIES TIMEOUT 250)
set_tests_properties(test_simple_custom_op_jit PROPERTIES TIMEOUT 180)
set_tests_properties(test_custom_relu_op_setup PROPERTIES TIMEOUT 250)
set_tests_properties(test_custom_relu_op_jit PROPERTIES TIMEOUT 180)

py_test(test_sysconfig SRCS test_sysconfig.py)

# 'test_dispatch' compile .cc file
py_test(test_dispatch SRCS test_dispatch.py)
set_tests_properties(test_dispatch PROPERTIES TIMEOUT 180)
py_test(test_dispatch_jit SRCS test_dispatch_jit.py)
set_tests_properties(test_dispatch_jit PROPERTIES TIMEOUT 180)

py_test(test_multi_out_jit SRCS test_multi_out_jit.py)
set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 180)

if(NOT LINUX)
return()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,6 @@

#include "paddle/extension.h"

template <typename data_t>
void fill_constant_cpu_kernel(data_t* out_data, int64_t x_numel, data_t value) {
for (int i = 0; i < x_numel; ++i) {
out_data[i] = value;
}
}

template <typename data_t>
void relu_cpu_forward_kernel(const data_t* x_data,
data_t* out_data,
Expand Down Expand Up @@ -53,21 +46,8 @@ std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
relu_cpu_forward_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), x.size());
}));
// fake multi output: Fake_float64 with float64 dtype
auto fake_float64 = paddle::Tensor(paddle::PlaceType::kCPU);
fake_float64.reshape(x.shape());

fill_constant_cpu_kernel<double>(
fake_float64.mutable_data<double>(x.place()), x.size(), 0.);

// fake multi output: ZFake_int32 with int32 dtype
auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kCPU);
zfake_int32.reshape(x.shape());

fill_constant_cpu_kernel<int32_t>(
zfake_int32.mutable_data<int32_t>(x.place()), x.size(), 1);

return {out, fake_float64, zfake_int32};
return {out};
}

std::vector<paddle::Tensor> relu_cpu_backward(const paddle::Tensor& x,
Expand Down Expand Up @@ -117,16 +97,16 @@ std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
}

std::vector<std::vector<int64_t>> ReluInferShape(std::vector<int64_t> x_shape) {
return {x_shape, x_shape, x_shape};
return {x_shape};
}

std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype) {
return {x_dtype, paddle::DataType::FLOAT64, paddle::DataType::INT32};
return {x_dtype};
}

PD_BUILD_OP("relu2")
PD_BUILD_OP("custom_relu")
.Inputs({"X"})
.Outputs({"Out", "Fake_float64", "ZFake_int32"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward))
.SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,6 @@

#include "paddle/extension.h"

template <typename data_t>
__global__ void fill_constant_cuda_kernel(data_t* y,
const int num,
data_t value) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
y[i] = value;
}
}

template <typename data_t>
__global__ void relu_cuda_forward_kernel(const data_t* x,
data_t* y,
Expand Down Expand Up @@ -57,18 +47,8 @@ std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
relu_cuda_forward_kernel<data_t><<<grid, block>>>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel);
}));
// fake multi output: Fake_1
auto fake_float64 = paddle::Tensor(paddle::PlaceType::kGPU);
fake_float64.reshape(x.shape());
fill_constant_cuda_kernel<double><<<grid, block>>>(
fake_float64.mutable_data<double>(x.place()), numel, 0.);
// fake multi output: ZFake_1
auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kGPU);
zfake_int32.reshape(x.shape());
fill_constant_cuda_kernel<int32_t><<<grid, block>>>(
zfake_int32.mutable_data<int32_t>(x.place()), numel, 1);

return {out, fake_float64, zfake_int32};
return {out};
}

std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ std::vector<std::vector<int64_t>> ReluInferShape(std::vector<int64_t> x_shape);

std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype);

// Reuse codes in `relu_op_simple.cc/cu` to register another custom operator
// Reuse codes in `custom_relu_op.cc/cu` to register another custom operator
// to test jointly compile multi operators at same time.
PD_BUILD_OP("relu3")
PD_BUILD_OP("custom_relu_dup")
.Inputs({"X"})
.Outputs({"Out", "Fake_float64", "ZFake_int32"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward))
.SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
from utils import paddle_includes, extra_compile_args
from paddle.utils.cpp_extension import CUDAExtension, setup

# custom_relu_op_dup.cc is only used for multi ops test,
# not a new op, if you want to test only one op, remove this
# source file
setup(
name='simple_setup_relu2',
name='custom_relu_module_setup',
ext_modules=CUDAExtension( # test for not specific name here.
sources=[
'relu_op_simple.cc', 'relu_op_simple.cu', 'relu_op3_simple.cc'
'custom_relu_op.cc', 'custom_relu_op.cu', 'custom_relu_op_dup.cc'
], # test for multi ops
include_dirs=paddle_includes,
extra_compile_args=extra_compile_args))
76 changes: 76 additions & 0 deletions python/paddle/fluid/tests/custom_op/multi_out_test_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>
#include <vector>

#include "paddle/extension.h"

template <typename data_t>
void assign_cpu_kernel(const data_t* x_data,
data_t* out_data,
int64_t x_numel) {
for (int i = 0; i < x_numel; ++i) {
out_data[i] = x_data[i];
}
}

template <typename data_t>
void fill_constant_cpu_kernel(data_t* out_data, int64_t x_numel, data_t value) {
for (int i = 0; i < x_numel; ++i) {
out_data[i] = value;
}
}

std::vector<paddle::Tensor> MultiOutCPU(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());

PD_DISPATCH_FLOATING_TYPES(
x.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), x.size());
}));

// fake multi output: Fake_float64 with float64 dtype
auto fake_float64 = paddle::Tensor(paddle::PlaceType::kCPU);
fake_float64.reshape(x.shape());

fill_constant_cpu_kernel<double>(
fake_float64.mutable_data<double>(x.place()), x.size(), 0.);

// fake multi output: ZFake_int32 with int32 dtype
auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kCPU);
zfake_int32.reshape(x.shape());

fill_constant_cpu_kernel<int32_t>(
zfake_int32.mutable_data<int32_t>(x.place()), x.size(), 1);

return {out, fake_float64, zfake_int32};
}

std::vector<std::vector<int64_t>> InferShape(std::vector<int64_t> x_shape) {
return {x_shape, x_shape, x_shape};
}

std::vector<paddle::DataType> InferDtype(paddle::DataType x_dtype) {
return {x_dtype, paddle::DataType::FLOAT64, paddle::DataType::INT32};
}

PD_BUILD_OP("multi_out")
.Inputs({"X"})
.Outputs({"Out", "Fake_float64", "ZFake_int32"})
.SetKernelFn(PD_KERNEL(MultiOutCPU))
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDtype));
86 changes: 86 additions & 0 deletions python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import subprocess
import unittest
import paddle
import numpy as np
from paddle.utils.cpp_extension import load, get_build_directory
from paddle.utils.cpp_extension.extension_utils import run_cmd
from utils import paddle_includes, extra_compile_args
from test_custom_relu_op_setup import custom_relu_dynamic, custom_relu_static

# Because Windows don't use docker, the shared lib already exists in the
# cache dir, it will not be compiled again unless the shared lib is removed.
if os.name == 'nt':
cmd = 'del {}\\custom_relu_module_jit.pyd'.format(get_build_directory())
run_cmd(cmd, True)

# Compile and load custom op Just-In-Time.
# custom_relu_op_dup.cc is only used for multi ops test,
# not a new op, if you want to test only one op, remove this
# source file
custom_module = load(
name='custom_relu_module_jit',
sources=[
'custom_relu_op.cc', 'custom_relu_op.cu', 'custom_relu_op_dup.cc'
],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cflags=extra_compile_args, # add for Coverage CI
verbose=True)


class TestJITLoad(unittest.TestCase):
def setUp(self):
self.custom_ops = [
custom_module.custom_relu, custom_module.custom_relu_dup
]
self.dtypes = ['float32', 'float64']
self.devices = ['cpu', 'gpu']

def test_static(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
for custom_op in self.custom_ops:
out = custom_relu_static(custom_op, device, dtype, x)
pd_out = custom_relu_static(custom_op, device, dtype, x,
False)
self.assertTrue(
np.array_equal(out, pd_out),
"custom op out: {},\n paddle api out: {}".format(
out, pd_out))

def test_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
for custom_op in self.custom_ops:
out, x_grad = custom_relu_dynamic(custom_op, device, dtype,
x)
pd_out, pd_x_grad = custom_relu_dynamic(custom_op, device,
dtype, x, False)
self.assertTrue(
np.array_equal(out, pd_out),
"custom op out: {},\n paddle api out: {}".format(
out, pd_out))
self.assertTrue(
np.array_equal(x_grad, pd_x_grad),
"custom op x grad: {},\n paddle api x grad: {}".format(
x_grad, pd_x_grad))


if __name__ == '__main__':
unittest.main()
Loading