Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Add triangular solve op (#1224)
Browse files Browse the repository at this point in the history
* add triangular solve op

* add tests

* update cublas call

* post pre-commit ;)

* post-commit ;)

* conditional compilation for triangular solve test

* update test cases && add broadcast

* update test case for left_side=False && remove header

* update left_side=False test cast

* update singular test case

* refine broadcast
  • Loading branch information
zzk0 authored Mar 3, 2023
1 parent d0ac73e commit b3591cd
Show file tree
Hide file tree
Showing 13 changed files with 742 additions and 1 deletion.
37 changes: 37 additions & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <vector>

#include "cinn/frontend/syntax.h"
#include "cinn/hlir/pe/broadcast.h"

namespace cinn {
namespace frontend {
Expand Down Expand Up @@ -692,6 +693,42 @@ Variable NetBuilder::Cholesky(const Variable& x, bool upper) {
return CustomInstr("cholesky", {x}, {{"upper", upper}}).front();
}

Variable NetBuilder::TriangularSolve(
const Variable& input1, const Variable& input2, bool left_side, bool upper, bool transpose_a, bool unit_diagonal) {
// broadcast
std::vector<Variable> inputs{input1, input2};
{
auto a_ndim = input1->shape.size();
auto b_ndim = input2->shape.size();
CHECK_GE(a_ndim, 2) << "The input matrix A shape size should >= 2! Please check again.";
CHECK_GE(b_ndim, 2) << "The input matrix B shape size should >= 2! Please check again.";
std::vector<int> input1_shape_cut(input1->shape.begin(), input1->shape.end() - 2);
std::vector<int> input2_shape_cut(input2->shape.begin(), input2->shape.end() - 2);
std::vector<int> common_shape;
hlir::pe::GetBroadcastOutShape(input1_shape_cut, input2_shape_cut, &common_shape);

// broadcast input1
std::vector<int> input1_shape(common_shape.begin(), common_shape.end());
input1_shape.push_back(input1->shape[a_ndim - 2]);
input1_shape.push_back(input1->shape[a_ndim - 1]);
inputs[0] = BroadcastTo(input1, input1_shape);

// broadcast input2
std::vector<int> input2_shape(common_shape.begin(), common_shape.end());
input2_shape.push_back(input2->shape[b_ndim - 2]);
input2_shape.push_back(input2->shape[b_ndim - 1]);
inputs[1] = BroadcastTo(input2, input2_shape);
}

return CustomInstr("triangular_solve",
inputs,
{{"left_side", left_side},
{"upper", upper},
{"transpose_a", transpose_a},
{"unit_diagonal", unit_diagonal}})
.front();
}

Variable NetBuilder::Norm(const Variable& x, int axis, float epsilon) {
Instruction instr("norm", {x});
instr.SetAttr<int32_t>("axis", axis);
Expand Down
15 changes: 15 additions & 0 deletions cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,21 @@ class NetBuilder {
*/
Variable Cholesky(const Variable& x, bool upper = false);

/**
* @brief Solve triangular linear systems with multiple right-hand-sides.
* @param input1 triangular matrix stored in lower or upper mode.
* @param input2 matrix on the right hand side.
* @param left_side When left_side is true, compute A*X = B.
When left_side is false, compute X*A = B.
* @param upper When upper is true, use the upper part of the triangular matrix.
When upper is false, use the lower part of the triangular matrix.
* @param transpose_a When transpose_a is true, use the transpose of matrix A
* @param unit_diagonal When unit_diagonal is true, assume the elements on the main diagonal of matrix A are unity
* @return The solution for the triangular linear systems.
*/
Variable TriangularSolve(
const Variable& input1, const Variable& input2, bool left_side, bool upper, bool transpose_a, bool unit_diagonal);

/**
* @brief l2-Norm
* @param x The input operand to be normed.
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ gather_srcs(cinnapi_src SRCS
gaussian_random.cc
uniform_random.cc
cholesky.cc
triangular_solve.cc
)

cc_test(test_gather_nd SRCS gather_nd_test.cc DEPS cinncore)
Expand Down
121 changes: 121 additions & 0 deletions cinn/hlir/op/contrib/triangular_solve.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include <vector>

#include "cinn/common/common.h"
#include "cinn/common/macros.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op.h"
#include "cinn/hlir/framework/op_strategy.h"
#include "cinn/hlir/op/op_util.h"
#include "cinn/hlir/pe/elementwise.h"
#include "cinn/hlir/pe/ir_schedule_pe.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_base.h"
#include "cinn/ir/tensor.h"
#include "cinn/lang/builtin.h"
#include "cinn/lang/compute.h"

namespace cinn {
namespace hlir {
namespace op {

using common::CINNValue;
using common::CINNValuePack;

std::shared_ptr<framework::OpStrategy> StrategyForTriangularSolve(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<int>> &output_shapes,
const Target &target) {
framework::CINNCompute triangular_solve_compute([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of triangular_solve is empty! Please check.";
CINNValuePack pack_args = args[0];
CHECK_GE(pack_args.size(), 2U) << "Two input tensors are required for the computation of triangular_solve.";
Expr a_expr = pack_args[0];
Expr b_expr = pack_args[1];
ir::Tensor a = a_expr.as_tensor_ref();
ir::Tensor b = b_expr.as_tensor_ref();
std::string tensor_name = "triangular_solve_out";
auto out = pe::Identity(b, tensor_name).front();
auto stages = CreateStages({out});
std::vector<CINNValue> res{CINNValue(out), CINNValue(stages)};
*ret = CINNValuePack{res};
});
auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(
triangular_solve_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.triangular_solve.x86", 1);
return strategy;
}

std::vector<framework::shape_t> InferShapeForTriangularSolve(const std::vector<framework::shape_t> &inputs_shape,
const framework::AttrMapType &attrs) {
CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again.";
framework::shape_t a_shape = inputs_shape[0];
framework::shape_t b_shape = inputs_shape[1];
int a_shape_size = a_shape.size();
int b_shape_size = b_shape.size();
CHECK_GE(a_shape_size, 2U) << "The input matrix A shape size should >= 2! Please check again.";
CHECK_GE(b_shape_size, 2U) << "The input matrix B shape size should >= 2! Please check again.";

int left_side = -1;
for (auto &iter : attrs) {
if (iter.first == "left_side") {
left_side = absl::get<bool>(iter.second);
break;
}
}

CHECK_EQ(a_shape[a_shape_size - 2], a_shape[a_shape_size - 1])
<< "The last two dimensions of the input a must be the same!";
if (left_side) {
CHECK_EQ(a_shape[a_shape_size - 2], b_shape[b_shape_size - 2])
<< "The last-but-one dimension of the two vectors must be consistent.";
} else {
CHECK_EQ(a_shape[a_shape_size - 1], b_shape[b_shape_size - 1])
<< "The last dimension of the two vectors must be consistent.";
}

return {b_shape};
}

std::vector<Type> InferDtypeForTriangularSolve(const std::vector<Type> &inputs_type,
const framework::AttrMapType &attrs) {
CHECK_EQ(inputs_type.size(), 2U) << "The input's shape size should be 2! Please check again.";
CHECK(inputs_type[0].is_float(32) || inputs_type[0].is_float(64))
<< "The input's dtype should be float32 or float64! Please check again.";
CHECK(inputs_type[1].is_float(32) || inputs_type[1].is_float(64))
<< "The input's dtype should be float32 or float64! Please check again.";
return std::vector<Type>{inputs_type[1]};
}

} // namespace op
} // namespace hlir
} // namespace cinn

CINN_REGISTER_HELPER(triangular_solve_ops) {
CINN_REGISTER_OP(triangular_solve)
.describe("TriangularSolve")
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForTriangularSolve)
.set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForTriangularSolve))
.set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForTriangularSolve))
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible)
.set_support_level(4);

return true;
}
40 changes: 40 additions & 0 deletions cinn/hlir/op/custom_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,44 @@ std::vector<ir::Expr> CustomCallArgsForCholesky(const framework::NodeAttr &attrs
return args;
}

std::vector<ir::Expr> CustomCallArgsForTriangularSolve(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<std::vector<int>> &output_shapes) {
CHECK_EQ(inputs.size(), 2UL);
auto attr_store = attrs.attr_store;
CHECK(attr_store.count("left_side"));
CHECK(attr_store.count("upper"));
CHECK(attr_store.count("transpose_a"));
CHECK(attr_store.count("unit_diagonal"));

ir::Tensor a = inputs[0];
ir::Tensor b = inputs[1];
int a_ndim = static_cast<int>(a->shape.size());
int b_ndim = static_cast<int>(b->shape.size());
int batch_size = 1;
for (int i = 0; i < a_ndim - 2; i++) {
batch_size *= a->shape[i].as_int32();
}

auto left_side = absl::get<bool>(attrs.attr_store.at("left_side"));
auto upper = absl::get<bool>(attrs.attr_store.at("upper"));
auto transpose_a = absl::get<bool>(attrs.attr_store.at("transpose_a"));
auto unit_diagonal = absl::get<bool>(attrs.attr_store.at("unit_diagonal"));

int m = a->shape[a_ndim - 1].as_int32();
int k = left_side ? b->shape[b_ndim - 1].as_int32() : b->shape[b_ndim - 2].as_int32();

std::vector<ir::Expr> args = {ir::Expr(batch_size),
ir::Expr(m),
ir::Expr(k),
ir::Expr(left_side),
ir::Expr(upper),
ir::Expr(transpose_a),
ir::Expr(unit_diagonal)};

return args;
}

bool RegisteryCustomCallArgsFunc() {
#ifdef CINN_WITH_CUDA
CustomCallArgsFuncRegistry::Global().Register(
Expand All @@ -734,6 +772,8 @@ bool RegisteryCustomCallArgsFunc() {
"cinn_call_cholesky_nvgpu", common::DefaultNVGPUTarget(), CustomCallArgsForCholesky);
CustomCallArgsFuncRegistry::Global().Register(
"cinn_call_batched_cublas", common::DefaultNVGPUTarget(), CustomCallArgsForBatchedCublas);
CustomCallArgsFuncRegistry::Global().Register(
"cinn_call_triangular_solve_nvgpu", common::DefaultNVGPUTarget(), CustomCallArgsForTriangularSolve);
#endif

#ifdef CINN_WITH_CUDNN
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/op/external_api_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ CINN_REGISTER_HELPER(op_external_api) {
CINN_OP_REGISTER_EXTERNAL_API(uniform_random, default_nvgpu).set_api_name("cinn_call_uniform_random");
CINN_OP_REGISTER_EXTERNAL_API(cholesky, default_nvgpu).set_api_name("cinn_call_cholesky_nvgpu");
CINN_OP_REGISTER_EXTERNAL_API(cholesky, default_host).set_api_name("cinn_call_cholesky_host");
CINN_OP_REGISTER_EXTERNAL_API(triangular_solve, default_nvgpu).set_api_name("cinn_call_triangular_solve_nvgpu");
#ifdef CINN_WITH_CUDNN
CINN_OP_REGISTER_EXTERNAL_API(conv2d, default_nvgpu).set_trans_func([](const ::cinn::hlir::framework::Node* node) {
CHECK(node->attrs.attr_store.count("conv_type"));
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/op/use_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ CINN_USE_REGISTER(reciprocal_ops)
CINN_USE_REGISTER(gaussian_random_ops)
CINN_USE_REGISTER(uniform_random_ops)
CINN_USE_REGISTER(cholesky_ops)
CINN_USE_REGISTER(triangular_solve_ops)
CINN_USE_REGISTER(op_external_api)
10 changes: 9 additions & 1 deletion cinn/pybind/frontend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,15 @@ void BindFrontend(pybind11::module *m) {
py::arg("seed") = 0,
py::arg("dtype") = "float32")
.def("norm", &NetBuilder::Norm, py::arg("x"), py::arg("axis") = -1, py::arg("epsilon") = 1e-12f)
.def("cholesky", &NetBuilder::Cholesky, py::arg("x"), py::arg("upper") = false);
.def("cholesky", &NetBuilder::Cholesky, py::arg("x"), py::arg("upper") = false)
.def("triangular_solve",
&NetBuilder::TriangularSolve,
py::arg("input1"),
py::arg("input2"),
py::arg("left_side") = true,
py::arg("upper") = false,
py::arg("transpose_a") = false,
py::arg("unit_diagonal") = false);

auto computation = py::class_<CinnComputation, std::shared_ptr<CinnComputation>>(*m, "Computation");
py::class_<CinnComputation::CompileOptions>(computation, "CompileOptions")
Expand Down
15 changes: 15 additions & 0 deletions cinn/runtime/cuda/cuda_intrinsics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,21 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) {
.AddInputType<void *>() // stream
.End();

using cinn::runtime::cuda::cinn_call_triangular_solve_nvgpu;
REGISTER_EXTERN_FUNC_HELPER(cinn_call_triangular_solve_nvgpu, cinn::common::DefaultNVGPUTarget())
.SetRetType<void>()
.AddInputType<void *>() // v_args
.AddInputType<int>() // num_args
.AddInputType<int>() // batch_size
.AddInputType<int>() // m
.AddInputType<int>() // k
.AddInputType<bool>() // left_side
.AddInputType<bool>() // upper
.AddInputType<bool>() // transpose_a
.AddInputType<bool>() // unit_diagonal
.AddInputType<void *>() // stream
.End();

#ifdef CINN_WITH_CUDNN
using cinn::runtime::cuda::cinn_call_cudnn_conv2d_forward;
REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_conv2d_forward, cinn::common::DefaultHostTarget())
Expand Down
Loading

0 comments on commit b3591cd

Please sign in to comment.