Skip to content

Commit

Permalink
[Relax] Express dynamic arguments of strided_slice as arguments (apac…
Browse files Browse the repository at this point in the history
…he#16826)

* [Relax] Express dynamic arguments of strided_slice as arguments

Prior to this commit, `relax.op.strided_slice` stored the `axes`,
`begin`, `end`, and `strides` in the `CallNode::attrs`.  However, the
attributes are only intended to store static values.  The indices used
used for `relax.op.strided_slice` must frequently be in terms of
symbolic shape variables, which should not be stored in the
attributes.  While some utilities have special handling for
`relax.op.strided_slice` (e.g. `tvm::relax::Bind`), many do
not (e.g. `tvm::relax::WellFormed` and
`tvm::relax::FreeSymbolicVars`).  As a result, the symbolic
expressions in `relax.op.strided_slice` will fail to be updated in
generic utilities, and will fail to trigger safeguards when this
occurs.

This commit changes the representation of `relax.op.strided_slice` to
store all arguments in the `relax::CallNode::args`, rather than the
`relax::CallNode::attrs`.  As mentioned in a comment from
apache#13987, which initially implemented
`relax.op.strided_slice`, this was an intended refactor once
`relax::PrimValue` was fully supported.

* Undo unnecessary changes in const_int_bound

* Remove unnecessary changes to rewrite_simplify

* lint fixes

* Fix unit tests

* Improve error message

* Fix additional unit tests

* Mark MSC tests with xfail

* remove commented-out code

* Resolve failing unit test

* Remove unused imports
  • Loading branch information
Lunderberg authored May 1, 2024
1 parent a320b63 commit 20d7696
Show file tree
Hide file tree
Showing 20 changed files with 653 additions and 357 deletions.
11 changes: 0 additions & 11 deletions include/tvm/relax/attrs/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,9 @@ struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {

/*! \brief Attributes used in strided_slice operator */
struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
Array<Integer> axes;
Array<PrimExpr> begin;
Array<PrimExpr> end;
Optional<Array<PrimExpr>> strides;
bool assume_inbound;

TVM_DECLARE_ATTRS(StridedSliceAttrs, "relax.attrs.StridedSliceAttrs") {
TVM_ATTR_FIELD(axes).describe("Axes along which slicing is applied.");
TVM_ATTR_FIELD(begin).describe("The indices to begin with in the slicing, inclusive.");
TVM_ATTR_FIELD(end).describe("The indices indicating end of the slice, exclusive.");
TVM_ATTR_FIELD(strides).describe(
"Specifies the stride values, it can be negative in that case, the input tensor will be "
"reversed in that particular axis. If not specified, it by default is an list of ones of "
"the same length as `axes`.");
TVM_ATTR_FIELD(assume_inbound)
.set_default(true)
.describe(
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from tvm.runtime import relax_vm as vm
from tvm.runtime.relax_vm import VirtualMachine, VMInstrumentReturnKind

from .type_converter import args_converter

# Expr
from .expr import (
Expr,
Expand Down Expand Up @@ -92,6 +94,9 @@
from .pipeline import get_pipeline
from .pipeline import register_pipeline

# utils
from .utils import convert_to_expr

# Import submodules in the last to avoid dependency
from . import exec_builder
from . import expr
Expand All @@ -105,6 +110,7 @@
from . import training
from . import distributed
from . import frontend
from . import utils

# VM
from .vm_build import build, Executable
Expand Down
12 changes: 7 additions & 5 deletions python/tvm/relax/op/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
# specific language governing permissions and limitations
# under the License.
"""Indexing operators."""
from typing import List, Optional, Union
from typing import Optional, Union

from tvm.ir.expr import PrimExpr

from . import _ffi_api
from ..expr import Expr
from .. import args_converter

PrimExprLike = Union[int, PrimExpr]

Expand Down Expand Up @@ -52,12 +53,13 @@ def take(x: Expr, indices: Expr, axis: Optional[int] = None) -> Expr:
return _ffi_api.take(x, indices, axis) # type: ignore


@args_converter.auto
def strided_slice(
x: Expr,
axes: List[int],
begin: List[PrimExprLike],
end: List[PrimExprLike],
strides: Optional[List[PrimExprLike]] = None,
axes: Expr,
begin: Expr,
end: Expr,
strides: Optional[Expr] = None,
assume_inbound: bool = False,
) -> Expr:
"""Strided slice of a tensor.
Expand Down
39 changes: 29 additions & 10 deletions python/tvm/relax/transform/legalize_ops/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ...op import call_pure_packed
from ...block_builder import BlockBuilder
from ...expr import Call, Expr
from ...struct_info import ShapeStructInfo
from ...struct_info import ShapeStructInfo, PrimStructInfo
from .common import register_legalize


Expand All @@ -35,18 +35,37 @@ def _take(bb: BlockBuilder, call: Call) -> Expr:

@register_legalize("relax.strided_slice")
def _strided_slice(bb: BlockBuilder, call: Call) -> Expr:
strides = (
[tir.IntImm("int64", 1)] * len(call.attrs.axes)
if call.attrs.strides is None
else call.attrs.strides
)
def _relax_tuple_to_tir(relax_tuple):
output = []
for field in relax_tuple.struct_info.fields:
assert isinstance(field, PrimStructInfo)
assert field.value is not None
output.append(field.value)
return output

if len(call.args) == 4:
data, axes, begin, end = call.args
strides = [tir.IntImm("int64", 1)] * len(axes.struct_info.fields)
elif len(call.args) == 5:
data, axes, begin, end, strides = call.args
strides = _relax_tuple_to_tir(strides)
else:
raise ValueError(
f"Expression {call} provides {len(call.args)} arguments, "
f"but {call.op} requires either 4 or 5 arguments."
)

axes = _relax_tuple_to_tir(axes)
begin = _relax_tuple_to_tir(begin)
end = _relax_tuple_to_tir(end)

return bb.call_te(
topi.strided_slice,
call.args[0],
call.attrs.begin,
call.attrs.end,
data,
begin,
end,
strides,
call.attrs.axes,
axes,
slice_mode="end",
)

Expand Down
179 changes: 179 additions & 0 deletions python/tvm/relax/type_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name,too-many-locals

"""Argument converter utility for Relax
This utility is used to decorate constructors of `tvm.relax.Expr`, and
must be able to be imported before `tvm.relax.Expr` or its subtypes
have been defined. Neither the class definitions nor any type
signature in this file may reference relax types. All references must
be exclusively in function bodies to avoid having a circular reference
during module imports.
"""

import functools
import inspect

from typing import List, Optional, Callable, TypeVar, Any

import tvm

FType = TypeVar("FType", bound=Callable[..., "tvm.relax.Expr"])


class _ArgsConverter:
"""A helper class to convert the arguments to Expr."""

@staticmethod
def convert(args_to_expr: List[str], args_to_list_expr: List[str]):
"""Convert the arguments to Expr.
Parameters
----------
args_to_expr : List[str]
The argument names to be converted to Expr.
args_to_list_expr : List[str]
The argument names to be converted to List[Expr].
Returns
-------
output : Callable[[FType], FType]
The decorator.
"""

if any([x in args_to_list_expr for x in args_to_expr]):
raise ValueError("`args_to_expr` and `args_to_list_expr` should be disjoint.")

def _convert(name: str, value: Any) -> Any:
if value is None:
return value
if name in args_to_expr:
try:
return tvm.relax.utils.convert_to_expr(value)
except Exception as err:
raise TypeError(
f"Argument `{name}` is expected to be converted to `Expr`, "
f"but failed with input value: {value}"
) from err
elif name in args_to_list_expr:
try:
return [tvm.relax.utils.convert_to_expr(x) for x in value]
except Exception as err:
raise TypeError(
f"Argument `{name}` is expected to be converted to `List[Expr]`, "
f"but failed with input value: {value}"
) from err
else:
return value

def inner(func: FType) -> FType:
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
for name in args_to_expr + args_to_list_expr:
if name not in param_names:
raise ValueError(f"Argument `{name}` is not found in function signature.")

@functools.wraps(func)
def wrapper(*args, **kwargs):
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
for param in sig.parameters.values():
if param.kind == param.VAR_POSITIONAL:
# *args case
values = [_convert(param.name, x) for x in bound.arguments[param.name]]
bound.arguments[param.name] = tuple(values)
elif param.kind == param.VAR_KEYWORD:
# **kwargs case
key_value = {
key: _convert(param.name, value)
for key, value in bound.arguments[param.name].items()
}
bound.arguments[param.name] = key_value
else:
bound.arguments[param.name] = _convert(
param.name, bound.arguments[param.name]
)
return func(*bound.args, **bound.kwargs)

return wrapper # type: ignore

return inner

@staticmethod
def to_expr(*arg_names: str) -> Callable:
"""Convert the arguments to Expr.
Parameters
----------
*arg_names: str
The list of argument names that need to be converted to Expr.
Returns
-------
output: Callable
The decorator.
"""

return _ArgsConverter.convert(args_to_expr=list(arg_names), args_to_list_expr=[])

@staticmethod
def to_list_expr(*arg_names: str) -> Callable:
"""Convert the arguments to List of Expr.
Parameters
----------
*arg_names: str
The list of argument names that need to be converted to List of Expr.
Returns
-------
output: Callable
The decorator.
"""

return _ArgsConverter.convert(args_to_expr=[], args_to_list_expr=list(arg_names))

@staticmethod
def auto(func: FType) -> FType:
"""Decorator for automatically convert the arguments to Expr according to type annotation.
Only two patterns are supported:
1. The argument is Expr or Optional[Expr].
2. The argument is List[Expr] or Optional[List[Expr]].
"""
sig = inspect.signature(func)
args_to_expr = []
args_to_list_expr = []

from . import Expr # pylint: disable=import-outside-toplevel

for param in sig.parameters.values():
anno = param.annotation
if anno in (Expr, Optional[Expr]):
args_to_expr.append(param.name)
if anno in (List[Expr], Optional[List[Expr]]):
args_to_list_expr.append(param.name)

return _ArgsConverter.convert(args_to_expr, args_to_list_expr)(func)


args_converter = _ArgsConverter() # pylint: disable=invalid-name
Loading

0 comments on commit 20d7696

Please sign in to comment.