Skip to content

Commit

Permalink
[CINN] Add local var value for cinn ir::Tensor (PaddlePaddle#64459)
Browse files Browse the repository at this point in the history
* cinn(backend): change slice's attribute from Int to Expr

* add precision check

* fix bug of cuda kernel name

* fix bug

* revert test_cinn_transform_symbolic

---------

Co-authored-by: 6clc <chaoliu.lc@foxmail.com>
  • Loading branch information
2 people authored and co63oc committed May 23, 2024
1 parent 61f30af commit 72e8e70
Show file tree
Hide file tree
Showing 9 changed files with 220 additions and 63 deletions.
7 changes: 7 additions & 0 deletions paddle/cinn/backends/codegen_device_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ std::string
detail::CollectBucketStrategyHostFunctionVisitor::GenDeviceKernelName(
const std::string &fn_name, ir::Expr predicate) {
std::string cond_str = Predicate2String(predicate);
// replace '-' with 'NEG'
size_t pos = cond_str.find("-", 0);
const std::string replacement = "NEG";
while (pos != std::string::npos) {
cond_str.replace(pos, 1, replacement);
pos = cond_str.find("-", pos + replacement.length());
}
VLOG(3) << "predicate string: " << cond_str;
return fn_name + "__COND_" + cond_str + "__kernel";
}
Expand Down
22 changes: 20 additions & 2 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/cinn/adt/map_expr_ctx.h"
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/common/dim_expr_converter.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/framework/compile_error.h"
#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h"
Expand Down Expand Up @@ -72,6 +73,17 @@ NodeAttr CollectAttrs(const ::pir::Operation& op) {
return node_attrs;
}

std::optional<std::vector<ir::Expr>> GetTensorValueFromShapeOrData(
const symbol::ShapeOrDataDimExprs& shape_or_data) {
if (!shape_or_data.data()) return std::nullopt;
std::vector<ir::Expr> result;
result.reserve(shape_or_data.data()->size());
for (const auto& data : *shape_or_data.data()) {
result.push_back(common::DimExprConverter().ConvertToIrExpr(data));
}
return result;
}

} // namespace details

std::shared_ptr<GroupInfo> OpLowererImpl::GetGroupInfo(
Expand Down Expand Up @@ -871,7 +883,7 @@ std::vector<ir::Expr> OpLowererImpl::LowerOps(
}

for (auto* op : ops) {
VLOG(4) << "start lowering op:" << op->name();
VLOG(4) << "start lowering op:" << op->name() << " id: " << op->id();
std::string cinn_op_name = CompatibleInfo::OpName(*op);

VLOG(4) << "cinn op name " << cinn_op_name << std::endl;
Expand Down Expand Up @@ -1089,8 +1101,14 @@ ir::Tensor OpLowererImpl::GetTensor(const OpLoweringGroupPtr& group,
if (sym_shape.empty()) {
sym_shape.emplace_back(input_id, symbol::DimExpr{1});
}
return lang::CreatePlaceHolder(
auto tensor = lang::CreatePlaceHolder(
sym_shape, CompatibleInfo::ConvertIRType(dtype), input_id);
const auto& tensor_value = details::GetTensorValueFromShapeOrData(
group->GetShapeOrDataExprs(value));
if (tensor_value.has_value()) {
tensor->set_value(*tensor_value);
}
return tensor;
} else {
auto shape = ::common::vectorize<int>(type_info.dims());
return lang::CreatePlaceHolder(
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/pir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ const std::unordered_set<std::string> TOCINN_OPS = {
PD_OP_NAME(ScaleOp),
PD_OP_NAME(Pool2dOp),
PD_OP_NAME(IscloseOp),
PD_OP_NAME(SliceOp),
// PD_OP_NAME(SliceOp),
PD_OP_NAME(ConcatOp),
PD_OP_NAME(SplitOp),
PD_OP_NAME(SplitWithNumOp),
Expand Down
147 changes: 102 additions & 45 deletions paddle/cinn/hlir/op/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1790,53 +1790,97 @@ std::shared_ptr<OpStrategy> StrategyForSlice(
return strategy;
}

template <typename T = int>
std::vector<T> GetIntVectorFromAttr(const utils::Attribute &attr) {
if (absl::holds_alternative<std::vector<int64_t>>(attr)) {
const auto &attr_data = absl::get<std::vector<int64_t>>(attr);
return std::vector<T>(attr_data.begin(), attr_data.end());
} else if (absl::holds_alternative<std::vector<int>>(attr)) {
const auto &attr_data = absl::get<std::vector<int>>(attr);
return std::vector<T>(attr_data.begin(), attr_data.end());
} else if (absl::holds_alternative<bool>(attr)) {
return std::vector<T>{};
} else {
PADDLE_THROW(
phi::errors::InvalidArgument("attribute's vector type is invalid!"));
}
}
std::shared_ptr<OpStrategy> StrategyForSliceSymbolic(
const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<ir::Dim>> &output_shapes,
const Target &target) {
std::vector<int> starts, ends, axes, strides, decrease_axis;
if (attrs.attr_store.find("starts") != attrs.attr_store.end()) {
starts = absl::get<std::vector<int>>(attrs.attr_store.at("starts"));
}
if (attrs.attr_store.find("ends") != attrs.attr_store.end()) {
ends = absl::get<std::vector<int>>(attrs.attr_store.at("ends"));
}
if (attrs.attr_store.find("axes") != attrs.attr_store.end()) {
axes = absl::get<std::vector<int>>(attrs.attr_store.at("axes"));
}
if (attrs.attr_store.find("strides") != attrs.attr_store.end()) {
strides = absl::get<std::vector<int>>(attrs.attr_store.at("strides"));
}
if (attrs.attr_store.find("decrease_axis") != attrs.attr_store.end()) {
decrease_axis =
absl::get<std::vector<int>>(attrs.attr_store.at("decrease_axis"));
}

CHECK(!starts.empty()) << "The Slice op doesn't find [starts] attribute! It "
"it a mandatory attribute, please check.";
CHECK(!ends.empty()) << "The Slice op doesn't find [ends] attribute! It it a "
"mandatory attribute, please check.";
CHECK_EQ(starts.size(), ends.size())
<< "The size of [starts] and [ends] must be identical! Please check.";
if (!axes.empty()) {
CHECK_EQ(starts.size(), axes.size())
<< "The size of [starts] and [axes] must be identical! Please check.";
} else {
for (int i = 0; i < starts.size(); i++) {
axes.push_back(i);
const std::vector<Expr> starts_expr = [&] {
if (inputs.size() == 3) {
const auto &value = inputs.at(1).self()->value();
CHECK(value.has_value());
return value.value();
}
}
if (!strides.empty()) {
CHECK_EQ(starts.size(), strides.size())
<< "The size of [starts] and [strides] must be identical! Please "
"check.";
} else {
for (int i = 0; i < starts.size(); i++) {
strides.push_back(1);
if (attrs.attr_store.find("starts") != attrs.attr_store.end()) {
return ToCinnExprs(GetIntVectorFromAttr(attrs.attr_store.at("starts")));
} else {
PADDLE_THROW(::common::errors::InvalidArgument(
"The Slice op doesn't find [starts] attribute!"));
}
}
}();
const std::vector<Expr> ends_expr = [&] {
if (inputs.size() == 3) {
const auto &value = inputs.at(2).self()->value();
CHECK(value.has_value());
return value.value();
}
if (attrs.attr_store.find("ends") != attrs.attr_store.end()) {
return ToCinnExprs(GetIntVectorFromAttr(attrs.attr_store.at("ends")));
} else {
PADDLE_THROW(::common::errors::InvalidArgument(
"The Slice op doesn't find [ends] attribute!"));
}
}();
const std::vector<int> axes = [&] {
std::vector<int> axes;
if (attrs.attr_store.find("axes") != attrs.attr_store.end()) {
axes = GetIntVectorFromAttr(attrs.attr_store.at("axes"));
}
if (axes.empty()) {
for (int i = 0; i < starts_expr.size(); i++) {
axes.push_back(i);
}
}
return axes;
}();
const std::vector<Expr> strides_expr = [&] {
std::vector<int> strides;
if (attrs.attr_store.find("strides") != attrs.attr_store.end()) {
strides = GetIntVectorFromAttr(attrs.attr_store.at("strides"));
}
if (strides.empty()) {
for (int i = 0; i < starts_expr.size(); i++) {
strides.push_back(1);
}
}
return ToCinnExprs(strides);
}();
const std::vector<int> decrease_axis = [&] {
if (attrs.attr_store.find("decrease_axis") != attrs.attr_store.end()) {
return GetIntVectorFromAttr(attrs.attr_store.at("decrease_axis"));
}
return std::vector<int>{};
}();

CHECK(!starts_expr.empty())
<< "The Slice op doesn't find [starts] attribute! It "
"it a mandatory attribute, please check.";
CHECK(!ends_expr.empty())
<< "The Slice op doesn't find [ends] attribute! It it a "
"mandatory attribute, please check.";
CHECK_EQ(starts_expr.size(), ends_expr.size())
<< "The size of [starts] and [ends] must be identical! Please check.";
CHECK_EQ(starts_expr.size(), axes.size())
<< "The size of [starts] and [axes] must be identical! Please check.";
CHECK_EQ(starts_expr.size(), strides_expr.size())
<< "The size of [starts] and [strides] must be identical! Please "
"check.";

std::vector<Expr> output_shape;
for (auto &i : output_shapes[0]) {
Expand All @@ -1855,12 +1899,25 @@ std::shared_ptr<OpStrategy> StrategyForSliceSymbolic(
CHECK(A_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref();

CHECK_EQ(arg_pack.size(), 2U);
CHECK(arg_pack[1].is_string());
std::string tensor_name = arg_pack[1].operator std::string();

auto out = pe::SliceSymbolic(
A, starts, axes, strides, decrease_axis, output_shape, tensor_name);
const std::string tensor_name = [&] {
if (arg_pack.size() == 2 || arg_pack.size() == 4) {
CHECK(arg_pack.back().is_string());
return arg_pack.back().operator std::string();
}
PADDLE_THROW(::common::errors::InvalidArgument(
"The slice op doesn't find output tensor name! The size of "
"arg_pack is %d.",
arg_pack.size()));
}();

auto out = pe::SliceSymbolic(A,
starts_expr,
axes,
strides_expr,
decrease_axis,
output_shape,
tensor_name);
VLOG(4) << "out: " << out;
auto stages = CreateStages({out});
*ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}};
});
Expand Down
10 changes: 3 additions & 7 deletions paddle/cinn/hlir/pe/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1129,9 +1129,9 @@ ir::Tensor Slice(const ir::Tensor& A,
}

ir::Tensor SliceSymbolic(const ir::Tensor& A,
const std::vector<int>& starts,
const std::vector<Expr>& starts,
const std::vector<int>& const_axes,
const std::vector<int>& strides,
const std::vector<Expr>& strides,
const std::vector<int>& decrease_axis,
const std::vector<Expr>& output_shape,
const std::string& output_name) {
Expand All @@ -1140,11 +1140,7 @@ ir::Tensor SliceSymbolic(const ir::Tensor& A,
input_shape.emplace_back(shape);
}

std::vector<Expr> new_starts;
std::transform(starts.begin(),
starts.end(),
std::back_inserter(new_starts),
[](const int start) { return ir::Expr(start); });
std::vector<Expr> new_starts = starts;
std::vector<int> axes;
std::transform(const_axes.begin(),
const_axes.end(),
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/pe/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ ir::Tensor Slice(const ir::Tensor& A,
const std::string& output_name);

ir::Tensor SliceSymbolic(const ir::Tensor& A,
const std::vector<int>& starts,
const std::vector<Expr>& starts,
const std::vector<int>& axes,
const std::vector<int>& strides,
const std::vector<Expr>& strides,
const std::vector<int>& decrease_axis,
const std::vector<Expr>& output_shape,
const std::string& output_name);
Expand Down
9 changes: 9 additions & 0 deletions paddle/cinn/ir/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <map>
#include <memory>
#include <optional>
#include <set>
#include <string>
#include <utility>
Expand Down Expand Up @@ -318,6 +319,10 @@ class _Tensor_ : public ExprNode<_Tensor_> {
poly::StageMap stages,
const Target& target = cinn::common::DefaultHostTarget()) const;

const std::optional<std::vector<Expr>>& value() const { return value_; }

void set_value(const std::vector<Expr>& value) { value_ = value; }

private:
//! Initialize the axis field after the shape field is assigned.
void InitAxis() const;
Expand All @@ -328,6 +333,10 @@ class _Tensor_ : public ExprNode<_Tensor_> {
//! this.
std::set<std::string> buffer_depended_tensor_names_;

// The flatten compute value of tensor, such as Tensor[[1, 2], [3, 4]] ->
// Tensor[1, 2, 3, 4]
std::optional<std::vector<Expr>> value_;

friend Shared<poly::Stage> CreateStage(Tensor tensor);
};

Expand Down
8 changes: 2 additions & 6 deletions test/ir/pir/cinn/llama_test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,8 @@ def _set_cos_sin_cache(self, seq_len):

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# TODO(phlrain): cinn slice not support end is a DimExpr
# WIP for support it
# cos = self.cos_cached[:, :seq_len, :, :]
# sin = self.sin_cached[:, :seq_len, :, :]
cos = self.cos_cached
sin = self.sin_cached
cos = self.cos_cached[:, :seq_len, :, :]
sin = self.sin_cached[:, :seq_len, :, :]
return (
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
Expand Down
74 changes: 74 additions & 0 deletions test/ir/pir/cinn/symbolic/test_dyshape_slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import unittest
from os.path import dirname

import numpy as np

import paddle
from paddle import nn
from paddle.static import InputSpec

sys.path.append(dirname(dirname(__file__)))

import utils


class CastLayer(nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x):
end = x.shape[1] - 1
return x[:, :end, :]


class TestCast(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.prepare_data()

def prepare_data(self):
self.shape = [1024, 32, 1024, 17]
self.x = paddle.randn(self.shape, dtype="float32")
self.x.stop_gradient = True

def check_jit_kernel_info(self, static_fn):
utils.check_jit_kernel_number(static_fn, 1)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})

def eval(self, use_cinn):
net = CastLayer()
input_spec = [
InputSpec(shape=[None, None, 1024, None], dtype='float32'),
]
net = utils.apply_to_static(net, use_cinn, input_spec)
net.eval()
out = net(self.x)
if use_cinn:
self.check_jit_kernel_info(net.forward)
return out

def test_eval(self):
cinn_out = self.eval(use_cinn=True)
dy_out = self.eval(use_cinn=False)
np.testing.assert_allclose(
cinn_out.numpy(), dy_out.numpy(), atol=1e-6, rtol=1e-6
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 72e8e70

Please sign in to comment.