Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Relax op: creation #13984

Merged
merged 2 commits into from
Feb 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions include/tvm/relax/attrs/create.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/attrs/create.h
* \brief Attributes for tensor creation operators.
*/
#ifndef TVM_RELAX_ATTRS_CREATE_H_
#define TVM_RELAX_ATTRS_CREATE_H_

#include <tvm/relax/expr.h>

namespace tvm {
namespace relax {

/*! \brief Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operators */
struct InitAttrs : public tvm::AttrsNode<InitAttrs> {
DataType dtype;

TVM_DECLARE_ATTRS(InitAttrs, "relax.attrs.InitAttrs") {
TVM_ATTR_FIELD(dtype).describe("The data type of the created tensor.");
}
}; // struct InitAttrs

/*! \brief Attributes used in tril and triu operator */
struct TriluAttrs : public tvm::AttrsNode<TriluAttrs> {
int k;

TVM_DECLARE_ATTRS(TriluAttrs, "relax.attrs.TriluAttrs") {
TVM_ATTR_FIELD(k).describe(
"The number of diagonals above or below the main diagonal to exclude or include.");
}
}; // struct TriluAttrs

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_ATTRS_CREATE_H_
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# Operators
from .base import *
from .binary import *
from .create import *
from .datatype import *
from .index import *
from .manipulate import *
Expand Down
209 changes: 209 additions & 0 deletions python/tvm/relax/op/create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Creation operators."""
from typing import Optional, Tuple, Union

from tvm import DataType
from tvm.ir.expr import PrimExpr

from . import _ffi_api
from ..expr import Expr, ShapeExpr

PrimExprLike = Union[int, PrimExpr]


def full(
shape: Union[Tuple[PrimExprLike], Expr],
fill_value: Expr,
dtype: Optional[Union[str, DataType]] = None,
) -> Expr:
"""Fill array with scalar value.
Parameters
----------
shape : Union[Tuple[PrimExprLike], Expr]
The shape of the created tensor.
fill_value : relax.Expr
The value to fill. Must be a scalar tensor.
dtype : Optional[Union[str, DataType]]
The data type of the created tensor.
If dtype is not given, it will by default use the dtype of fill_value.
Returns
-------
result : relax.Expr
The result tensor.
"""
return _ffi_api.full(shape, fill_value, dtype) # type: ignore


def full_like(x: Expr, fill_value: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr:
"""Construct a tensor such that
- its shape is the same as the input data tensor's shape,
- its value is filled with the input scalar fill value.
Parameters
----------
x : relax.Expr
The input tensor, which provides the shape, and dtype
when the `dtype` field is not specified.
fill_value : relax.Expr
The value to fill. Must be a scalar tensor.
dtype : Optional[Union[str, DataType]]
The data type of the created tensor.
If dtype is not given, it will by default use the dtype of the input tensor.
Returns
-------
result : relax.Expr
The result tensor.
"""
return _ffi_api.full_like(x, fill_value, dtype) # type: ignore


def ones(shape: Union[Tuple[PrimExprLike], Expr], dtype: Union[str, DataType]) -> Expr:
"""Construct a tensor of all ones, with the input shape and dtype.
Parameters
----------
shape : Union[Tuple[PrimExprLike], Expr]
The shape of the created tensor.
dtype : Union[str, DataType]
The data type of the created tensor.
Returns
-------
result : relax.Expr
The result tensor.
"""
if isinstance(shape, (tuple, list)):
shape = ShapeExpr(shape)
return _ffi_api.ones(shape, dtype) # type: ignore


def ones_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr:
"""Construct a tensor with all ones, with shape of the input tensor shape.
Parameters
----------
x : relax.Expr
The input tensor, which provides the shape, and dtype
when the `dtype` field is not specified.
dtype : Optional[Union[str, DataType]]
The data type of the created tensor.
If dtype is not given, it will by default use the dtype of the input tensor.
Returns
-------
result : relax.Expr
The result tensor.
"""
return _ffi_api.ones_like(x, dtype) # type: ignore


def zeros(shape: Union[Tuple[PrimExprLike], Expr], dtype: Union[str, DataType]) -> Expr:
"""Construct a tensor of all zeros, with the input shape and dtype.
Parameters
----------
shape : Union[Tuple[PrimExprLike], Expr]
The shape of the created tensor.
dtype : Union[str, DataType]
The data type of the created tensor.
Returns
-------
result : relax.Expr
The result tensor.
"""
if isinstance(shape, (tuple, list)):
shape = ShapeExpr(shape)
return _ffi_api.zeros(shape, dtype) # type: ignore


def zeros_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr:
"""Construct a tensor with all zeros, with shape of the input tensor shape.
Parameters
----------
x : relax.Expr
The input tensor, which provides the shape, and dtype
when the `dtype` field is not specified.
dtype : Optional[Union[str, DataType]]
The data type of the created tensor.
If dtype is not given, it will by default use the dtype of the input tensor.
Returns
-------
result : relax.Expr
The result tensor.
"""
return _ffi_api.zeros_like(x, dtype) # type: ignore


def tril(x: Expr, k: int = 0) -> Expr:
"""Return the lower triangular part of a matrix or a batch of matrices.
Parameters
----------
x : relax.Expr
The tensor that tril will be applied to.
It is required to have at least two dimensions.
k : int
The index indicating the diagonal above which to zero elements.
If k = 0, the diagonal is the main diagonal.
If k < 0, the diagonal is below the main diagonal.
If k > 0, the diagonal is above the main diagonal.
Returns
-------
ret : relax.Expr
The result tensor.
"""
return _ffi_api.tril(x, k) # type: ignore


def triu(x: Expr, k: int = 0) -> Expr:
"""Return the upper triangular part of a matrix or a batch of matrices.
Parameters
----------
x : relax.Expr
The tensor that triu will be applied to.
It is required to have at least two dimensions.
k : int
The index indicating the diagonal below which to zero elements.
If k = 0, the diagonal is the main diagonal.
If k < 0, the diagonal is below the main diagonal.
If k > 0, the diagonal is above the main diagonal.
Returns
-------
ret : relax.Expr
The result tensor.
"""
return _ffi_api.triu(x, k) # type: ignore
10 changes: 10 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@
import tvm._ffi


@tvm._ffi.register_object("relax.attrs.InitAttrs")
class InitAttrs(Attrs):
"""Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operator"""


@tvm._ffi.register_object("relax.attrs.TriluAttrs")
class TriluAttrs(Attrs):
"""Attributes used in tril and triu operator"""


@tvm._ffi.register_object("relax.attrs.AstypeAttrs")
class AstypeAttrs(Attrs):
"""Attributes used in astype operator"""
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
exp,
floor,
floor_divide,
full,
full_like,
greater,
greater_equal,
image,
Expand All @@ -71,6 +73,8 @@
negative,
not_equal,
null_value,
ones,
ones_like,
print,
prod,
reshape,
Expand All @@ -92,7 +96,11 @@
take,
tan,
tanh,
tril,
triu,
unique,
zeros,
zeros_like,
nn,
)
from tvm.relax.struct_info import StructInfo
Expand Down Expand Up @@ -480,6 +488,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"exp",
"floor",
"floor_divide",
"full",
"full_like",
"func_attr",
"func_name",
"func_ret_struct_info",
Expand All @@ -504,6 +514,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"negative",
"not_equal",
"null_value",
"ones",
"ones_like",
"output",
"prim_value",
"print",
Expand All @@ -528,8 +540,12 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"take",
"tan",
"tanh",
"tril",
"triu",
"tuple",
"variance",
"unique",
"zeros",
"zeros_like",
"nn",
]
Loading