From 663370dd7de8964a30fd9a2b7b197f159c5ccc95 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 30 Jun 2021 16:25:35 +0000 Subject: [PATCH 1/4] specialize --- include/tvm/tir/analysis.h | 2 +- include/tvm/tir/buffer.h | 1 + python/tvm/tir/function.py | 20 +- src/tir/ir/specialize.cc | 326 ++++++++++++++++++ .../test_tvmscript_meta_programming.py | 185 ++++++++++ 5 files changed, 532 insertions(+), 2 deletions(-) create mode 100644 src/tir/ir/specialize.cc create mode 100644 tests/python/unittest/test_tvmscript_meta_programming.py diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 262ac688f2e0..63d6fa375c83 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -96,7 +96,7 @@ TVM_DLL Array UndefinedVars(const PrimExpr& expr); TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr); /*! - * \brief Whether e expression used any var in variable set.. + * \brief Whether e expression used any var in variable set. * \param expr The expression to be checked. * \param vset_contains The check function to see if var is in the vset. * \return Whether e uses vset. diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index a01d69b372d2..017f4f7052b1 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -183,6 +183,7 @@ class Buffer : public ObjectRef { TVM_DLL Stmt vstore(Array begin, PrimExpr value) const; TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode); }; /*! diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 79d18d8970b5..cb66ece750a6 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -16,12 +16,14 @@ # under the License. """Function data types.""" +from typing import Mapping, Union + import tvm._ffi import tvm.runtime from tvm.runtime import Object from tvm.ir import BaseFunc from .buffer import Buffer -from .expr import Var +from .expr import Var, PrimExpr from . import _ffi_api @@ -85,3 +87,19 @@ def with_body(self, new_body, span=None): The created new function. """ return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map, self.attrs, span) + + def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]): + """Metaprogramming usage: specialize parameters of PrimFunc + + Parameters + ---------- + + param_map : Mapping[Var, Union[PrimExpr, Buffer]] + The mapping from function params to the instance + + Returns + ------- + func : PrimFunc + The new function with parameter specialized + """ + return _ffi_api.Specialize(self, param_map) diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc new file mode 100644 index 000000000000..1ef031fd49f7 --- /dev/null +++ b/src/tir/ir/specialize.cc @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tir/ir/specialize.cc + * \brief Specialize parameters of PrimFunc. + */ +#include +#include +#include +#include + +#include + +#include "functor_common.h" + +namespace tvm { +namespace tir { + +using VarMap = std::unordered_map; + +/**************** Helper functions ****************/ + +/*! \brief Helper function to check whether the given var is in function parameter list. */ +inline bool IsParam(const PrimFunc& func, const Var& param) { + return std::any_of(func->params.begin(), func->params.end(), + [&](const Var& var) { return var.same_as(param); }); +} + +/*! \brief Mutator to specialize function and remove const parameters */ +class PrimFuncSpecializer : public StmtExprMutator { + public: + explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {} + + static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) { + PrimFuncSpecializer specializer(var_map); + // Updating Buffer map + Map buffer_map; + for (const auto& it : f->buffer_map) { + const Var& var = it.first; + const Buffer& buffer = it.second; + Buffer new_buffer = specializer.MutateBuffer(buffer); + buffer_map.Set(var, new_buffer); + if (!new_buffer.same_as(buffer)) { + specializer.buffer_map_[buffer] = new_buffer; + } + } + + // Updating parmeters + Array params; + for (const auto& var : f->params) { + // Remove parmeters which has been specialized. + if (var_map.find(var) == var_map.end()) { + params.push_back(var); + } + } + + PrimFuncNode* f_ptr = f.CopyOnWrite(); + f_ptr->params = std::move(params); + f_ptr->buffer_map = std::move(buffer_map); + f_ptr->body = specializer(std::move(f_ptr->body)); + + // Updating attrs + if (f->attrs.defined()) { + auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict; + for (const auto& kv : attr_dict) { + const String& key = kv.first; + const ObjectRef& value = kv.second; + if (value->IsInstance()) { + attr_dict.Set(key, Substitute(Downcast(value), var_map)); + } + } + } + return f; + } + + private: + Stmt VisitStmt_(const BlockNode* op) final { + Array alloc_buffers = MutateArray( + op->alloc_buffers, + std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1)); + Array reads = MutateArray( + op->reads, + std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1)); + Array writes = MutateArray( + op->writes, + std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1)); + Array block_vars = MutateArray( + op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1)); + Optional init = NullOpt; + if (op->init.defined()) { + init = VisitStmt(op->init.value()); + } + Stmt body = VisitStmt(op->body); + + if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) && + writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) && + init.same_as(op->init)) { + return GetRef(op); + } else { + ObjectPtr n = CopyOnWrite(op); + n->alloc_buffers = std::move(alloc_buffers); + n->reads = std::move(reads); + n->writes = std::move(writes); + n->iter_vars = std::move(block_vars); + n->body = std::move(body); + n->init = std::move(init); + return Stmt(n); + } + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto it = buffer_map_.find(op->buffer); + if (it == buffer_map_.end()) { + return GetRef(op); + } + + PrimExpr value = VisitExpr(op->value); + Array indices = + MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); }); + + auto n = CopyOnWrite(op); + n->buffer = it->second; + n->value = std::move(value); + n->indices = std::move(indices); + return Stmt(n); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto it = buffer_map_.find(op->buffer); + if (it == buffer_map_.end()) { + return GetRef(op); + } + + Array indices = + MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); }); + + auto n = CopyOnWrite(op); + n->buffer = it->second; + n->indices = std::move(indices); + return PrimExpr(n); + } + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = var_map_.find(GetRef(op)); + if (it == var_map_.end()) { + return GetRef(op); + } else { + return it->second; + } + } + + private: + Buffer MutateBuffer(Buffer buffer) const { + BufferNode* buffer_ptr = buffer.CopyOnWrite(); + Array new_shape, new_stride; + new_shape.reserve(buffer_ptr->shape.size()); + new_shape.reserve(buffer_ptr->strides.size()); + for (const auto& dim : buffer_ptr->shape) { + new_shape.push_back(Substitute(dim, var_map_)); + } + for (const auto& stride : buffer_ptr->strides) { + new_shape.push_back(Substitute(stride, var_map_)); + } + buffer_ptr->elem_offset = Substitute(buffer_ptr->elem_offset, var_map_); + buffer_ptr->shape = std::move(new_shape); + buffer_ptr->strides = std::move(new_stride); + return buffer; + } + + Range MutateRange(const Range& range) { + PrimExpr min = this->VisitExpr(range->min); + PrimExpr extent = this->VisitExpr(range->extent); + if (min.same_as(range->min) && extent.same_as(range->extent)) { + return range; + } else { + ObjectPtr n = CopyOnWrite(range.get()); + n->min = std::move(min); + n->extent = std::move(extent); + return Range(n); + } + } + + IterVar MutateIterVar(const IterVar& iter_var) { + Range range = MutateRange(iter_var->dom); + if (range.same_as(iter_var->dom)) { + return iter_var; + } else { + auto n = CopyOnWrite(iter_var.get()); + n->dom = std::move(range); + return IterVar(n); + } + } + + Buffer MutateAllocBuffer(const Buffer& alloc_buf) { + Buffer buf = MutateBuffer(alloc_buf); + if (buf.same_as(alloc_buf)) { + return alloc_buf; + } else { + buffer_map_[alloc_buf] = buf; + return buf; + } + } + + BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) { + auto it = buffer_map_.find(buffer_region->buffer); + Array region = + MutateArray(buffer_region->region, + std::bind(&PrimFuncSpecializer::MutateRange, this, std::placeholders::_1)); + if (it == buffer_map_.end() && region.same_as(buffer_region->region)) { + return buffer_region; + } else { + auto n = CopyOnWrite(buffer_region.get()); + n->buffer = it->second; + n->region = std::move(region); + return BufferRegion(n); + } + } + + private: + /*! \brief The vars to be substitute and their values */ + VarMap var_map_; + /*! \brief map from old buffer to mutated buffer */ + std::unordered_map buffer_map_; +}; + +/**************** Implementation ****************/ + +PrimFunc Specialize(PrimFunc func, const Var& param, const Buffer& specific_buf) { + // preliminaries + tir::ExprDeepEqual equal; + VarMap var_map; + + auto it = func->buffer_map.find(param); + CHECK(it != func->buffer_map.end()) + << "ValueError: specialize expects param to be in PrimFunc's buffer_map"; + const Buffer& buf_to_specialize = (*it).second; + + // build var mapping using specific_buf's parameters + auto build_var_mapping = [&](const PrimExpr& new_expr, const PrimExpr& old_expr) { + if (!equal(new_expr, old_expr)) { + CHECK(old_expr->IsInstance()) + << "TypeError: The signature of target buffer exprected an independent Var, but got " + << old_expr << "."; + const Var& var = Downcast(old_expr); + auto it = var_map.find(var); + if (it != var_map.end()) { + CHECK(equal(it->second, new_expr)) + << "ValueError: The assigned value of var " << var << " mismatched. " << it->second + << " vs. " << new_expr << "."; + } else { + var_map[var] = new_expr; + } + } + }; + + // Check buffer dimensions + CHECK(specific_buf->shape.size() == buf_to_specialize->shape.size() && + specific_buf->strides.size() == buf_to_specialize->strides.size()) + << "ValueError: The buffer dimensions mismatched" << buf_to_specialize->shape.size() + << " vs. " << specific_buf->shape.size() << "."; + + // Updating var mapping using specific_expr + for (size_t i = 0; i < specific_buf->shape.size(); ++i) { + build_var_mapping(specific_buf->shape[i], buf_to_specialize->shape[i]); + } + for (size_t i = 0; i < specific_buf->strides.size(); ++i) { + build_var_mapping(specific_buf->strides[i], buf_to_specialize->strides[i]); + } + build_var_mapping(specific_buf->elem_offset, buf_to_specialize->elem_offset); + // Specialize function with var mapping + return PrimFuncSpecializer::Specialize(func, var_map); +} + +PrimFunc Specialize(PrimFunc func, const Var& param, const PrimExpr& specific_expr) { + // preliminaries + VarMap var_map; + // check param is in PrimFunc's parameters + CHECK(IsParam(func, param)) << "ValueError: Specialize expects param to be in PrimFunc's params"; + // specialize a param not in buffer_map + CHECK_EQ(func->buffer_map.count(param), 0) + << "ValueError: Specialize expects param to not be in PrimFunc's buffer_map"; + // build var mapping using specific_expr + var_map[param] = specific_expr; + // Specialize function with var mapping + return PrimFuncSpecializer::Specialize(std::move(func), var_map); +} + +/**************** FFI ****************/ + +TVM_REGISTER_GLOBAL("tir.Specialize") + .set_body_typed)>([](PrimFunc func, + Map param_map) { + for (const auto& kv : param_map) { + const Var& param = kv.first; + const ObjectRef& instance = kv.second; + if (instance->IsInstance()) { + func = Specialize(std::move(func), param, Downcast(instance)); + } else if (instance->IsInstance()) { + func = Specialize(std::move(func), param, Downcast(instance)); + } else { + LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or PrimExpr, but got " + << instance->GetTypeKey(); + } + } + return func; + }); + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_meta_programming.py b/tests/python/unittest/test_tvmscript_meta_programming.py new file mode 100644 index 000000000000..4068719089d3 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_meta_programming.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring, missing-module-docstring + +import tvm +from tvm import tir +from tvm.script import ty + + +@tvm.script.tir +def matmul(a: ty.handle, b: ty.handle, c: ty.handle, n: ty.int32) -> None: + m = tir.var("int32") + A = tir.match_buffer(a, [m, n]) + B = tir.match_buffer(b, [m, n]) + C = tir.match_buffer(c, [m, m]) + + with tir.block([m, m, tir.reduce_axis(0, n)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def matmul_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def matmul_m_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + m = tir.var("int32") + A = tir.match_buffer(a, [m, 128]) + B = tir.match_buffer(b, [m, 128]) + C = tir.match_buffer(c, [m, m]) + + with tir.block([m, m, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def matmul_m_8x(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + x = tir.var("int32") + m = tir.var("int32") + A = tir.match_buffer(a, [m, x * 8]) + B = tir.match_buffer(b, [m, x * 8]) + C = tir.match_buffer(c, [m, m]) + + with tir.block([m, m, tir.reduce_axis(0, x * 8)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def element_wise(a: ty.handle, c: ty.handle) -> None: + m = tir.var("int32") + n = tir.var("int32") + A = tir.match_buffer(a, (m, n), "float32") + C = tir.match_buffer(c, (m, n), "float32") + + B = tir.alloc_buffer((m, n), "float32") + + with tir.block([m, n], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + with tir.block([m, n], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def element_wise_128_64(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 64), "float32") + C = tir.match_buffer(c, (128, 64), "float32") + B = tir.alloc_buffer((128, 64), "float32") + + with tir.block([128, 64], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + with tir.block([128, 64], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def element_wise_128_n(a: ty.handle, c: ty.handle) -> None: + n = tir.var("int32") + A = tir.match_buffer(a, (128, n), "float32") + C = tir.match_buffer(c, (128, n), "float32") + B = tir.alloc_buffer((128, n), "float32") + + with tir.block([128, n], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + with tir.block([128, n], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None: + A = tir.match_buffer(a, (m, n), "float32") + B = tir.match_buffer(b, (m, n), "float32") + + with tir.block([m, n], "") as [vi, vj]: + B[vi, vj] = A[vi, vj] + + +@tvm.script.tir +def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + B = tir.match_buffer(b, (16, 16), "float32") + + with tir.block([16, 16], "") as [vi, vj]: + B[vi, vj] = A[vi, vj] + + +@tvm.script.tir +def mem_copy_m_16(a: ty.handle, b: ty.handle, m: ty.int32) -> None: + A = tir.match_buffer(a, (m, 16), "float32") + B = tir.match_buffer(b, (m, 16), "float32") + + with tir.block([m, 16], "") as [vi, vj]: + B[vi, vj] = A[vi, vj] + + +def test_tensor_dimension_invariant_code_matmul(): + a, _, _, n = matmul.params + # fully specialized + func = matmul.specialize({a: tir.decl_buffer((128, 128))}) + tvm.ir.assert_structural_equal(func, matmul_128) + # partially specialized + func = matmul.specialize({n: 128}) + tvm.ir.assert_structural_equal(func, matmul_m_128) + # symbolic specialized + func = matmul.specialize({n: tir.Var("x", "int32") * 8}) + tvm.ir.assert_structural_equal(func, matmul_m_8x) + + +def test_tensor_dimension_invariant_code_elemwise(): + a, c = element_wise.params + C = element_wise.buffer_map[c] + # fully specialized + func = element_wise.specialize({a: tir.decl_buffer((128, 64))}) + tvm.ir.assert_structural_equal(func, element_wise_128_64) + # partially specialized + func = element_wise.specialize({c: tir.decl_buffer((128, C.shape[1]))}) + tvm.ir.assert_structural_equal(func, element_wise_128_n) + + +def test_tensor_dimension_invariant_code_mem_copy(): + a, _, m, n = mem_copy.params + # fully specialized + func = mem_copy.specialize({a: tir.decl_buffer((16, 16))}) + tvm.ir.assert_structural_equal(func, mem_copy_16_16) + func = mem_copy.specialize({n: 16, m: 16}) + tvm.ir.assert_structural_equal(func, mem_copy_16_16) + # partially specialized + func = mem_copy.specialize({n: 16}) + tvm.ir.assert_structural_equal(func, mem_copy_m_16) + + +if __name__ == "__main__": + test_tensor_dimension_invariant_code_matmul() + test_tensor_dimension_invariant_code_elemwise() + test_tensor_dimension_invariant_code_mem_copy() From ce4abcb96691c6081972afcb7e9550f209bacc37 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 30 Jun 2021 16:25:35 +0000 Subject: [PATCH 2/4] update doc string --- python/tvm/tir/function.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index cb66ece750a6..f75e61c8859e 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -97,6 +97,41 @@ def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]): param_map : Mapping[Var, Union[PrimExpr, Buffer]] The mapping from function params to the instance + Examples + -------- + We can define a Meta TIR function with symbolic shape: + + .. code-block:: python + + @tvm.script.tir + def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None: + A = tir.match_buffer(a, (m, n), "float32") + B = tir.match_buffer(b, (m, n), "float32") + + with tir.block([m, n], "") as [vi, vj]: + B[vi, vj] = A[vi, vj] + + Then we can make it specialized with given shapes or buffers. + + .. code-block:: python + + a, _, m, n = mem_copy.params + func = mem_copy.specialize({a: tir.decl_buffer((16, 16))}) + # or + func = mem_copy.specialize({n: 16, m: 16}) + + The specialized function: + + .. code-block:: python + + @tvm.script.tir + def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + B = tir.match_buffer(b, (16, 16), "float32") + + with tir.block([16, 16], "") as [vi, vj]: + B[vi, vj] = A[vi, vj] + Returns ------- func : PrimFunc From b629c0b2768fc50882e78806b11ec4d1c95fec98 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 30 Jun 2021 16:25:35 +0000 Subject: [PATCH 3/4] address comment --- src/tir/ir/specialize.cc | 187 ++++++++---------- .../test_tvmscript_meta_programming.py | 42 ++-- 2 files changed, 111 insertions(+), 118 deletions(-) diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 1ef031fd49f7..c6d21ca70760 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -43,118 +43,109 @@ inline bool IsParam(const PrimFunc& func, const Var& param) { [&](const Var& var) { return var.same_as(param); }); } +/**************** Specializer ****************/ + /*! \brief Mutator to specialize function and remove const parameters */ class PrimFuncSpecializer : public StmtExprMutator { public: - explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {} + explicit PrimFuncSpecializer(const VarMap& var_map) : var_map_(var_map) {} static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) { PrimFuncSpecializer specializer(var_map); // Updating Buffer map Map buffer_map; + bool buffer_map_updated = false; for (const auto& it : f->buffer_map) { const Var& var = it.first; const Buffer& buffer = it.second; Buffer new_buffer = specializer.MutateBuffer(buffer); buffer_map.Set(var, new_buffer); if (!new_buffer.same_as(buffer)) { + buffer_map_updated = true; specializer.buffer_map_[buffer] = new_buffer; } } // Updating parmeters Array params; + bool param_updated = false; for (const auto& var : f->params) { // Remove parmeters which has been specialized. if (var_map.find(var) == var_map.end()) { params.push_back(var); + } else { + param_updated = true; } } - PrimFuncNode* f_ptr = f.CopyOnWrite(); - f_ptr->params = std::move(params); - f_ptr->buffer_map = std::move(buffer_map); - f_ptr->body = specializer(std::move(f_ptr->body)); - - // Updating attrs - if (f->attrs.defined()) { - auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict; - for (const auto& kv : attr_dict) { - const String& key = kv.first; - const ObjectRef& value = kv.second; - if (value->IsInstance()) { - attr_dict.Set(key, Substitute(Downcast(value), var_map)); - } - } + // Updating function body + Stmt body = specializer(f->body); + + if (param_updated || buffer_map_updated || !f->body.same_as(body)) { + PrimFuncNode* f_ptr = f.CopyOnWrite(); + f_ptr->params = std::move(params); + f_ptr->buffer_map = std::move(buffer_map); + f_ptr->body = std::move(body); } return f; } private: Stmt VisitStmt_(const BlockNode* op) final { + // Step.0. Define buffer mappings which is allocated inside the block Array alloc_buffers = MutateArray( op->alloc_buffers, std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1)); + + // Step.1. Recursively visit block body + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + Array reads = MutateArray( op->reads, std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1)); Array writes = MutateArray( op->writes, std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1)); - Array block_vars = MutateArray( - op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1)); - Optional init = NullOpt; - if (op->init.defined()) { - init = VisitStmt(op->init.value()); - } - Stmt body = VisitStmt(op->body); - if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) && - writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) && - init.same_as(op->init)) { + if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads)) { return GetRef(op); } else { ObjectPtr n = CopyOnWrite(op); n->alloc_buffers = std::move(alloc_buffers); n->reads = std::move(reads); n->writes = std::move(writes); - n->iter_vars = std::move(block_vars); - n->body = std::move(body); - n->init = std::move(init); return Stmt(n); } } Stmt VisitStmt_(const BufferStoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); auto it = buffer_map_.find(op->buffer); if (it == buffer_map_.end()) { return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->buffer = it->second; + return Stmt(n); } - - PrimExpr value = VisitExpr(op->value); - Array indices = - MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); }); - - auto n = CopyOnWrite(op); - n->buffer = it->second; - n->value = std::move(value); - n->indices = std::move(indices); - return Stmt(n); } PrimExpr VisitExpr_(const BufferLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + ICHECK(op != nullptr); auto it = buffer_map_.find(op->buffer); if (it == buffer_map_.end()) { return GetRef(op); + } else { + auto n = make_object(*op); + n->buffer = it->second; + return PrimExpr(n); } - - Array indices = - MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); }); - - auto n = CopyOnWrite(op); - n->buffer = it->second; - n->indices = std::move(indices); - return PrimExpr(n); } PrimExpr VisitExpr_(const VarNode* op) final { @@ -167,21 +158,24 @@ class PrimFuncSpecializer : public StmtExprMutator { } private: - Buffer MutateBuffer(Buffer buffer) const { - BufferNode* buffer_ptr = buffer.CopyOnWrite(); - Array new_shape, new_stride; - new_shape.reserve(buffer_ptr->shape.size()); - new_shape.reserve(buffer_ptr->strides.size()); - for (const auto& dim : buffer_ptr->shape) { - new_shape.push_back(Substitute(dim, var_map_)); - } - for (const auto& stride : buffer_ptr->strides) { - new_shape.push_back(Substitute(stride, var_map_)); + Buffer MutateBuffer(const Buffer& buffer) const { + Array shape = + MutateArray(buffer->shape, [this](const PrimExpr& e) { return Substitute(e, var_map_); }); + Array strides = + MutateArray(buffer->strides, [this](const PrimExpr& e) { return Substitute(e, var_map_); }); + + PrimExpr elem_offset = Substitute(buffer->elem_offset, var_map_); + + if (buffer->elem_offset.same_as(elem_offset) && buffer->shape.same_as(shape) && + buffer->strides.same_as(strides)) { + return buffer; + } else { + auto n = make_object(*buffer.get()); + n->elem_offset = std::move(elem_offset); + n->shape = std::move(shape); + n->strides = std::move(strides); + return Buffer(n); } - buffer_ptr->elem_offset = Substitute(buffer_ptr->elem_offset, var_map_); - buffer_ptr->shape = std::move(new_shape); - buffer_ptr->strides = std::move(new_stride); - return buffer; } Range MutateRange(const Range& range) { @@ -190,21 +184,7 @@ class PrimFuncSpecializer : public StmtExprMutator { if (min.same_as(range->min) && extent.same_as(range->extent)) { return range; } else { - ObjectPtr n = CopyOnWrite(range.get()); - n->min = std::move(min); - n->extent = std::move(extent); - return Range(n); - } - } - - IterVar MutateIterVar(const IterVar& iter_var) { - Range range = MutateRange(iter_var->dom); - if (range.same_as(iter_var->dom)) { - return iter_var; - } else { - auto n = CopyOnWrite(iter_var.get()); - n->dom = std::move(range); - return IterVar(n); + return Range::FromMinExtent(std::move(min), std::move(extent)); } } @@ -213,6 +193,7 @@ class PrimFuncSpecializer : public StmtExprMutator { if (buf.same_as(alloc_buf)) { return alloc_buf; } else { + ICHECK(buffer_map_.find(alloc_buf) == buffer_map_.end()); buffer_map_[alloc_buf] = buf; return buf; } @@ -226,26 +207,21 @@ class PrimFuncSpecializer : public StmtExprMutator { if (it == buffer_map_.end() && region.same_as(buffer_region->region)) { return buffer_region; } else { - auto n = CopyOnWrite(buffer_region.get()); - n->buffer = it->second; - n->region = std::move(region); - return BufferRegion(n); + return BufferRegion(it->second, std::move(region)); } } private: /*! \brief The vars to be substitute and their values */ - VarMap var_map_; + const VarMap& var_map_; /*! \brief map from old buffer to mutated buffer */ std::unordered_map buffer_map_; }; -/**************** Implementation ****************/ - -PrimFunc Specialize(PrimFunc func, const Var& param, const Buffer& specific_buf) { +void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer& specific_buf, + VarMap* var_map) { // preliminaries tir::ExprDeepEqual equal; - VarMap var_map; auto it = func->buffer_map.find(param); CHECK(it != func->buffer_map.end()) @@ -259,23 +235,26 @@ PrimFunc Specialize(PrimFunc func, const Var& param, const Buffer& specific_buf) << "TypeError: The signature of target buffer exprected an independent Var, but got " << old_expr << "."; const Var& var = Downcast(old_expr); - auto it = var_map.find(var); - if (it != var_map.end()) { + auto it = var_map->find(var); + if (it != var_map->end()) { CHECK(equal(it->second, new_expr)) << "ValueError: The assigned value of var " << var << " mismatched. " << it->second << " vs. " << new_expr << "."; } else { - var_map[var] = new_expr; + (*var_map)[var] = new_expr; } } }; // Check buffer dimensions - CHECK(specific_buf->shape.size() == buf_to_specialize->shape.size() && - specific_buf->strides.size() == buf_to_specialize->strides.size()) + CHECK(specific_buf->shape.size() == buf_to_specialize->shape.size()) << "ValueError: The buffer dimensions mismatched" << buf_to_specialize->shape.size() << " vs. " << specific_buf->shape.size() << "."; + CHECK(specific_buf->strides.size() == buf_to_specialize->strides.size()) + << "ValueError: The buffer strides dimensions mismatched" << buf_to_specialize->strides.size() + << " vs. " << specific_buf->strides.size() << "."; + // Updating var mapping using specific_expr for (size_t i = 0; i < specific_buf->shape.size(); ++i) { build_var_mapping(specific_buf->shape[i], buf_to_specialize->shape[i]); @@ -284,22 +263,27 @@ PrimFunc Specialize(PrimFunc func, const Var& param, const Buffer& specific_buf) build_var_mapping(specific_buf->strides[i], buf_to_specialize->strides[i]); } build_var_mapping(specific_buf->elem_offset, buf_to_specialize->elem_offset); - // Specialize function with var mapping - return PrimFuncSpecializer::Specialize(func, var_map); + + // Check data_alignment and offset_factor. + // These two signatures are int, so we do not need map them. + CHECK_EQ(specific_buf->data_alignment, buf_to_specialize->data_alignment) + << "ValueError: The buffer data_alignment mismatched" << buf_to_specialize->data_alignment + << " vs. " << specific_buf->data_alignment << "."; + + CHECK_EQ(specific_buf->offset_factor, buf_to_specialize->offset_factor) + << "ValueError: The buffer offset_factor mismatched" << buf_to_specialize->offset_factor + << " vs. " << specific_buf->offset_factor << "."; } -PrimFunc Specialize(PrimFunc func, const Var& param, const PrimExpr& specific_expr) { - // preliminaries - VarMap var_map; +void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimExpr& specific_expr, + VarMap* var_map) { // check param is in PrimFunc's parameters CHECK(IsParam(func, param)) << "ValueError: Specialize expects param to be in PrimFunc's params"; // specialize a param not in buffer_map CHECK_EQ(func->buffer_map.count(param), 0) << "ValueError: Specialize expects param to not be in PrimFunc's buffer_map"; // build var mapping using specific_expr - var_map[param] = specific_expr; - // Specialize function with var mapping - return PrimFuncSpecializer::Specialize(std::move(func), var_map); + (*var_map)[param] = specific_expr; } /**************** FFI ****************/ @@ -307,19 +291,20 @@ PrimFunc Specialize(PrimFunc func, const Var& param, const PrimExpr& specific_ex TVM_REGISTER_GLOBAL("tir.Specialize") .set_body_typed)>([](PrimFunc func, Map param_map) { + VarMap var_map; for (const auto& kv : param_map) { const Var& param = kv.first; const ObjectRef& instance = kv.second; if (instance->IsInstance()) { - func = Specialize(std::move(func), param, Downcast(instance)); + UpdateSpecializeVarMap(func, param, Downcast(instance), &var_map); } else if (instance->IsInstance()) { - func = Specialize(std::move(func), param, Downcast(instance)); + UpdateSpecializeVarMap(func, param, Downcast(instance), &var_map); } else { LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or PrimExpr, but got " << instance->GetTypeKey(); } } - return func; + return PrimFuncSpecializer::Specialize(std::move(func), std::move(var_map)); }); } // namespace tir diff --git a/tests/python/unittest/test_tvmscript_meta_programming.py b/tests/python/unittest/test_tvmscript_meta_programming.py index 4068719089d3..5b304d42d785 100644 --- a/tests/python/unittest/test_tvmscript_meta_programming.py +++ b/tests/python/unittest/test_tvmscript_meta_programming.py @@ -117,32 +117,39 @@ def element_wise_128_n(a: ty.handle, c: ty.handle) -> None: @tvm.script.tir -def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None: - A = tir.match_buffer(a, (m, n), "float32") - B = tir.match_buffer(b, (m, n), "float32") +def mem_copy( + a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: ty.int32, q: ty.int32 +) -> None: + A = tir.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=q) + B = tir.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=q) with tir.block([m, n], "") as [vi, vj]: B[vi, vj] = A[vi, vj] @tvm.script.tir -def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - B = tir.match_buffer(b, (16, 16), "float32") +def mem_copy_16_16_8_4(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32", strides=[8, 1], elem_offset=4) + B = tir.match_buffer(b, (16, 16), "float32", strides=[8, 1], elem_offset=4) with tir.block([16, 16], "") as [vi, vj]: B[vi, vj] = A[vi, vj] @tvm.script.tir -def mem_copy_m_16(a: ty.handle, b: ty.handle, m: ty.int32) -> None: - A = tir.match_buffer(a, (m, 16), "float32") - B = tir.match_buffer(b, (m, 16), "float32") +def mem_copy_m_n_p_n(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: ty.int32) -> None: + A = tir.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=n) + B = tir.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=n) - with tir.block([m, 16], "") as [vi, vj]: + with tir.block([m, n], "") as [vi, vj]: B[vi, vj] = A[vi, vj] +def test_specialize_nothing(): + func = matmul.specialize({}) + assert func.same_as(matmul) # Pointer the same + + def test_tensor_dimension_invariant_code_matmul(): a, _, _, n = matmul.params # fully specialized @@ -168,18 +175,19 @@ def test_tensor_dimension_invariant_code_elemwise(): def test_tensor_dimension_invariant_code_mem_copy(): - a, _, m, n = mem_copy.params + a, _, m, n, p, q = mem_copy.params # fully specialized - func = mem_copy.specialize({a: tir.decl_buffer((16, 16))}) - tvm.ir.assert_structural_equal(func, mem_copy_16_16) - func = mem_copy.specialize({n: 16, m: 16}) - tvm.ir.assert_structural_equal(func, mem_copy_16_16) + func = mem_copy.specialize({a: tir.decl_buffer((16, 16), strides=[8, 1], elem_offset=4)}) + tvm.ir.assert_structural_equal(func, mem_copy_16_16_8_4) + func = mem_copy.specialize({n: 16, m: 16, p: 8, q: 4}) + tvm.ir.assert_structural_equal(func, mem_copy_16_16_8_4) # partially specialized - func = mem_copy.specialize({n: 16}) - tvm.ir.assert_structural_equal(func, mem_copy_m_16) + func = mem_copy.specialize({q: n}) + tvm.ir.assert_structural_equal(func, mem_copy_m_n_p_n) if __name__ == "__main__": + test_specialize_nothing() test_tensor_dimension_invariant_code_matmul() test_tensor_dimension_invariant_code_elemwise() test_tensor_dimension_invariant_code_mem_copy() From 22a61ff6756a007517ab93e1e0834fbf80db5d4b Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 30 Jun 2021 16:25:36 +0000 Subject: [PATCH 4/4] address comments --- include/tvm/tir/function.h | 38 ++++++++++++ python/tvm/tir/function.py | 2 +- src/tir/ir/specialize.cc | 62 +++++++++++++------ ..._programming.py => test_tir_specialize.py} | 18 ++++-- 4 files changed, 95 insertions(+), 25 deletions(-) rename tests/python/unittest/{test_tvmscript_meta_programming.py => test_tir_specialize.py} (94%) diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 97ee7f7211d4..25ed2f9ae8d1 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -187,6 +187,44 @@ class LinkedParam : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode); }; +/*! + * \brief Specialize parameters of PrimFunc. + * \param func The PrimFunc to be specialized. + * \param param_map The mapping from function params to the instance. + * \return The new function with parameter specialized. + * \note We can define a Meta TIR function with symbolic shape: + * + * \code + * @tvm.script.tir + * def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None: + * A = tir.match_buffer(a, (m, n), "float32") + * B = tir.match_buffer(b, (m, n), "float32") + * + * with tir.block([m, n], "") as [vi, vj]: + * B[vi, vj] = A[vi, vj] + * \endcode + * + * Then we can make it specialized with given shapes or buffers. + * + * \code + * a, _, m, n = mem_copy.params + * func = mem_copy.specialize({a: tir.decl_buffer((16, 16))}) + * # or + * func = mem_copy.specialize({n: 16, m: 16}) + * \endcode + * + * \code {.language-id} + * @tvm.script.tir + * def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None: + * A = tir.match_buffer(a, (16, 16), "float32") + * B = tir.match_buffer(b, (16, 16), "float32") + * + * with tir.block([16, 16], "") as [vi, vj]: + * B[vi, vj] = A[vi, vj] + * \endcode + */ +PrimFunc Specialize(PrimFunc func, const Map& param_map); + /*! * \brief PrimFunc specific attribute names. * diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index f75e61c8859e..b1081d436150 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -89,7 +89,7 @@ def with_body(self, new_body, span=None): return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map, self.attrs, span) def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]): - """Metaprogramming usage: specialize parameters of PrimFunc + """Specialize parameters of PrimFunc Parameters ---------- diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index c6d21ca70760..aa5f271c20c2 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -218,6 +218,23 @@ class PrimFuncSpecializer : public StmtExprMutator { std::unordered_map buffer_map_; }; +/*! + * \brief Update Specialize var map with buffer matching. + * \param func The function to be specialized. + * \param param The given function parameter + * \param specific_buf The matching buffer. + * \param var_map The var mapping to be updated. + * \note This function will match target buffer's shape, strides and element_offset + * For example, we define a buffer in PrimFunc: + * A = tir.match_buffer(a, [m, n]) + * + * Then we match it with a buffer B = tir.decl_buffer((8, 16)) + * + * It means we have two var mappings here: m = 8 and n = 16 + * + * If the buffer signature is not a Var, the mapping will fail. + * e.g. A = tir.match_buffer(a, [m * 2, n + 1]) + */ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer& specific_buf, VarMap* var_map) { // preliminaries @@ -275,6 +292,13 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer << " vs. " << specific_buf->offset_factor << "."; } +/*! + * \brief Update Specialize var map with parameter value. + * \param func The function to be specialized. + * \param param The given function parameter + * \param specific_expr The parameter value. + * \param var_map The var mapping to be updated. + */ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimExpr& specific_expr, VarMap* var_map) { // check param is in PrimFunc's parameters @@ -286,26 +310,28 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimEx (*var_map)[param] = specific_expr; } +/**************** Implementation ****************/ + +PrimFunc Specialize(PrimFunc func, const Map& param_map) { + VarMap var_map; + for (const auto& kv : param_map) { + const Var& param = kv.first; + const ObjectRef& instance = kv.second; + if (instance->IsInstance()) { + UpdateSpecializeVarMap(func, param, Downcast(instance), &var_map); + } else if (instance->IsInstance()) { + UpdateSpecializeVarMap(func, param, Downcast(instance), &var_map); + } else { + LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or PrimExpr, but got " + << instance->GetTypeKey(); + } + } + return PrimFuncSpecializer::Specialize(func, std::move(var_map)); +} + /**************** FFI ****************/ -TVM_REGISTER_GLOBAL("tir.Specialize") - .set_body_typed)>([](PrimFunc func, - Map param_map) { - VarMap var_map; - for (const auto& kv : param_map) { - const Var& param = kv.first; - const ObjectRef& instance = kv.second; - if (instance->IsInstance()) { - UpdateSpecializeVarMap(func, param, Downcast(instance), &var_map); - } else if (instance->IsInstance()) { - UpdateSpecializeVarMap(func, param, Downcast(instance), &var_map); - } else { - LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or PrimExpr, but got " - << instance->GetTypeKey(); - } - } - return PrimFuncSpecializer::Specialize(std::move(func), std::move(var_map)); - }); +TVM_REGISTER_GLOBAL("tir.Specialize").set_body_typed(Specialize); } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_meta_programming.py b/tests/python/unittest/test_tir_specialize.py similarity index 94% rename from tests/python/unittest/test_tvmscript_meta_programming.py rename to tests/python/unittest/test_tir_specialize.py index 5b304d42d785..2e9f1110732a 100644 --- a/tests/python/unittest/test_tvmscript_meta_programming.py +++ b/tests/python/unittest/test_tir_specialize.py @@ -150,7 +150,7 @@ def test_specialize_nothing(): assert func.same_as(matmul) # Pointer the same -def test_tensor_dimension_invariant_code_matmul(): +def test_specialize_matmul(): a, _, _, n = matmul.params # fully specialized func = matmul.specialize({a: tir.decl_buffer((128, 128))}) @@ -163,7 +163,7 @@ def test_tensor_dimension_invariant_code_matmul(): tvm.ir.assert_structural_equal(func, matmul_m_8x) -def test_tensor_dimension_invariant_code_elemwise(): +def test_specialize_elemwise(): a, c = element_wise.params C = element_wise.buffer_map[c] # fully specialized @@ -174,7 +174,7 @@ def test_tensor_dimension_invariant_code_elemwise(): tvm.ir.assert_structural_equal(func, element_wise_128_n) -def test_tensor_dimension_invariant_code_mem_copy(): +def test_specialize_mem_copy(): a, _, m, n, p, q = mem_copy.params # fully specialized func = mem_copy.specialize({a: tir.decl_buffer((16, 16), strides=[8, 1], elem_offset=4)}) @@ -186,8 +186,14 @@ def test_tensor_dimension_invariant_code_mem_copy(): tvm.ir.assert_structural_equal(func, mem_copy_m_n_p_n) +def test_specialize_recursive_load(): + # TODO(Siyuan): add recursive Load testcase, e.g. A[C[i]] + pass + + if __name__ == "__main__": test_specialize_nothing() - test_tensor_dimension_invariant_code_matmul() - test_tensor_dimension_invariant_code_elemwise() - test_tensor_dimension_invariant_code_mem_copy() + test_specialize_matmul() + test_specialize_elemwise() + test_specialize_mem_copy() + test_specialize_recursive_load()