Skip to content

Commit

Permalink
fix (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
spectrometerHBH authored Jun 6, 2022
1 parent d1e48ad commit 78c316a
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 192 deletions.
241 changes: 78 additions & 163 deletions python/tvm/meta_schedule/testing/tir_tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,94 +95,16 @@ def dot_product_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
)


# @T.prim_func
# def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
# A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_a")
# B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_b")
# C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=1, scope="wmma.accumulator")

# with T.block("root"):
# for i, j, k in T.grid(16, 16, 16):
# with T.block("update"):
# vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
# C[vii, vjj] = C[vii, vjj] + T.cast(A[vii, vkk], "float32") * T.cast(
# B[vkk, vjj], "float32"
# )


# @T.prim_func
# def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
# A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")
# B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")
# C = T.match_buffer(
# c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator"
# )

# with T.block("root"):
# T.reads(
# [
# C[0:16, 0:16],
# A[0:16, 0:16],
# B[0:16, 0:16],
# ]
# )
# T.writes(C[0:16, 0:16])
# T.evaluate(
# T.tvm_mma_sync(
# C.data,
# C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
# A.data,
# A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16),
# B.data,
# B.elem_offset // 256 + T.floordiv(T.floormod(B.elem_offset, 256), 16),
# C.data,
# C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
# dtype="handle",
# )
# )

@T.prim_func
def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_a")
B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_b")
C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16,
scope="wmma.accumulator")

with T.block("root"):
T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0: 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
for i, j, k in T.grid(16, 16, 16):
with T.block(""):
vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
C[vii, vjj] = C[vii, vjj] + T.cast(A[vii, vkk], 'float32') * T.cast(B[vkk, vjj], 'float32')


@T.prim_func
def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
def wmma_load_a_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16,
scope="shared")
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_a")
B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_b")
C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16,
scope="wmma.accumulator")

with T.block("root"):
T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0: 16, 0 : 16])
T.reads(A[0 : 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
T.evaluate(T.tvm_mma_sync(C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
A.data, A.elem_offset // 256,
B.data, B.elem_offset // 256,
C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), dtype='handle'))


@T.prim_func
def wmma_load_a_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")

with T.block("root"):
for i, j in T.grid(16, 16):
with T.block("load"):
vii, vjj = T.axis.remap("SS", [i, j])
Expand All @@ -193,34 +115,27 @@ def wmma_load_a_desc(a: T.handle, c: T.handle) -> None:
def wmma_load_a_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")
A = T.match_buffer(
a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0]
)
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0])
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")

with T.block("root"):
T.reads(A[0:16, 0:16])
T.writes(C[0:16, 0:16])
T.evaluate(
T.tvm_load_matrix_sync(
C.data,
16,
16,
16,
C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
A.access_ptr("r"),
s1,
"row_major",
dtype="handle",
)
)
T.reads(A[0 : 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
T.evaluate(T.tvm_load_matrix_sync(
C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), A.access_ptr("r"), s1, "row_major",
dtype="handle"))


@T.prim_func
def wmma_load_b_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16,
scope="shared")
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_b")

with T.block("root"):
T.reads(A[0 : 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
for i, j in T.grid(16, 16):
with T.block("load"):
vii, vjj = T.axis.remap("SS", [i, j])
Expand All @@ -231,34 +146,60 @@ def wmma_load_b_desc(a: T.handle, c: T.handle) -> None:
def wmma_load_b_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")
A = T.match_buffer(
a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0]
)
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0])
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")

with T.block("root"):
T.reads(A[0:16, 0:16])
T.writes(C[0:16, 0:16])
T.evaluate(
T.tvm_load_matrix_sync(
C.data,
16,
16,
16,
C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
A.access_ptr("r"),
s1,
"row_major",
dtype="handle",
)
)
T.reads(A[0 : 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
T.evaluate(T.tvm_load_matrix_sync(
C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), A.access_ptr("r"), s1, "col_major",
dtype="handle"))


@T.prim_func
def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_a")
B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_b")
C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16,
scope="wmma.accumulator")

with T.block("root"):
T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0: 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
for i, j, k in T.grid(16, 16, 16):
with T.block(""):
vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
C[vii, vjj] = C[vii, vjj] + T.cast(A[vii, vkk], 'float32') * T.cast(B[vjj, vkk], 'float32')


@T.prim_func
def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_a")
B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_b")
C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16,
scope="wmma.accumulator")

with T.block("root"):
T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0: 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
T.evaluate(T.tvm_mma_sync(C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
A.data, A.elem_offset // 256,
B.data, B.elem_offset // 256,
C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), dtype='handle'))


@T.prim_func
def wmma_fill_desc(c: T.handle) -> None:
C = T.match_buffer(
c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator"
)
C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator")

with T.block("root"):
T.reads()
T.writes(C[0 : 16, 0 : 16])
for i, j in T.grid(16, 16):
with T.block("init"):
vii, vjj = T.axis.remap("SS", [i, j])
Expand All @@ -267,32 +208,20 @@ def wmma_fill_desc(c: T.handle) -> None:

@T.prim_func
def wmma_fill_impl(c: T.handle) -> None:
C = T.match_buffer(
c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator"
)
C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator")
with T.block("root"):
T.reads([])
T.writes(C[0:16, 0:16])
T.evaluate(
T.tvm_fill_fragment(
C.data,
16,
16,
16,
C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
T.float32(0),
dtype="handle",
)
)
T.reads()
T.writes(C[0 : 16, 0 : 16])
T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), T.float32(0), dtype="handle"))


@T.prim_func
def wmma_store_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator"
)
A = T.match_buffer(a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator")
C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="global")
with T.block("root"):
T.reads(A[0 : 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
for i, j in T.grid(16, 16):
with T.block("store"):
vii, vjj = T.axis.remap("SS", [i, j])
Expand All @@ -303,28 +232,14 @@ def wmma_store_desc(a: T.handle, c: T.handle) -> None:
def wmma_store_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")
A = T.match_buffer(
a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator"
)
C = T.match_buffer(
c, (16, 16), "float32", align=128, offset_factor=16, scope="global", strides=[s1, s0]
)
A = T.match_buffer(a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator")
C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="global", strides=[s1, s0])
with T.block("root"):
T.reads(A[0:16, 0:16])
T.writes(C[0:16, 0:16])
T.evaluate(
T.tvm_store_matrix_sync(
A.data,
16,
16,
16,
A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16),
C.access_ptr("w"),
s1,
"row_major",
dtype="handle",
)
)
T.reads(A[0 : 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
T.evaluate(T.tvm_store_matrix_sync(
A.data, 16, 16, 16, A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16), C.access_ptr("w"), s1, "row_major",
dtype="handle"))


# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks
Expand Down
28 changes: 18 additions & 10 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1547,16 +1547,20 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) {
<< " may not be an iterator";
return GetRef<PrimExpr>(op);
}

IterSumExpr preprocessed = PreprocessDividend(Downcast<IterMapExpr>(a), op->a);
if (!preprocessed.defined()) {
return GetRef<PrimExpr>(op);
}
PrimExpr remainder = SplitFloorDivConst(preprocessed->args[0], preprocessed->base, b);
if (!remainder.defined()) {
return GetRef<PrimExpr>(op);
ICHECK(preprocessed->args.size() <= 1);
if (preprocessed->args.empty()) {
return IterSumExpr({}, floordiv(preprocessed->base, b));
} else {
PrimExpr remainder = SplitFloorDivConst(preprocessed->args[0], preprocessed->base, b);
if (!remainder.defined()) {
return GetRef<PrimExpr>(op);
}
return remainder;
}
return remainder;
}

PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs) {
Expand Down Expand Up @@ -1636,12 +1640,16 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
if (!preprocessed.defined()) {
return GetRef<PrimExpr>(op);
}

PrimExpr remainder = SplitFloorModConst(preprocessed->args[0], preprocessed->base, b);
if (!remainder.defined()) {
return GetRef<PrimExpr>(op);
ICHECK(preprocessed->args.size() <= 1);
if (preprocessed->args.empty()) {
return IterSumExpr({}, floormod(preprocessed->base, b));
} else {
PrimExpr remainder = SplitFloorModConst(preprocessed->args[0], preprocessed->base, b);
if (!remainder.defined()) {
return GetRef<PrimExpr>(op);
}
return remainder;
}
return remainder;
}

/*! * \brief Given an expression that may contain IterVarMapExpr, transform it to normal PrimExpr.
Expand Down
4 changes: 3 additions & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.predicate_opt", Bool);

using runtime::PackedFunc;
using runtime::TVMArgs;
Expand Down Expand Up @@ -208,6 +209,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
bool instrument_bound_checkers =
pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
bool disable_cse_tir = pass_ctx->GetConfig<Bool>("tir.disable_cse_tir", Bool(false)).value();
bool predicate_opt = pass_ctx->GetConfig<Bool>("tir.predicate_opt", Bool(false)).value();

// Get any user-added passes
Array<Array<ObjectRef>> add_lower_pass =
Expand Down Expand Up @@ -304,7 +306,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
}

pass_list.push_back(tir::transform::CommonSubexprElimTIR(!disable_cse_tir));
pass_list.push_back(tir::transform::OptimizePredicatedLoad(true));
pass_list.push_back(tir::transform::OptimizePredicatedLoad(predicate_opt));
return pass_list;
}

Expand Down
4 changes: 4 additions & 0 deletions src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReIndex")
int buffer_index_type) {
return self->ReIndex(block_rv, buffer_index, static_cast<BufferIndexType>(buffer_index_type));
});
/******** (FFI) Data movement ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReadAt").set_body_method<Schedule>(&ScheduleNode::ReadAt);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWriteAt")
.set_body_method<Schedule>(&ScheduleNode::WriteAt);
/******** (FFI) Compute location ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt")
.set_body_method<Schedule>(&ScheduleNode::ComputeAt);
Expand Down
Loading

0 comments on commit 78c316a

Please sign in to comment.