Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript][Fix] Add type hints for more uncovered cases #9505

Merged
merged 7 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 43 additions & 29 deletions python/tvm/script/tir/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ class IterVar(Var): ...

class Buffer:
@overload
def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]]) -> PrimExpr: ...
def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]]) -> PrimExpr: ...
@overload
def __getitem__(self: Buffer, pos: Union[PrimExpr, int]) -> PrimExpr: ...
def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ...
@overload
def __setitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]], value: PrimExpr) -> None: ...
def __setitem__(
self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]], value: PrimExpr
) -> None: ...
@overload
def __setitem__(self: Buffer, pos: Union[PrimExpr, int], value: PrimExpr) -> None: ...
def __setitem__(self: Buffer, pos: Union[PrimExpr, int, slice], value: PrimExpr) -> None: ...
@property
def data(self: Buffer) -> Ptr: ...

Expand Down Expand Up @@ -124,35 +126,47 @@ def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
def store(
var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True
) -> None: ...
def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ...
def comm_reducer(lambda_io: Callable[[Any, Any], Any], identities: List[PrimExpr]) -> PrimExpr: ...

"""
Intrinsics - tvm builtin
"""

def tvm_thread_allreduce(
*freduceargs: Union[PrimExpr, builtins.bool, Ptr], dtype: str
) -> PrimExpr: ...

"""
Unary operator
Note that any intrinsics not registered in script.tir.intrin
should add "dtype" as an argument. This is different from their
definition but intentional.
"""

def exp2(x: PrimExpr) -> PrimExpr: ...
def exp10(x: PrimExpr) -> PrimExpr: ...
def erf(x: PrimExpr) -> PrimExpr: ...
def tanh(x: PrimExpr) -> PrimExpr: ...
def sigmoid(x: PrimExpr) -> PrimExpr: ...
def log(x: PrimExpr) -> PrimExpr: ...
def log2(x: PrimExpr) -> PrimExpr: ...
def log10(x: PrimExpr) -> PrimExpr: ...
def log1p(x: PrimExpr) -> PrimExpr: ...
def tan(x: PrimExpr) -> PrimExpr: ...
def cos(x: PrimExpr) -> PrimExpr: ...
def cosh(x: PrimExpr) -> PrimExpr: ...
def acos(x: PrimExpr) -> PrimExpr: ...
def acosh(x: PrimExpr) -> PrimExpr: ...
def sin(x: PrimExpr) -> PrimExpr: ...
def sinh(x: PrimExpr) -> PrimExpr: ...
def asin(x: PrimExpr) -> PrimExpr: ...
def asinh(x: PrimExpr) -> PrimExpr: ...
def atan(x: PrimExpr) -> PrimExpr: ...
def atanh(x: PrimExpr) -> PrimExpr: ...
def atan2(x: PrimExpr) -> PrimExpr: ...
def sqrt(x: PrimExpr) -> PrimExpr: ...
def rsqrt(x: PrimExpr) -> PrimExpr: ...
def exp(x: PrimExpr, dtype: str) -> PrimExpr: ...
def exp2(x: PrimExpr, dtype: str) -> PrimExpr: ...
def exp10(x: PrimExpr, dtype: str) -> PrimExpr: ...
def erf(x: PrimExpr, dtype: str) -> PrimExpr: ...
def tanh(x: PrimExpr, dtype: str) -> PrimExpr: ...
def sigmoid(x: PrimExpr, dtype: str) -> PrimExpr: ...
def log(x: PrimExpr, dtype: str) -> PrimExpr: ...
def log2(x: PrimExpr, dtype: str) -> PrimExpr: ...
def log10(x: PrimExpr, dtype: str) -> PrimExpr: ...
def log1p(x: PrimExpr, dtype: str) -> PrimExpr: ...
def tan(x: PrimExpr, dtype: str) -> PrimExpr: ...
def cos(x: PrimExpr, dtype: str) -> PrimExpr: ...
def cosh(x: PrimExpr, dtype: str) -> PrimExpr: ...
def acos(x: PrimExpr, dtype: str) -> PrimExpr: ...
def acosh(x: PrimExpr, dtype: str) -> PrimExpr: ...
def sin(x: PrimExpr, dtype: str) -> PrimExpr: ...
def sinh(x: PrimExpr, dtype: str) -> PrimExpr: ...
def asin(x: PrimExpr, dtype: str) -> PrimExpr: ...
def asinh(x: PrimExpr, dtype: str) -> PrimExpr: ...
def atan(x: PrimExpr, dtype: str) -> PrimExpr: ...
def atanh(x: PrimExpr, dtype: str) -> PrimExpr: ...
def atan2(x: PrimExpr, dtype: str) -> PrimExpr: ...
def sqrt(x: PrimExpr, dtype: str) -> PrimExpr: ...
def rsqrt(x: PrimExpr, dtype: str) -> PrimExpr: ...

"""
special_stmt - Buffers
Expand Down Expand Up @@ -334,7 +348,7 @@ def for_range(
end: Union[PrimExpr, int] = None,
annotations: Optional[Mapping[str, Object]] = None,
) -> Iterable[IterVar]: ...
def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ...
def grid(*extents: Union[PrimExpr, int]) -> Iterable[Sequence[IterVar]]: ...

"""
ty - redefine types
Expand Down
96 changes: 96 additions & 0 deletions tests/python/unittest/test_tvmscript_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,102 @@ def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None:
)


"""
This test case is added to test T.grid
"""


@T.prim_func
def loop_split(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [128, 128], dtype="float32")
B = T.match_buffer(b, [128], dtype="float32")
for i, ko in T.grid(128, 4):
for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
with T.block("B"):
vi = T.axis.S(128, i)
vk = T.axis.R(128, ko * 32 + ki)
T.reads([B[vi], A[vi, vk]])
T.writes([B[vi]])
with T.init():
B[vi] = T.float32(0)
B[vi] = B[vi] + A[vi, vk]


"""
This test case is added to test T.comm_reducer, T.reinterpret, T.tvm_thread_allreduce
"""


@T.prim_func
def lowered_loop_split(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [128, 128], dtype="float32")
B = T.match_buffer(b, [128], dtype="float32")
reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local")
normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local")
for i in T.serial(0, 128):
for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
normal_reduce_temp0[0] = T.float32(0)
for ko in T.serial(0, 4):
with T.block("B_normal_reduction"):
vi = T.axis.S(128, i)
vk = T.axis.R(128, ko * 32 + ki)
T.reads([A[vi, vk], normal_reduce_temp0[0]])
T.writes([normal_reduce_temp0[0]])
normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk]
with T.block("B_cross_thread_reduction"):
T.reads([normal_reduce_temp0[0]])
T.writes([reduce_temp0[0]])
T.attr(
T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
)
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
normal_reduce_temp0[0],
True,
reduce_temp0.data,
ki,
dtype="handle",
)
)
with T.block("B_write_back"):
vi = T.axis.S(128, i)
T.reads([reduce_temp0[0]])
T.writes([B[vi]])
B[vi] = reduce_temp0[0]


"""
This test case is added to test T.Buffer with slice as argument and T.exp
"""


@T.prim_func
def different_access_indices(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [128, 128, 128], dtype="float32")
B = T.match_buffer(b, [128, 128], dtype="float32")
for i, j in T.grid(128, 128):
for k in T.thread_binding(0, 128, thread="threadIdx.x"):
with T.block("B"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
T.reads([B[vi, vj], A[vi, vj, vk]])
T.writes(
[
B[
T.min(vj, vi) : T.min(vj, vi) # type: ignore[misc]
+ (T.max(vj, vi) + 1 - T.min(vj, vi)),
T.min(vi, vj) : T.min(vi, vj) # type: ignore[misc]
+ (T.max(vi, vj) + 1 - T.min(vi, vj)),
]
]
)
with T.init():
B[vj, vi] = T.exp(B[vj, vi], dtype="float32")
B[vi, vj] = B[vi, vj] + A[vi, vj, vk]


# Not running any test as we only want to type-check here
if __name__ == "__main__":
pass