Skip to content

Commit

Permalink
[Unity] Relax op: linear algebra (#13988)
Browse files Browse the repository at this point in the history
This PR is about the high-level tensor computation operators in Relax.

This PR includes the linear algebra operators.

Co-authored-by: Siyuan Fneg <Hzfengsy@sjtu.edu.cn>
  • Loading branch information
2 people authored and tqchen committed Feb 14, 2023
1 parent d79b068 commit 58bcc12
Show file tree
Hide file tree
Showing 9 changed files with 640 additions and 0 deletions.
44 changes: 44 additions & 0 deletions include/tvm/relax/attrs/linear_algebra.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/attrs/linear_algebra.h
* \brief Attributes for linear algebra operators.
*/
#ifndef TVM_RELAX_ATTRS_LINEAR_ALGEBRA_H_
#define TVM_RELAX_ATTRS_LINEAR_ALGEBRA_H_

#include <tvm/relax/expr.h>

namespace tvm {
namespace relax {

/*! \brief Attributes for matmul operator */
struct MatmulAttrs : public tvm::AttrsNode<MatmulAttrs> {
DataType out_dtype;

TVM_DECLARE_ATTRS(MatmulAttrs, "relax.attrs.MatmulAttrs") {
TVM_ATTR_FIELD(out_dtype).describe("The data type of the output tensor");
}
}; // struct MatmulAttrs

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_ATTRS_LINEAR_ALGEBRA_H_
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .create import *
from .datatype import *
from .index import *
from .linear_algebra import *
from .manipulate import *
from .op_attrs import *
from .statistical import *
Expand Down
90 changes: 90 additions & 0 deletions python/tvm/relax/op/linear_algebra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Relax linear algebra operators"""
from typing import Optional, Union

from tvm import DataType

from ..expr import Expr
from . import _ffi_api
from .manipulate import permute_dims


def matmul(x1: Expr, x2: Expr, out_dtype: Optional[Union[str, DataType]] = None) -> Expr:
"""General matrix multiplication of two tensors, with broadcasting on batched dimensions.
The semantics and output shape deduction rule is specified as
https://data-apis.org/array-api/latest/API_specification/generated/array_api.matmul.html.
Parameters
----------
x1 : relax.Expr
The first input tensor.
x2 : relax.Expr
The second input tensor.
out_dtype: Optional[Union[str, DataType]]
The data type of the matmul result.
When it is not specified, the output dtype will be the the same as input dtype.
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.matmul(x1, x2, out_dtype) # type: ignore


def linear(
data: Expr,
weight: Expr,
bias: Optional[Expr] = None,
out_dtype: Optional[Union[str, DataType]] = None,
) -> Expr:
"""Applies a linear transformation to the incoming data: y = xA^T + b
Parameters
----------
data : relax.Expr
The input data.
weight : relax.Expr
The weight tensor.
bias : Optional[Expr]
The bias tensor.
out_dtype: Optional[Union[str, DataType]]
The data type of the matmul result.
When it is not specified, the output dtype will be the the same as input dtype.
Notes
-----
Relax does not regard the Linear Op as a primitive Op,
while combine the transpose, matmul and add op to implement it.
Returns
-------
result : relax.Expr
The computed result.
"""

# Since weight can be 1D or 2D, we use `axes=None` to support both cases.
x = matmul(data, permute_dims(weight, axes=None), out_dtype=out_dtype)
return x + bias if bias is not None else x
5 changes: 5 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ class StridedSliceAttrs(Attrs):
"""Attributes used in strided_slice operator"""


@tvm._ffi.register_object("relax.attrs.MatmulAttrs")
class MatmulAttrs(Attrs):
"""Attributes for matmul operator"""


@tvm._ffi.register_object("relax.attrs.Conv2DAttrs")
class Conv2DAttrs(Attrs):
"""Attributes for nn.conv2d"""
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@
isnan,
less,
less_equal,
linear,
log,
make_closure,
matmul,
max,
mean,
memory,
Expand Down Expand Up @@ -504,8 +506,10 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"isnan",
"less",
"less_equal",
"linear",
"log",
"make_closure",
"matmul",
"max",
"mean",
"memory",
Expand Down
123 changes: 123 additions & 0 deletions src/relax/op/tensor/linear_algebra.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file linear_algebra.cc
* \brief Linear algebra operators.
*/

#include "linear_algebra.h"

#include <algorithm>
#include <utility>

namespace tvm {
namespace relax {

/* relax.matmul */
TVM_REGISTER_NODE_TYPE(MatmulAttrs);

Expr matmul(Expr x1, Expr x2, DataType out_dtype) {
ObjectPtr<MatmulAttrs> attrs = make_object<MatmulAttrs>();
attrs->out_dtype = out_dtype;

static const Op& op = Op::Get("relax.matmul");
return Call(op, {std::move(x1), std::move(x2)}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.matmul").set_body_typed(matmul);

StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) {
Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
TensorStructInfo x1_sinfo = input_sinfo[0];
TensorStructInfo x2_sinfo = input_sinfo[1];

const auto* attrs = call->attrs.as<MatmulAttrs>();
DataType out_dtype = attrs->out_dtype.is_void()
? InferBinaryArithOpOutDtype(call, ctx, x1_sinfo, x2_sinfo)
: attrs->out_dtype;

if (x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) {
return TensorStructInfo(out_dtype, kUnknownNDim);
}
int x1_ndim = x1_sinfo->ndim;
int x2_ndim = x2_sinfo->ndim;
if (x1_ndim == 0 || x2_ndim == 0) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Matmul requires both inputs to have at least 1 dimension. However, "
<< (x1_ndim == 0 ? "x1" : "x2") << " is a 0-rank tensor.");
}

int x1_prepended = 0;
int x2_appended = 0;
if (x1_ndim == 1) {
x1_ndim = 2;
x1_prepended = 1;
}
if (x2_ndim == 1) {
x2_ndim = 2;
x2_appended = 1;
}
int output_ndim = std::max(x1_ndim, x2_ndim) - x1_prepended - x2_appended;

const auto* x1_shape = x1_sinfo->shape.as<ShapeExprNode>();
const auto* x2_shape = x2_sinfo->shape.as<ShapeExprNode>();
if (x1_shape == nullptr || x2_shape == nullptr) {
return TensorStructInfo(out_dtype, output_ndim);
}

Array<PrimExpr> x1_shape_prefix{x1_shape->values.begin(),
x1_shape->values.end() - 2 + x1_prepended};
Array<PrimExpr> x2_shape_prefix{x2_shape->values.begin(),
x2_shape->values.end() - 2 + x2_appended};
Optional<Array<PrimExpr>> output_shape_prefix =
InferBinaryBroadcastShape(call, ctx, x1_shape_prefix, x2_shape_prefix);
if (!output_shape_prefix.defined()) {
return TensorStructInfo(out_dtype, output_ndim);
}

arith::Analyzer* analyzer = ctx->GetAnalyzer();
PrimExpr x1_reduction_length = x1_shape->values[x1_sinfo->ndim - 1];
PrimExpr x2_reduction_length = x2_shape->values[x2_ndim - 2];
if (analyzer->CanProve(x1_reduction_length != x2_reduction_length)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Matmul requires the reduction length of x1 and x2 to be equal. However, "
"the reduction lengths of x1 and x2 are "
<< x1_reduction_length << " and " << x2_reduction_length << " respectively.");
}

Array<PrimExpr> output_shape = output_shape_prefix.value();
if (!x1_prepended) {
output_shape.push_back(x1_shape->values[x1_ndim - 2]);
}
if (!x2_appended) {
output_shape.push_back(x2_shape->values[x2_ndim - 1]);
}
ICHECK_EQ(static_cast<int>(output_shape.size()), output_ndim);
return TensorStructInfo(ShapeExpr(output_shape), out_dtype);
}

TVM_REGISTER_OP("relax.matmul")
.set_num_inputs(2)
.add_argument("x1", "Tensor", "The first input tensor.")
.add_argument("x2", "Tensor", "The second input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMatmul);

} // namespace relax
} // namespace tvm
49 changes: 49 additions & 0 deletions src/relax/op/tensor/linear_algebra.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file linear_algebra.h
* \brief The functions to make Relax linear algebra operator calls.
*/
#ifndef TVM_RELAX_OP_TENSOR_LINEAR_ALGEBRA_H_
#define TVM_RELAX_OP_TENSOR_LINEAR_ALGEBRA_H_

#include <tvm/relax/attrs/linear_algebra.h>

#include "../op_common.h"

namespace tvm {
namespace relax {

/*!
* \brief General matrix multiplication of two tensors.
* The semantics and output shape deduction rule is specified as
* https://data-apis.org/array-api/latest/API_specification/generated/array_api.matmul.html.
* \param x1 The first input tensor.
* \param x2 The second input tensor.
* \param out_dtype The data type of the matmul result.
* When it is not specified, the output dtype will be the the same as input dtype.
* \return The computed result.
*/
Expr matmul(Expr x1, Expr x2, DataType out_dtype);

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_OP_TENSOR_LINEAR_ALGEBRA_H_
Loading

0 comments on commit 58bcc12

Please sign in to comment.