Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR]Support inplace custom op in pir #60529

Merged
merged 12 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions paddle/fluid/framework/custom_operator_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ inline static const OpMetaInfo& GetOpInfoByPirName(
const std::string& pir_op_name) {
auto custom_name = pir_op_name.substr(strlen(kCustomDialectPrefix));
int pos = custom_name.length();

if (custom_name[pos - 1] == '_') {
// deal with inplace name
custom_name = custom_name.substr(0, pos - 1);
}

pos = custom_name.length();
if (custom_name.find("_grad_grad") != custom_name.npos) {
pos = custom_name.find("_grad_grad") + 1;
} else if (custom_name.find("_grad") != custom_name.npos) {
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,26 @@ class CustomKernelInstruction : public InstructionBase {
void BuildCustomContext(
const paddle::dialect::OpYamlInfoParser& op_yaml_info);

void BuildShapeDtype();

void UpdateOutputMeta(const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<DataType>& output_dtypes);

std::vector<std::vector<int64_t>> RunDefaultInferShape();
std::vector<DataType> RunDefaultInferDtype();
void CheckDefaultInferShapeDtype(
const paddle::dialect::OpYamlInfoParser& op_yaml_info);

paddle::CustomOpKernelContext custom_kernel_ctx_;

paddle::InferShapeFunc infershape_func_ = nullptr;
paddle::InferDtypeFunc inferdtype_func_ = nullptr;
paddle::KernelFunc kernel_func_ = nullptr;

// key is input name, value is a index in input_shapes_ or vec_input_shapes_
std::unordered_map<std::string, int> input_name2id_map_;
std::unordered_map<std::string, int> vec_input_name2id_map_;

// use for runing infershape
std::vector<std::vector<int64_t>> input_shapes_;
std::vector<std::vector<std::vector<int64_t>>> vec_input_shapes_;
Expand All @@ -63,6 +74,10 @@ class CustomKernelInstruction : public InstructionBase {
std::vector<DataType> input_dtypes_;
std::vector<std::vector<DataType>> vec_input_dtypes_;

// use for calculate input shapes and dtypes in runtime
std::vector<phi::DenseTensor*> input_ptrs_;
std::vector<std::vector<phi::DenseTensor*>> vec_input_ptrs_;

// use for update output
std::vector<phi::DenseTensor*> cache_out_ptrs_;

Expand Down
9 changes: 4 additions & 5 deletions paddle/fluid/inference/api/demo_ci/custom_relu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,11 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data,
}

std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape());

auto out = paddle::empty_like(x);
PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cpu_forward", ([&] {
relu_cpu_forward_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), x.size());
x.data<data_t>(), out.data<data_t>(), x.size());
}));

return {out};
Expand All @@ -52,13 +51,13 @@ std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
std::vector<paddle::Tensor> relu_cpu_backward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, x.shape());
auto grad_x = paddle::empty_like(x);

PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
relu_cpu_backward_kernel<data_t>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(x.place()),
grad_x.data<data_t>(),
out.size());
}));

Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/inference/api/demo_ci/custom_relu_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ __global__ void relu_cuda_backward_kernel(const data_t* dy,
}

std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kGPU, x.shape());
auto out = paddle::empty_like(x);

int numel = x.size();
int block = 512;
int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cuda_forward_kernel", ([&] {
relu_cuda_forward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel);
x.data<data_t>(), out.data<data_t>(), numel);
}));

return {out};
Expand All @@ -53,7 +53,7 @@ std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, x.shape());
auto grad_x = paddle::empty_like(x);

int numel = out.size();
int block = 512;
Expand All @@ -63,7 +63,7 @@ std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
<<<grid, block, 0, x.stream()>>>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(x.place()),
grad_x.data<data_t>(),
numel);
}));

Expand Down
27 changes: 25 additions & 2 deletions paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/type_storage.h"
#include "paddle/fluid/pir/dialect/operator/trait/inplace.h"
#include "paddle/fluid/pir/dialect/operator/transforms/param_to_variable.h"
#include "paddle/pir/core/builtin_type_interfaces.h"
#include "paddle/pir/core/interface_value.h"
Expand Down Expand Up @@ -356,9 +357,25 @@ struct CustomOpInfoInterfaceModel : public OpYamlInfoInterface::Concept {
output_name, "paddle::dialect::DenseTensorType", is_optional, false});
}

auto& inplace_maps = OpMetaInfoHelper::GetInplaceReverseMap(op_meta);

if (!inplace_maps.empty()) {
VLOG(3) << "Register Custom Operator: op inplace_map: "
<< string::join_strings(inplace_maps, ',', [](auto& pair) {
return pair.first + ": " + pair.second;
});
}

std::vector<std::pair<std::string, std::string>> vec_inplace;
for (auto inplace_map : inplace_maps) {
vec_inplace.push_back(inplace_map);
}

// we only need kernel params name in run_time_info
paddle::dialect::OpRunTimeInfo run_time_info =
paddle::dialect::OpRunTimeInfo("", {}, "", param_names, {}, {}, {}, {});
paddle::dialect::OpRunTimeInfo(
"", {}, "", param_names, {}, {}, vec_inplace, {});

return std::make_tuple(
inputs_info, attributes_info, outputs_info, run_time_info, "");
}
Expand Down Expand Up @@ -387,6 +404,13 @@ void CustomOpDialect::RegisterCustomOp(const paddle::OpMetaInfo& op_meta) {
pir::TypeId id = IdManager::Instance().CreateId();
std::string op_name = paddle::framework::kCustomDialectPrefix +
OpMetaInfoHelper::GetOpName(op_meta);
std::vector<pir::TypeId> traits;

auto& inplace_map = OpMetaInfoHelper::GetInplaceMap(op_meta);
if (!inplace_map.empty()) {
op_name += "_";
traits.push_back(pir::TypeId::get<paddle::dialect::InplaceTrait>());
}
op_names_.push_back(op_name);

auto& op_attrs = OpMetaInfoHelper::GetAttrs(op_meta);
Expand All @@ -400,7 +424,6 @@ void CustomOpDialect::RegisterCustomOp(const paddle::OpMetaInfo& op_meta) {
AttributeManager::Instance().ToCharPointers(attr_names);
uint32_t attr_num = attr_names.size();

std::vector<pir::TypeId> traits;
std::set<pir::InterfaceValue> interface_values;
pir::InterfaceValue op_info_interface =
pir::InterfaceValue::Get<OpYamlInfoInterface,
Expand Down
5 changes: 5 additions & 0 deletions test/custom_op/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ if(WITH_TESTING)
set_tests_properties(test_context_pool PROPERTIES TIMEOUT 180)
endif()

if(WITH_GPU)
py_test(test_inference_inplace SRCS test_inference_inplace.py)
set_tests_properties(test_inference_inplace PROPERTIES TIMEOUT 180)
endif()

# custom OP support TensorRT inference
if(WITH_GPU
AND WITH_TENSORRT
Expand Down
50 changes: 50 additions & 0 deletions test/custom_op/custom_inplace.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2024,下同

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,3q

//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WIdata_tHOUdata_t WARRANdata_tIES OR CONDIdata_tIONS OF ANY KIND, either
// express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>
#include <vector>

#include "paddle/extension.h"

#define CHECK_GPU_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")

template <typename data_t>
__global__ void relu_cuda_forward_kernel(data_t* x, int64_t num) {
int64_t gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) {
x[i] = x[i] > static_cast<data_t>(0.) ? x[i] : static_cast<data_t>(0.);
}
}

void ReluForwardInplace(paddle::Tensor& x) { // NOLINT
CHECK_GPU_INPUT(x);

PD_CHECK(x.place() == paddle::DefaultGPUPlace());

int64_t numel = x.numel();
int64_t block = 512;
int64_t grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
x.type(), "relu_cuda_forward_kernel", ([&] {
relu_cuda_forward_kernel<data_t>
<<<grid, block, 0, x.stream()>>>(x.data<data_t>(), numel);
}));
}

PD_BUILD_OP(custom_relu_inplace)
.Inputs({"X"})
.Outputs({"Out"})
.SetInplaceMap({{"X", "Out"}})
.SetKernelFn(PD_KERNEL(ReluForwardInplace));
138 changes: 138 additions & 0 deletions test/custom_op/test_inference_inplace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest

import numpy as np
from utils import (
extra_cc_args,
extra_nvcc_args,
paddle_includes,
)

import paddle
from paddle.inference import Config, create_predictor
from paddle.utils.cpp_extension import get_build_directory, load
from paddle.utils.cpp_extension.extension_utils import run_cmd

# Because Windows don't use docker, the shared lib already exists in the
# cache dir, it will not be compiled again unless the shared lib is removed.
file = f'{get_build_directory()}\\infer_custom\\infer_custom.pyd'
if os.name == 'nt' and os.path.isfile(file):
cmd = f'del {file}'
run_cmd(cmd, True)

# Compile and load custom op Just-In-Time.
custom_inplace = load(
name='infer_custom',
sources=['custom_inplace.cu'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cxx_cflags=extra_cc_args, # test for cflags
extra_cuda_cflags=extra_nvcc_args, # test for cflags
verbose=True,
)


class TestInplaceNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.fc = paddle.nn.Linear(4, 4)

def forward(self, x):
fc_out = self.fc(x)
out = custom_inplace.custom_relu_inplace(fc_out)
mean_out = paddle.mean(out)
return mean_out


@unittest.skipIf(
not paddle.is_compiled_with_cuda(), 'should compile with cuda.'
)
class TestPredictorRunWithTensor(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
net = TestInplaceNet()
model = paddle.jit.to_static(
net,
input_spec=[
paddle.static.InputSpec(
shape=[None, 4], dtype='float32', name='x'
),
],
)
paddle.jit.save(
model,
os.path.join(
self.temp_dir.name, 'test_predictor_run_model/inference'
),
)

def tearDown(self):
self.temp_dir.cleanup()

def enable_pir(self, flag: bool):
paddle.set_flags({'FLAGS_enable_pir_in_executor': flag})

def init_predictor(self):
config = Config(
os.path.join(
self.temp_dir.name,
'test_predictor_run_model/inference.pdmodel',
),
os.path.join(
self.temp_dir.name,
'test_predictor_run_model/inference.pdiparams',
),
)
config.enable_use_gpu(256, 0)
config.switch_ir_optim(False)
config.enable_new_executor()
predictor = create_predictor(config)
return predictor

def get_inputs(self):
x = np.array([[1, 2, 3, 4], [2, 3, 4, 5]]).astype(np.float32)

x_tensor = paddle.to_tensor(x)

return [x_tensor]

def get_outputs(self, predictor):
[x_tensor] = self.get_inputs()

input_names = predictor.get_input_names()
x_tensor.name = input_names[0]

# disorder
inputs = [x_tensor]
outputs = predictor.run(inputs)

return outputs[0]

def test_output(self):
self.enable_pir(True)
pir_predictor = self.init_predictor()
pir_output = self.get_outputs(pir_predictor)
self.enable_pir(False)
predictor = self.init_predictor()
output = self.get_outputs(predictor)
np.testing.assert_allclose(
output.numpy().flatten(), pir_output.numpy().flatten()
)


if __name__ == "__main__":
unittest.main()