Skip to content

Commit

Permalink
[Unity] Relax op: manipulation (apache#13989)
Browse files Browse the repository at this point in the history
This PR is about the high-level tensor computation operators in Relax.

This PR includes the tensor manipulation operators.

Co-authored-by: Prakalp Srivastava <prakalp@octoml.ai>
  • Loading branch information
2 people authored and yongwww committed Feb 27, 2023
1 parent b061f74 commit 85479ca
Show file tree
Hide file tree
Showing 8 changed files with 3,803 additions and 10 deletions.
108 changes: 108 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/attrs/manipulate.h
* \brief Attributes for tensor manipulation operators.
*/
#ifndef TVM_RELAX_ATTRS_MANIPULATE_H_
#define TVM_RELAX_ATTRS_MANIPULATE_H_

#include <tvm/relax/expr.h>
#include <tvm/tir/index_map.h>

namespace tvm {
namespace relax {

/*! \brief Attributes used in concat operators */
struct ConcatAttrs : public tvm::AttrsNode<ConcatAttrs> {
Optional<Integer> axis;

TVM_DECLARE_ATTRS(ConcatAttrs, "relax.attrs.ConcatAttrs") {
TVM_ATTR_FIELD(axis).describe(
"The axis at which the input arrays are concatenated."
"Should lie in range `[-ndim, ndim)`.");
}
}; // struct ConcatAttrs

/*! \brief Attributes used in expand_dims operators */
struct ExpandDimsAttrs : public tvm::AttrsNode<ExpandDimsAttrs> {
Array<Integer> axis;

TVM_DECLARE_ATTRS(ExpandDimsAttrs, "relax.attrs.ExpandDimsAttrs") {
TVM_ATTR_FIELD(axis).describe(
"The axes at which the input array are expanded. "
"All values are required to lie in range `[-data.ndim - 1, data.ndim]`, "
"with the convention of negative indexing.");
}
}; // struct ExpandDimsAttrs

/*! \brief Attributes used in layout_transform operator */
struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
tir::IndexMap index_map;
// pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This
// needs to be revisited in case PrimValue is evolved to represent symbolic expression in future.
Optional<PrimValue> pad_value;

TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") {
TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply.");
TVM_ATTR_FIELD(pad_value).describe(
"The specific value to be used to pad if the layout transform would result in implicit "
"padding. If not specified, the compiler is free to choose any value.");
}
}; // struct LayoutTransformAttrs

/*! \brief Attributes used in permute_dims operator */
struct PermuteDimsAttrs : public tvm::AttrsNode<PermuteDimsAttrs> {
Optional<Array<Integer>> axes;

TVM_DECLARE_ATTRS(PermuteDimsAttrs, "relax.attrs.PermuteDimsAttrs") {
TVM_ATTR_FIELD(axes).describe("The target axes order, reverse order if not specified.");
}
}; // struct PermuteDimsAttrs

/*! \brief Attributes used in split operator */
struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
ObjectRef indices_or_sections;
int axis;

TVM_DECLARE_ATTRS(SplitAttrs, "relax.attrs.SplitAttrs") {
TVM_ATTR_FIELD(indices_or_sections)
.describe("The input array of indices or the number of split sections.");
TVM_ATTR_FIELD(axis).describe("The axis to be splitted");
}
}; // struct SplitAttrs

/*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
Optional<Array<Integer>> axis;

TVM_DECLARE_ATTRS(SqueezeAttrs, "relax.attrs.SqueezeAttrs") {
TVM_ATTR_FIELD(axis).describe(
"The axis to squeeze in the input tensor."
"If `axis = None`, all axis of dimension 1 get squeezed;"
"Else, the dimension in axes get squeezed."
"It is an error if an axis does not has dimension 1.");
}
}; // struct SqueezeAttrs

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_ATTRS_MANIPULATE_H_
207 changes: 204 additions & 3 deletions python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,161 @@
# specific language governing permissions and limitations
# under the License.
"""Manipulation operators."""
from typing import Tuple, Union
from typing import List, Optional, Tuple, Union, Callable

from tvm.ir.expr import PrimExpr

from tvm.tir import IntImm, FloatImm, IndexMap

from . import _ffi_api
from ..expr import Expr
from ..expr import Expr, PrimValue, ShapeExpr, Tuple as RxTuple


PrimExprLike = Union[int, PrimExpr]


def broadcast_to(x: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr:
"""Broadcasts a tensor to a specified shape.
Parameters
----------
x : relax.Expr
The input data to the operator.
shape : Union[Tuple[PrimExprLike], Expr]
The target shape.
Returns
-------
result : relax.Expr
The broadcasted tensor.
"""
if isinstance(shape, (tuple, list)):
shape = ShapeExpr(shape)
return _ffi_api.broadcast_to(x, shape) # type: ignore


def concat(tensors: Union[Expr, List[Expr]], axis: Optional[int] = 0) -> Expr:
"""Concatenate the input tensors along the given axis.
Parameters
----------
tensors : Union[relax.Expr, List[relax.Expr]]
An Expr in Tuple type, containing the tensors to be concatenated,
or a list of Tensors.
axis : Optional[int]
The axis along which the tensors are concatenated.
If `axis` is `None`, the input tensor is required to be flattened before concatenation.
Returns
-------
result: relax.Expr
The concatenated tensor.
"""
if isinstance(tensors, (list, tuple)):
tensors = RxTuple(tensors)
return _ffi_api.concat(tensors, axis) # type: ignore


def expand_dims(x: Expr, axis: Union[int, List[int]]) -> Expr:
"""Insert new axes at the positions given by `axis`.
Parameters
----------
x : relax.Expr
The input data to the operator.
axis : Union[int, List[int]]
The axes at which the input array are expanded.
All values are required to lie in range `[-data.ndim - 1, data.ndim]`, with the convention
of negative indexing.
Returns
-------
result : relax.Expr
The transformed result.
"""
if isinstance(axis, int):
axis = [axis]
return _ffi_api.expand_dims(x, axis) # type: ignore


def flatten(x: Expr) -> Expr:
"""Flatten all the tensor dimensions into one.
Parameters
----------
x : relax.Expr
The input data to the operator.
Returns
-------
result : relax.Expr
The flattened result.
"""
return _ffi_api.flatten(x) # type: ignore


def layout_transform(
x: Expr,
index_map: Union[Callable, IndexMap],
pad_value: Optional[Union[int, float, PrimValue]] = None,
):
"""Modifies the layout of a tensor.
Parameters
----------
x : relax.Expr
The input tensor to the operator.
index_map : Union[Callable, IndexMap]
The transformation to apply.
pad_value : Optional[Union[int, float, PrimValue]]
The value used for padding if the transformation results in implicit padding.
If not specified, any value can be used.
Returns
-------
result : relax.Expr
The transformed tensor.
"""
if callable(index_map):
index_map = IndexMap.from_func(index_map)
x_dtype = x.checked_type.dtype

# Explicitly convert python int/float pad_value to the x's type. If the default behavior
# is applied, it would be converted to int32/float32, which may not match the x's type.
if pad_value is None:
pass
elif not isinstance(pad_value, PrimValue):
if "int" in x_dtype and isinstance(pad_value, int):
pad_value = IntImm(x_dtype, pad_value)
elif "float" in x_dtype and (isinstance(pad_value, (int, float))):
pad_value = FloatImm(x_dtype, float(pad_value))
pad_value = PrimValue(pad_value)
return _ffi_api.layout_transform(x, index_map, pad_value) # type: ignore


def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr:
"""Permutes the dimensions of an array.
Parameters
----------
x : relax.Expr
The input data to the operator.
axes : Optional[List[int]]
The target axes order, reverse order if not specified.
Returns
-------
result : relax.Expr
The transposed result.
"""
return _ffi_api.permute_dims(x, axes) # type: ignore


def reshape(x: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr:
"""Reshape the input array.
Expand Down Expand Up @@ -60,3 +203,61 @@ def reshape(x: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr:
compile-time, an error will be thrown.
"""
return _ffi_api.reshape(x, shape) # type: ignore


def split(
x: Expr,
indices_or_sections: Union[int, List[PrimExprLike]],
axis: int = 0,
) -> Expr:
"""Split input tensor along axis by sections or indices.
If indices_or_sections is an integer, the input will be divided equally
along given axis (if possible). Last section will be smaller if the tensor
size along the given dimension is not divisible by the integer.
If indices_or_sections is a tuple of mixture of int or PrimExpr,
the entries indicate the indices where along axis the array is split.
Parameters
----------
x : relax.Expr
The tensor to be split.
indices_or_sections : Union[int, List[PrimExprLike]]
Indices or sections to split into. Accepts an int or a list.
axis : int
The axis over which to split.
Returns
-------
ret : relax.Expr
The computed result.
"""
if isinstance(indices_or_sections, int):
indices_or_sections = IntImm("int64", indices_or_sections)
return _ffi_api.split(x, indices_or_sections, axis) # type: ignore


def squeeze(x: Expr, axis: Optional[Union[int, List[int]]] = None) -> Expr:
"""Squeeze axes in the array.
Parameters
----------
x : relax.Expr
The input data to the operator.
axis : Optional[Union[int, List[int]]
The set of axes to remove.
If axis = None, remove all axis of dimensions 1.
If any specified axis has dimension that does not equal 1, it is an error.
Returns
-------
result : relax.Expr
The squeezed result.
"""
if isinstance(axis, int):
axis = [axis]
return _ffi_api.squeeze(x, axis) # type: ignore
30 changes: 30 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,36 @@ class StatisticalAttrs(Attrs):
"""Attributes used in statistical operator"""


@tvm._ffi.register_object("relax.attrs.ConcatAttrs")
class ConcatAttrs(Attrs):
"""Attributes for concat operator"""


@tvm._ffi.register_object("relax.attrs.ExpandDimsAttrs")
class ExpandDimsAttrs(Attrs):
"""Attributes for expand_dims operator"""


@tvm._ffi.register_object("relax.attrs.PermuteDimsAttrs")
class PermuteDimsAttrs(Attrs):
"""Attributes for permute_dims operator"""


@tvm._ffi.register_object("relax.attrs.SplitAttrs")
class SplitAttrs(Attrs):
"""Attributes used in split operator"""


@tvm._ffi.register_object("relax.attrs.SqueezeAttrs")
class SqueezeAttrs(Attrs):
"""Attributes for squeeze operator"""


@tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs")
class LayoutTransformAttrs(Attrs):
"""Attributes used in layout_transform operator"""


@tvm._ffi.register_object("relax.attrs.Resize2DAttrs")
class Resize2DAttrs(Attrs):
"""Attributes used in image resize2d operator"""
Expand Down
Loading

0 comments on commit 85479ca

Please sign in to comment.