Skip to content

Commit

Permalink
Fix tensor intrin
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jun 21, 2020
1 parent 1190dec commit dbc0777
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
8 changes: 4 additions & 4 deletions tests/python/unittest/test_te_schedule_tensor_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def intrin_func(ins, outs):

BA = ins[0]
BC = outs[0]
ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync',
ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync',
BC.data, n, m, l, BC.elem_offset // (row * col),
BA.access_ptr('r'), col, 'row_major'))
return ib.get()
Expand All @@ -66,12 +66,12 @@ def intrin_func(ins, outs):

def init():
ib = tvm.tir.ir_builder.create()
ib.emit(tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, m, l, BC.elem_offset // (n * m), 0.0))
ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_fill_fragment', BC.data, n, m, l, BC.elem_offset // (n * m), 0.0))
return ib.get()

def update():
ib = tvm.tir.ir_builder.create()
ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync',
ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_mma_sync',
BC.data, BC.elem_offset // (n * m),
BA.data, BA.elem_offset // (n * l),
BB.data, BB.elem_offset // (l * m),
Expand All @@ -95,7 +95,7 @@ def intrin_func(ins, outs):

BA = ins[0]
BC = outs[0]
ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync',
ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_store_matrix_sync',
BA.data, n, m, l, BA.elem_offset // (n * m),
BC.access_ptr('w'), m, 'row_major'))
return ib.get()
Expand Down
10 changes: 5 additions & 5 deletions topi/python/topi/cuda/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def intrin_func(ins, outs):
BC = outs[0]
row = wmma_m * wmma_k
warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_k
ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync',
ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync',
BC.data, wmma_m, wmma_n, wmma_k, warp_index,
BA.access_ptr('r'), strides_from[0], layout))
return ib.get()
Expand Down Expand Up @@ -128,7 +128,7 @@ def intrin_func(ins, outs):
BC = outs[0]
row = wmma_n * wmma_k
warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_n
ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync',
ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync',
BC.data, wmma_m, wmma_n, wmma_k, warp_index,
BA.access_ptr('r'), strides_from[0], layout))
return ib.get()
Expand Down Expand Up @@ -156,7 +156,7 @@ def intrin_func(ins, outs):
BC = outs[0]
row = wmma_m * wmma_n
warp_index = BA.elem_offset // row + BA.elem_offset % row // wmma_n
ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync',
ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_store_matrix_sync',
BA.data, wmma_m, wmma_n, wmma_k, warp_index,
BC.access_ptr('w'), strides_dst[0], 'row_major'))
return ib.get()
Expand Down Expand Up @@ -207,13 +207,13 @@ def warp_idnex(offset, row, col):
def init():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, wmma_m, wmma_n, wmma_k,
tvm.tir.call_intrin('handle', 'tir.tvm_fill_fragment', BC.data, wmma_m, wmma_n, wmma_k,
warp_index_C, 0.0))
return ib.get()

def update():
ib = tvm.tir.ir_builder.create()
ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync',
ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_mma_sync',
BC.data, warp_index_C,
BA.data, warp_index_A,
BB.data, warp_index_B,
Expand Down
8 changes: 4 additions & 4 deletions tutorials/optimize/opt_conv_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def intrin_func(ins, outs):

BA = ins[0]
BC = outs[0]
ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync',
ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync',
BC.data, n, n, n, BC.elem_offset // 256,
BA.access_ptr('r'), n, 'row_major'))
return ib.get()
Expand All @@ -190,12 +190,12 @@ def intrin_func(ins, outs):

def init():
ib = tvm.tir.ir_builder.create()
ib.emit(tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0))
ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0))
return ib.get()

def update():
ib = tvm.tir.ir_builder.create()
ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync',
ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_mma_sync',
BC.data, BC.elem_offset // 256,
BA.data, BA.elem_offset // 256,
BB.data, BB.elem_offset // 256,
Expand All @@ -218,7 +218,7 @@ def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()
BA = ins[0]
BC = outs[0]
ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync',
ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_store_matrix_sync',
BA.data, n, n, n, BA.elem_offset // 256,
BC.access_ptr('w'), n, 'row_major'))
return ib.get()
Expand Down

0 comments on commit dbc0777

Please sign in to comment.