Skip to content

Commit 02cfc2a

Browse files
authored
[Language] Add type stubs for tir op (#1239)
* add typing stub for tir.ir * remove idents * minor update
1 parent 30d8ded commit 02cfc2a

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed

tilelang/language/tir/ir.pyi

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from typing import TypeVar, Literal
2+
from tvm.tir.expr import Span, PrimExpr, BufferLoad, Var, IntImm
3+
4+
_T = TypeVar('_T')
5+
6+
def abs(x: _T, span: Span | None=None) -> _T: ...
7+
def acos(x: _T) -> _T: ...
8+
def acosh(x: _T) -> _T: ...
9+
def address_of(buffer_load: BufferLoad, span: Span | None=None) -> PrimExpr: ...
10+
def asin(x: _T) -> _T: ...
11+
def asinh(x: _T) -> _T: ...
12+
def atan(x: _T) -> _T: ...
13+
def atan2(x1: _T, x2: _T) -> _T: ...
14+
def atanh(x: _T) -> _T: ...
15+
def bitwise_and(x: _T, y: _T, span: Span | None=None) -> _T: ...
16+
def bitwise_not(x: _T, span: Span | None=None) -> _T: ...
17+
def bitwise_or(x: _T, y: _T, span: Span | None=None) -> _T: ...
18+
def bitwise_xor(x: _T, y: _T, span: Span | None=None) -> _T: ...
19+
def ceil(x: _T, span: Span | None=None) -> _T: ...
20+
def clz(x: _T) -> _T: ...
21+
def copysign(x1: _T, x2: _T) -> _T: ...
22+
def cos(x: _T) -> _T: ...
23+
def cosh(x: _T) -> _T: ...
24+
def erf(x: _T) -> _T: ...
25+
def exp(x: _T) -> _T: ...
26+
def exp2(x: _T) -> _T: ...
27+
def exp10(x: _T) -> _T: ...
28+
def floor(x: _T, span: Span | None=None) -> _T: ...
29+
def ceildiv(lhs: _T, rhs: _T, span: Span | None=None) -> _T: ...
30+
def floordiv(a: _T, b: _T, span: Span | None=None) -> _T: ...
31+
def floormod(a: _T, b: _T, span: Span | None=None) -> _T: ...
32+
def fmod(x: _T, y: _T) -> _T: ...
33+
def hypot(x1: _T, x2: _T) -> _T: ...
34+
def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None=None) -> _T: ...
35+
def infinity(dtype: _T, span: Span | None=None) -> _T: ...
36+
def isfinite(x: _T, span: Span | None=None) -> _T: ...
37+
def isinf(x: _T, span: Span | None=None) -> _T: ...
38+
def isnan(x: _T, span: Span | None=None) -> _T: ...
39+
def isnullptr(x: _T, span: Span | None=None) -> _T: ...
40+
def ldexp(x1: _T, x2: _T) -> _T: ...
41+
def likely(cond: _T, span: Span | None=None) -> _T: ...
42+
def log(x: _T) -> _T: ...
43+
def log1p(x: _T) -> _T: ...
44+
def log2(x: _T) -> _T: ...
45+
def log10(x: _T) -> _T: ...
46+
def lookup_param(param_name: str, span: Span | None=None) -> PrimExpr: ...
47+
def max_value(dtype: str, span: Span | None=None) -> PrimExpr: ...
48+
def min_value(dtype: str, span: Span | None=None) -> PrimExpr: ...
49+
def nearbyint(x: _T, span: Span | None=None) -> _T: ...
50+
def nextafter(x1: _T, x2: _T) -> _T: ...
51+
def popcount(x: _T) -> _T: ...
52+
def pow(x: _T, y: _T, span: Span | None=None) -> _T: ...
53+
def q_multiply_shift(x: _T, y: _T, q: _T, s: _T) -> _T: ...
54+
def q_multiply_shift_per_axis(x: _T, y: _T, ls: _T, rs: _T, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm) -> PrimExpr: ...
55+
def ret(val: _T) -> _T: ...
56+
def round(x: _T, span: Span | None=None) -> _T: ...
57+
def rsqrt(x: _T) -> _T: ...
58+
def shift_left(x: _T, y: _T, span=None) -> _T: ...
59+
def shift_right(x: _T, y: _T, span=None) -> _T: ...
60+
def sigmoid(x: _T) -> _T: ...
61+
def sin(x: _T) -> _T: ...
62+
def sinh(x: _T) -> _T: ...
63+
def sqrt(x: _T) -> _T: ...
64+
def tan(x: _T) -> _T: ...
65+
def tanh(x: _T) -> _T: ...
66+
def trunc(x: _T, span: Span | None=None) -> _T: ...
67+
def truncdiv(a: _T, b: _T, span: Span | None=None) -> _T: ...
68+
def truncmod(a: _T, b: _T, span: Span | None=None) -> _T: ...
69+
def tvm_access_ptr(ptype: PrimExpr, data, offset: int, extent: int, rw_mask: int) -> PrimExpr: ...
70+
def tvm_throw_last_error() -> _T: ...
71+
def tvm_stack_alloca(dtype_str: str, num: int) -> PrimExpr: ...
72+
def tvm_stack_make_shape(*args) -> _T: ...
73+
def tvm_stack_make_array(data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset) -> PrimExpr: ...
74+
def tvm_check_return(expected: int, return_unexpected: int, nested_call: PrimExpr) -> PrimExpr: ...
75+
def call_packed(*args, span=None) -> _T: ...
76+
def call_cpacked(*args, span=None) -> _T: ...
77+
def call_packed_lowered(*args, span=None) -> _T: ...
78+
def call_cpacked_lowered(*args, span=None) -> _T: ...
79+
def tvm_tuple(*value) -> _T: ...
80+
def tvm_struct_set(arr, index: int, field: int, value: PrimExpr) -> PrimExpr: ...
81+
def tvm_thread_invariant(cond: _T) -> _T: ...
82+
def tvm_thread_allreduce(*freduce_args) -> _T: ...
83+
def tvm_load_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ...
84+
def tvm_mma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ...
85+
def tvm_bmma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ...
86+
def tvm_fill_fragment(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, value: PrimExpr) -> PrimExpr: ...
87+
def tvm_store_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ...
88+
def ptx_wait_group(num: int) -> PrimExpr: ...
89+
def ptx_commit_group() -> _T: ...
90+
def ptx_cp_async_barrier(barrier_id: int) -> PrimExpr: ...
91+
def ptx_init_barrier_thread_count(barrier_id: int, thread_count: int) -> PrimExpr: ...
92+
def ptx_arrive_barrier(barrier_id: int) -> PrimExpr: ...
93+
def ptx_arrive_barrier_expect_tx(barrier_id: int, byte_count: int) -> PrimExpr: ...
94+
def ptx_wait_barrier(barrier_id: int) -> PrimExpr: ...
95+
def create_barriers(barrier_count: int) -> PrimExpr: ...
96+
def assume(cond: _T=None) -> _T: ...
97+
def undef() -> _T: ...
98+
def TVMBackendAllocWorkspace(device_type: int, device_id: int, nbytes: int, dtype_code_hint: int, dtype_bits_hint: int) -> PrimExpr: ...
99+
def TVMBackendFreeWorkspace(device_type: int, device_id: int, ptr: Var) -> PrimExpr: ...
100+
def start_profile_intrinsic(id: int) -> PrimExpr: ...
101+
def end_profile_intrinsic(id: int) -> PrimExpr: ...
102+
def anylist_getitem(list_handle, index) -> PrimExpr: ...
103+
def anylist_resetitem(list_handle, index) -> PrimExpr: ...
104+
def anylist_setitem_call_packed(list_handle, index, func_name, *args) -> PrimExpr: ...
105+
def anylist_setitem_call_cpacked(list_handle, index, func_name, *args) -> PrimExpr: ...
106+
def vscale() -> _T: ...

0 commit comments

Comments
 (0)