diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index c28b0514dac1..de48ac24822a 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -177,17 +177,6 @@ TVM_DLL Target hexagon(const std::vector& options = std::vector> add_lower_pass; - /*! \brief Whether to dump the IR of each pass (only when building from python) */ - bool dump_pass_ir = false; - /*! \brief Whether to instrument loads and stores with check for out of the bounds. */ bool instrument_bound_checkers = false; @@ -233,8 +219,6 @@ class BuildConfigNode : public Object { bool disable_assert = false; void VisitAttrs(AttrVisitor* v) { - v->Visit("data_alignment", &data_alignment); - v->Visit("offset_factor", &offset_factor); v->Visit("double_buffer_split_loop", &double_buffer_split_loop); v->Visit("auto_unroll_max_step", &auto_unroll_max_step); v->Visit("auto_unroll_max_depth", &auto_unroll_max_depth); @@ -243,7 +227,6 @@ class BuildConfigNode : public Object { v->Visit("restricted_func", &restricted_func); v->Visit("detect_global_barrier", &detect_global_barrier); v->Visit("partition_const_loop", &partition_const_loop); - v->Visit("dump_pass_ir", &dump_pass_ir); v->Visit("instrument_bound_checkers", &instrument_bound_checkers); v->Visit("disable_select_rewriting", &disable_select_rewriting); v->Visit("disable_vectorize", &disable_vectorize); diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 216cad992d98..97ed8d8042a1 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -56,7 +56,6 @@ def get_binds(args, compact=False, binds=None): The list of symbolic buffers of arguments. """ binds = {} if binds is None else binds.copy() - cfg = BuildConfig.current() arg_list = [] for x in args: if isinstance(x, tensor.Tensor): @@ -66,9 +65,6 @@ def get_binds(args, compact=False, binds=None): buf = tvm.tir.decl_buffer( x.shape, dtype=x.dtype, - name=x.name, - data_alignment=cfg.data_alignment, - offset_factor=cfg.offset_factor, buffer_type=buffer_type) binds[x] = buf arg_list.append(buf) @@ -157,8 +153,6 @@ def lower(sch, """ cfg = BuildConfig.current() add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else [] - if cfg.dump_pass_ir: - add_lower_pass = BuildConfig._dump_ir.decorate_custompass(add_lower_pass) lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] diff --git a/python/tvm/target/build_config.py b/python/tvm/target/build_config.py index 538ee7d5f544..a99797a0c397 100644 --- a/python/tvm/target/build_config.py +++ b/python/tvm/target/build_config.py @@ -45,11 +45,8 @@ class BuildConfig(Object): "unroll_explicit": True, "detect_global_barrier": False, "partition_const_loop": False, - "offset_factor": 0, - "data_alignment": -1, "restricted_func": True, "double_buffer_split_loop": 1, - "dump_pass_ir": False, "instrument_bound_checkers": False, "disable_select_rewriting": False, "disable_vectorize": False, @@ -129,14 +126,6 @@ def build_config(**kwargs): partition_const_loop: bool, default=False Whether partition const loop - data_alignment: int, optional - The alignment of data pointer in bytes. - If -1 is passed, the alignment will be set to TVM's internal default. - - offset_factor: int, default=0 - The factor used in default buffer declaration. - If specified as 0, offset field is not used. - restricted_func: bool, default=True Whether build restricted function. That is each buffer argument to the function are guaranteed @@ -152,8 +141,6 @@ def build_config(**kwargs): phase contains an integer on which optimization pass we apply the pass. Additional lowering passes to be applied before make_api. - dump_pass_ir: dump ir of each pass into file idx_passname_ir.cc, default=False - Returns ------- config: BuildConfig diff --git a/python/tvm/te/tensor_intrin.py b/python/tvm/te/tensor_intrin.py index c5c2afef1c93..cd488a7fbd14 100644 --- a/python/tvm/te/tensor_intrin.py +++ b/python/tvm/te/tensor_intrin.py @@ -20,7 +20,6 @@ from tvm.runtime import Object, convert from tvm.ir import Range -from tvm.target import BuildConfig from .tensor import PlaceholderOp from . import tensor as _tensor @@ -68,7 +67,9 @@ def __call__(self, *args, **kwargs): def decl_tensor_intrin(op, fcompute, name="tensor_intrin", - binds=None, scalar_params=None): + binds=None, + scalar_params=None, + default_buffer_params=None): """Declare a tensor intrinsic function. Parameters @@ -104,6 +105,9 @@ def decl_tensor_intrin(op, scalar_params: a list of variables used by op, whose values will be passed as scalar_inputs when the tensor intrinsic is called. + default_buffer_params: Optional[dict] + Dictionary of buffer arguments to be passed when constructing a buffer. + Returns ------- intrin: TensorIntrin @@ -122,12 +126,11 @@ def decl_tensor_intrin(op, if not isinstance(t.op, PlaceholderOp): raise ValueError("Do not yet support composition op") - cfg = BuildConfig.current() + default_buffer_params = {} if default_buffer_params is None else default_buffer_params for t in tensors: buf = (binds[t] if t in binds else tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name, - data_alignment=cfg.data_alignment, - offset_factor=cfg.offset_factor)) + **default_buffer_params)) binds_list.append(buf) if scalar_params: diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index cdd9d5441b25..ca1e122ca13c 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -91,8 +91,7 @@ void GetBinds(const Array& args, bool compact, for (const auto& x : args) { if (out_binds->find(x) == out_binds->end()) { - auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, config->data_alignment, - config->offset_factor, compact); + auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, -1, 0, compact); out_binds->Set(x, buf); out_arg_list->push_back(buf); } else { diff --git a/src/target/target.cc b/src/target/target.cc index 644ebdfdd579..aac5a2be25a0 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -357,8 +357,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "build_config("; - p->stream << "data_alignment=" << op->data_alignment << ", "; - p->stream << "offset_factor=" << op->offset_factor << ", "; p->stream << "double_buffer_split_loop=" << op->double_buffer_split_loop << ", "; p->stream << "auto_unroll_max_step=" << op->auto_unroll_max_step << ", "; p->stream << "auto_unroll_max_depth=" << op->auto_unroll_max_depth << ", "; @@ -367,7 +365,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "restricted_func=" << op->restricted_func << ", "; p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", "; p->stream << "partition_const_loop=" << op->partition_const_loop << ", "; - p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", "; p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", "; p->stream << "disable_select_rewriting=" << op->disable_select_rewriting; p->stream << "disable_vectorize=" << op->disable_vectorize; diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 45b9680ab3f3..8b98ed9d14d9 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -34,7 +34,7 @@ namespace tvm { namespace tir { -// TODO(tqchen): change to floormod/div + using IndexMod = tir::FloorModNode; using IndexDiv = tir::FloorDivNode; diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py index 9b8d4061afb4..2c851cc39789 100644 --- a/tests/python/unittest/test_te_schedule.py +++ b/tests/python/unittest/test_te_schedule.py @@ -115,7 +115,6 @@ def test_fuse_with_split(): assert any(isinstance(x, tvm.te.schedule.Fuse) for x in s[T].relations) assert tuple(s[T].leaf_iter_vars) == (xo, fused) -@pytest.mark.xfail def test_fuse_with_out_of_order_axis(): m = te.size_var('m') n = te.size_var('n') @@ -125,9 +124,10 @@ def test_fuse_with_out_of_order_axis(): s = te.create_schedule(T.op) y = T.op.axis[1] xo, xi = s[T].split(T.op.axis[0], factor=10) - fused = s[T].fuse(xo, y) # should throw here -@pytest.mark.xfail + with pytest.raises(RuntimeError): + fused = s[T].fuse(xo, y) # should throw here + def test_fuse_with_out_of_order_axis_with_reorder(): m = te.size_var('m') n = te.size_var('n') @@ -144,23 +144,21 @@ def test_fuse_with_out_of_order_axis_with_reorder(): y = T.op.axis[1] xo, xi = s[T].split(T.op.axis[0], factor=10) s[T].reorder(y, xo, xi) - fused = s[T].fuse(y, xi) # should throw here + + with pytest.raises(RuntimeError): + fused = s[T].fuse(y, xi) # should throw here def test_singleton(): - print("test singleton") A = te.placeholder((), name='A') T = te.compute((), lambda : A() + 1) s = te.create_schedule(T.op) - print("test singleton fin1") fused = s[T].fuse() assert any(isinstance(x, tvm.te.schedule.Singleton) for x in s[T].relations) assert tuple(s[T].leaf_iter_vars) == (fused,) dump = pkl.dumps(s) - print("test singleton fin3") s_loaded = pkl.loads(dump) - print("test singleton fin2") assert isinstance(s_loaded, tvm.te.schedule.Schedule) - print("test singleton fin") + def test_vectorize(): m = te.size_var('m') @@ -177,13 +175,14 @@ def test_vectorize(): assert s[T].iter_var_attrs[xi].iter_type == UNROLL assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE -@pytest.mark.xfail + def test_vectorize_commreduce(): V = te.placeholder((128,), name='V') ax = te.reduce_axis((0, 128), name='ax') O = te.compute((1,), lambda _: te.sum(V[ax], axis=[ax])) s = te.create_schedule(O.op) - s[O].vectorize(ax) # should throw here + with pytest.raises(RuntimeError): + s[O].vectorize(ax) # should throw here def test_pragma(): m = 100 @@ -271,8 +270,9 @@ def intrin_func(ins, outs, sp): assert(sp[1] == w) return tvm.tir.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1]) - with tvm.target.build_config(offset_factor=1): - intrin = te.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w]) + intrin = te.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w], default_buffer_params={ + "offset_factor": 1 + }) assert intrin.op == z.op assert intrin.reduce_init is None assert tuple(intrin.inputs) == tuple(z.op.input_tensors) diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py index 7cbf20eccf12..3f93c772a037 100644 --- a/tests/python/unittest/test_te_schedule_ops.py +++ b/tests/python/unittest/test_te_schedule_ops.py @@ -321,10 +321,9 @@ def intrin_func(ins, outs): "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) return body, reset, update - with tvm.target.build_config(data_alignment=16, - offset_factor=16): - return te.decl_tensor_intrin(z.op, intrin_func, - binds={w: Wb}) + buffer_params = {"data_alignment": 16, "offset_factor": 16} + return te.decl_tensor_intrin( + z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params) def test_schedule_tensor_compute1(): @@ -377,8 +376,9 @@ def intrin_func(ins, outs): ib.emit(tvm.tir.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) return ib.get() - with tvm.target.build_config(offset_factor=16): - return te.decl_tensor_intrin(z.op, intrin_func, binds=binds) + return te.decl_tensor_intrin(z.op, intrin_func, binds=binds, default_buffer_params={ + "offset_factor": 16 + }) def test_schedule_tensor_compute2(): diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index ef5b3fd3a44e..5152235ef379 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -25,8 +25,8 @@ def intrin_func(ins, outs): xx, yy = ins zz = outs[0] return tvm.tir.call_packed("vadd", xx, yy, zz) - with tvm.target.build_config(offset_factor=16): - return te.decl_tensor_intrin(z.op, intrin_func) + buffer_params = {"offset_factor": 16} + return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params=buffer_params) def intrin_gemv(m, n): w = te.placeholder((m, n), name='w') @@ -52,10 +52,9 @@ def intrin_func(ins, outs): "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) return body, reset, update - with tvm.target.build_config(data_alignment=16, - offset_factor=16): - return te.decl_tensor_intrin(z.op, intrin_func, - binds={w: Wb}) + buffer_params = {"offset_factor": 16, "data_alignment": 16} + return te.decl_tensor_intrin( + z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params) def intrin_gemv_no_reset(m, n): w = te.placeholder((m, n), name='w') @@ -79,10 +78,10 @@ def intrin_func(ins, outs): "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) return body, None, update - with tvm.target.build_config(data_alignment=16, - offset_factor=16): - return te.decl_tensor_intrin(z.op, intrin_func, - binds={w: Wb}) + + buffer_params = {"offset_factor": 16, "data_alignment": 16} + return te.decl_tensor_intrin( + z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params) def test_tensorize_vadd(): @@ -248,8 +247,9 @@ def intrin_func(ins, outs): zz = outs[0] return tvm.tir.call_packed("op", xx, zz) - with tvm.target.build_config(offset_factor=2): - return te.decl_tensor_intrin(y.op, intrin_func) + return te.decl_tensor_intrin(y.op, intrin_func, default_buffer_params={ + "offset_factor": 2 + }) A = te.placeholder((5, 5), name='A') B = te.compute((9,9), lambda i, j: A[idxd(j,3) + idxm(i,3), idxm(j,3) + idxd(i,3)]) @@ -286,8 +286,7 @@ def intrin_multivadd(n): def intrin_func(ins, outs): return tvm.tir.call_packed("multivadd") - with tvm.target.build_config(): - return te.decl_tensor_intrin(z.op, intrin_func, name="multivadd") + return te.decl_tensor_intrin(z.op, intrin_func, name="multivadd") def intrin_vadd(n): dtype = 'float32' @@ -297,9 +296,7 @@ def intrin_vadd(n): s = te.create_schedule(z.op) def create_buffer(t): - return tvm.tir.decl_buffer(t.shape, t.dtype, - name='W'+t.name, - offset_factor=16) + return tvm.tir.decl_buffer(t.shape, t.dtype, name='W'+t.name, offset_factor=16) def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() @@ -307,11 +304,9 @@ def intrin_func(ins, outs): ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) return ib.get() - - with tvm.target.build_config(offset_factor=16): - return te.decl_tensor_intrin(z.op, intrin_func, binds={x: create_buffer(x), - y: create_buffer(y), - z: create_buffer(z)}) + return te.decl_tensor_intrin(z.op, intrin_func, binds={x: create_buffer(x), + y: create_buffer(y), + z: create_buffer(z)}) # cache_read, cache_write M = 1024 diff --git a/tests/python/unittest/test_te_tensor.py b/tests/python/unittest/test_te_tensor.py index 5d3cbadce165..a8ab3cfda25a 100644 --- a/tests/python/unittest/test_te_tensor.py +++ b/tests/python/unittest/test_te_tensor.py @@ -117,8 +117,9 @@ def intrin_func(ins, outs): ib.emit(tvm.tir.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) return ib.get() - with tvm.target.build_config(offset_factor=n): - return te.decl_tensor_intrin(z.op, intrin_func) + return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params={ + "offset_factor": n + }) vadd = intrin_vadd(factor) @@ -159,8 +160,8 @@ def intrin_func(ins, outs): "gemv_add", x_ptr, y_ptr, z_ptr, m, n, l) return body, reset, update - with tvm.target.build_config(offset_factor=n): - return te.decl_tensor_intrin(z.op, intrin_func) + return te.decl_tensor_intrin(z.op, intrin_func, + default_buffer_params={"offset_factor": n}) vgemm = intrin_gemm(factor1, factor2, factor) @@ -290,8 +291,8 @@ def intrin_func(ins, outs): dout = outs[0] return tvm.tir.call_packed("op", dinp, dout) - with tvm.target.build_config(offset_factor=1): - return te.decl_tensor_intrin(P.op, intrin_func) + return te.decl_tensor_intrin(P.op, intrin_func, + default_buffer_params={"offset_factor": 1}) A = te.placeholder((1, 64, 16, 16), name='A') P = pool(data=A, kernel=(3, 3), stride=(1, 1), padding=(0, 0, 0, 0), diff --git a/topi/python/topi/cuda/tensor_intrin.py b/topi/python/topi/cuda/tensor_intrin.py index f8fce342e212..3941c00cc464 100644 --- a/topi/python/topi/cuda/tensor_intrin.py +++ b/topi/python/topi/cuda/tensor_intrin.py @@ -69,14 +69,15 @@ def _instr(index): return _instr(0), _instr(1), _instr(2) # body, reset, update - with tvm.target.build_config(data_alignment=4, offset_factor=1) as cfg: - scopes = {x: x_scope, y: y_scope, z: z_scope} - binds = {t: tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name, - data_alignment=cfg.data_alignment, - offset_factor=cfg.offset_factor, - scope=scopes[t]) for t in [x, y, z]} - - return te.decl_tensor_intrin(z.op, _intrin_func, binds=binds) + default_buffer_params = { + "data_alignment": 4, "offset_factor": 1 + } + scopes = {x: x_scope, y: y_scope, z: z_scope} + binds = {t: tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name, + scope=scopes[t], **default_buffer_params) for t in [x, y, z]} + + return te.decl_tensor_intrin( + z.op, _intrin_func, binds=binds, default_buffer_params=default_buffer_params) def intrin_wmma_load_matrix_A(strides_dst, strides_from, shape, layout, A_shape, C_shape, in_dtype): diff --git a/topi/python/topi/x86/tensor_intrin.py b/topi/python/topi/x86/tensor_intrin.py index 955b6b4ad280..ee8d83dbef07 100644 --- a/topi/python/topi/x86/tensor_intrin.py +++ b/topi/python/topi/x86/tensor_intrin.py @@ -110,8 +110,10 @@ def _instr(index): # body, reset, update return _instr(0), _instr(1), _instr(2) - with tvm.target.build_config(offset_factor=1, partition_const_loop=True): - return te.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}) + buffer_params = {"offset_factor" : 1} + return te.decl_tensor_intrin( + C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}, + default_buffer_params=buffer_params) def dot_16x1x16_uint8_int8_int16(): @@ -191,9 +193,10 @@ def _instr(index): # body, reset, update return _instr(0), _instr(1), _instr(2) - - with tvm.target.build_config(offset_factor=1, partition_const_loop=True): - return te.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}) + buffer_params = {"offset_factor" : 1} + return te.decl_tensor_intrin( + C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}, + default_buffer_params=buffer_params) def dot_16x1x16_uint8_int8_int32_cascadelake(): @@ -287,5 +290,7 @@ def _instr(index): # body, reset, update return _instr(0), _instr(1), _instr(2) - with tvm.target.build_config(offset_factor=1, partition_const_loop=True): - return te.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}) + buffer_params = {"offset_factor" : 1} + return te.decl_tensor_intrin( + C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}, + default_buffer_params=buffer_params)