Skip to content

Commit

Permalink
[Unity] Relax op: creation (#13984)
Browse files Browse the repository at this point in the history
This PR is about the high-level tensor computation operators in Relax.

This PR includes the tensor creation operators.
  • Loading branch information
MasterJH5574 authored and tqchen committed Feb 22, 2023
1 parent 95f609d commit 070e1ae
Show file tree
Hide file tree
Showing 9 changed files with 1,444 additions and 0 deletions.
54 changes: 54 additions & 0 deletions include/tvm/relax/attrs/create.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/attrs/create.h
* \brief Attributes for tensor creation operators.
*/
#ifndef TVM_RELAX_ATTRS_CREATE_H_
#define TVM_RELAX_ATTRS_CREATE_H_

#include <tvm/relax/expr.h>

namespace tvm {
namespace relax {

/*! \brief Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operators */
struct InitAttrs : public tvm::AttrsNode<InitAttrs> {
DataType dtype;

TVM_DECLARE_ATTRS(InitAttrs, "relax.attrs.InitAttrs") {
TVM_ATTR_FIELD(dtype).describe("The data type of the created tensor.");
}
}; // struct InitAttrs

/*! \brief Attributes used in tril and triu operator */
struct TriluAttrs : public tvm::AttrsNode<TriluAttrs> {
int k;

TVM_DECLARE_ATTRS(TriluAttrs, "relax.attrs.TriluAttrs") {
TVM_ATTR_FIELD(k).describe(
"The number of diagonals above or below the main diagonal to exclude or include.");
}
}; // struct TriluAttrs

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_ATTRS_CREATE_H_
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# Operators
from .base import *
from .binary import *
from .create import *
from .datatype import *
from .index import *
from .manipulate import *
Expand Down
209 changes: 209 additions & 0 deletions python/tvm/relax/op/create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Creation operators."""
from typing import Optional, Tuple, Union

from tvm import DataType
from tvm.ir.expr import PrimExpr

from . import _ffi_api
from ..expr import Expr, ShapeExpr

PrimExprLike = Union[int, PrimExpr]


def full(
shape: Union[Tuple[PrimExprLike], Expr],
fill_value: Expr,
dtype: Optional[Union[str, DataType]] = None,
) -> Expr:
"""Fill array with scalar value.
Parameters
----------
shape : Union[Tuple[PrimExprLike], Expr]
The shape of the created tensor.
fill_value : relax.Expr
The value to fill. Must be a scalar tensor.
dtype : Optional[Union[str, DataType]]
The data type of the created tensor.
If dtype is not given, it will by default use the dtype of fill_value.
Returns
-------
result : relax.Expr
The result tensor.
"""
return _ffi_api.full(shape, fill_value, dtype) # type: ignore


def full_like(x: Expr, fill_value: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr:
"""Construct a tensor such that
- its shape is the same as the input data tensor's shape,
- its value is filled with the input scalar fill value.
Parameters
----------
x : relax.Expr
The input tensor, which provides the shape, and dtype
when the `dtype` field is not specified.
fill_value : relax.Expr
The value to fill. Must be a scalar tensor.
dtype : Optional[Union[str, DataType]]
The data type of the created tensor.
If dtype is not given, it will by default use the dtype of the input tensor.
Returns
-------
result : relax.Expr
The result tensor.
"""
return _ffi_api.full_like(x, fill_value, dtype) # type: ignore


def ones(shape: Union[Tuple[PrimExprLike], Expr], dtype: Union[str, DataType]) -> Expr:
"""Construct a tensor of all ones, with the input shape and dtype.
Parameters
----------
shape : Union[Tuple[PrimExprLike], Expr]
The shape of the created tensor.
dtype : Union[str, DataType]
The data type of the created tensor.
Returns
-------
result : relax.Expr
The result tensor.
"""
if isinstance(shape, (tuple, list)):
shape = ShapeExpr(shape)
return _ffi_api.ones(shape, dtype) # type: ignore


def ones_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr:
"""Construct a tensor with all ones, with shape of the input tensor shape.
Parameters
----------
x : relax.Expr
The input tensor, which provides the shape, and dtype
when the `dtype` field is not specified.
dtype : Optional[Union[str, DataType]]
The data type of the created tensor.
If dtype is not given, it will by default use the dtype of the input tensor.
Returns
-------
result : relax.Expr
The result tensor.
"""
return _ffi_api.ones_like(x, dtype) # type: ignore


def zeros(shape: Union[Tuple[PrimExprLike], Expr], dtype: Union[str, DataType]) -> Expr:
"""Construct a tensor of all zeros, with the input shape and dtype.
Parameters
----------
shape : Union[Tuple[PrimExprLike], Expr]
The shape of the created tensor.
dtype : Union[str, DataType]
The data type of the created tensor.
Returns
-------
result : relax.Expr
The result tensor.
"""
if isinstance(shape, (tuple, list)):
shape = ShapeExpr(shape)
return _ffi_api.zeros(shape, dtype) # type: ignore


def zeros_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr:
"""Construct a tensor with all zeros, with shape of the input tensor shape.
Parameters
----------
x : relax.Expr
The input tensor, which provides the shape, and dtype
when the `dtype` field is not specified.
dtype : Optional[Union[str, DataType]]
The data type of the created tensor.
If dtype is not given, it will by default use the dtype of the input tensor.
Returns
-------
result : relax.Expr
The result tensor.
"""
return _ffi_api.zeros_like(x, dtype) # type: ignore


def tril(x: Expr, k: int = 0) -> Expr:
"""Return the lower triangular part of a matrix or a batch of matrices.
Parameters
----------
x : relax.Expr
The tensor that tril will be applied to.
It is required to have at least two dimensions.
k : int
The index indicating the diagonal above which to zero elements.
If k = 0, the diagonal is the main diagonal.
If k < 0, the diagonal is below the main diagonal.
If k > 0, the diagonal is above the main diagonal.
Returns
-------
ret : relax.Expr
The result tensor.
"""
return _ffi_api.tril(x, k) # type: ignore


def triu(x: Expr, k: int = 0) -> Expr:
"""Return the upper triangular part of a matrix or a batch of matrices.
Parameters
----------
x : relax.Expr
The tensor that triu will be applied to.
It is required to have at least two dimensions.
k : int
The index indicating the diagonal below which to zero elements.
If k = 0, the diagonal is the main diagonal.
If k < 0, the diagonal is below the main diagonal.
If k > 0, the diagonal is above the main diagonal.
Returns
-------
ret : relax.Expr
The result tensor.
"""
return _ffi_api.triu(x, k) # type: ignore
10 changes: 10 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@
import tvm._ffi


@tvm._ffi.register_object("relax.attrs.InitAttrs")
class InitAttrs(Attrs):
"""Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operator"""


@tvm._ffi.register_object("relax.attrs.TriluAttrs")
class TriluAttrs(Attrs):
"""Attributes used in tril and triu operator"""


@tvm._ffi.register_object("relax.attrs.AstypeAttrs")
class AstypeAttrs(Attrs):
"""Attributes used in astype operator"""
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
exp,
floor,
floor_divide,
full,
full_like,
greater,
greater_equal,
image,
Expand All @@ -71,6 +73,8 @@
negative,
not_equal,
null_value,
ones,
ones_like,
print,
prod,
reshape,
Expand All @@ -92,7 +96,11 @@
take,
tan,
tanh,
tril,
triu,
unique,
zeros,
zeros_like,
nn,
)
from tvm.relax.struct_info import StructInfo
Expand Down Expand Up @@ -480,6 +488,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"exp",
"floor",
"floor_divide",
"full",
"full_like",
"func_attr",
"func_name",
"func_ret_struct_info",
Expand All @@ -504,6 +514,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"negative",
"not_equal",
"null_value",
"ones",
"ones_like",
"output",
"prim_value",
"print",
Expand All @@ -528,8 +540,12 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"take",
"tan",
"tanh",
"tril",
"triu",
"tuple",
"variance",
"unique",
"zeros",
"zeros_like",
"nn",
]
Loading

0 comments on commit 070e1ae

Please sign in to comment.