diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index a8edc2675fc4..ec8e32526abb 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -31,9 +31,15 @@ #include #include +#include #include namespace tvm { + +namespace arith { +class Analyzer; +} + namespace tir { /*! @@ -203,6 +209,29 @@ TVM_DLL Array> GetBlockAccessRegion(const Block& block, TVM_DLL Array> GetBlockReadWriteRegion(const Block& block, const Map& buffer_var_map); +/*! \brief Helper struct for return value of IdentifyMemCpy + * + * This helper struct is not strictly necessary, as `IdentifyMemCpy` + * could instead return a `std::pair`. + * However, that would introduce ambiguity between the two unnamed + * regions. + */ +struct MemCpyDetails { + BufferRegion source; + BufferRegion dest; +}; + +/*! \brief Identify whether a For loop is semantically equivalent to MemCpy + * + * \param loop The loop to be checked + * + * \param analyzer The analyzer with which to check any algebraic expressions + * + * \returns The source and destination regions being copied, if the + * loop is equivalent to memcpy. Otherwise, returns nullopt. + */ +TVM_DLL std::optional IdentifyMemCpy(const For& loop, arith::Analyzer* analyzer); + /*! * \brief Calculate the expresion complexity based on number of symbols it contains. * \param expr The expr to be calculated. diff --git a/src/tir/analysis/identify_memcpy.cc b/src/tir/analysis/identify_memcpy.cc new file mode 100644 index 000000000000..0d3b48dbc2c6 --- /dev/null +++ b/src/tir/analysis/identify_memcpy.cc @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/identify_memcpy.cc + * \brief Check if a loop nest is equivalent to memcpy + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../arith/ir_visitor_with_analyzer.h" + +namespace tvm { +namespace tir { + +std::variant IdentifyMemCpyImpl(const For& loop, + arith::Analyzer* analyzer) { + Map loop_intervals; + Map loop_ranges; + PrimExpr total_loop_iterations = 1; + + // Walk through the loop nest, stopping at the first loop whose body + // is not a loop. + Stmt stmt = loop; + while (auto* for_node = stmt.as()) { + loop_ranges.Set(for_node->loop_var, Range::FromMinExtent(for_node->min, for_node->extent)); + loop_intervals.Set(for_node->loop_var, + arith::IntSet::FromMinExtent(for_node->min, for_node->extent)); + total_loop_iterations = total_loop_iterations * for_node->extent; + + stmt = for_node->body; + } + + BufferStore store; + if (auto* ptr = stmt.as()) { + store = GetRef(ptr); + } else { + return static_cast( + std::stringstream() + << "Expected innermost loop to have BufferStore body, but instead found " << stmt) + .str(); + } + + BufferLoad load; + if (auto* ptr = store->value.as()) { + load = GetRef(ptr); + } else { + return static_cast( + std::stringstream() + << "Expected BufferStore's value to be BufferLoad, but instead found " + << store->value) + .str(); + } + + // Now, we have a BufferStore whose value is a BufferLoad. Because + // non-flat physical indices are target-dependent, only handle cases + // where the buffer will be flattened to a 1-d physical buffer. + Array flattened_dst = store->buffer.OffsetOf(store->indices); + Array flattened_src = load->buffer.OffsetOf(load->indices); + + if (flattened_dst.size() != 1 || flattened_src.size() != 1) { + return static_cast( + std::stringstream() + << "Expected flattened dimension of src/dest to be 1, but found" + << flattened_src.size() << "-d src and " << flattened_dst.size() << "-d dst") + .str(); + } + PrimExpr src_index = flattened_src[0]; + PrimExpr dst_index = flattened_dst[0]; + + // First check, do the input/output form affine subsets of their + // respective buffers? + // + // For example, should exclude the following, indices are not affine + // + // for i in T.serial(16): + // B[i] = A[T.abs(i-8)] + + auto src_iter_map = arith::DetectIterMap({src_index}, loop_ranges, Bool(true), + arith::IterMapLevel::Bijective, analyzer); + if (src_iter_map->errors.size()) { + return static_cast(std::stringstream() + << "arith::DetectIterMap(src) returned " + << src_iter_map->errors.size() << " errors: [" + << src_iter_map->errors << "]" + << " for src_index = " << src_index) + .str(); + } + auto dst_iter_map = arith::DetectIterMap({dst_index}, loop_ranges, Bool(true), + arith::IterMapLevel::Bijective, analyzer); + if (dst_iter_map->errors.size()) { + return static_cast(std::stringstream() + << "arith::DetectIterMap(dst) returned " + << dst_iter_map->errors.size() << " errors: [" + << dst_iter_map->errors << "]" + << " for dst_index = " << dst_index) + .str(); + } + + // Second check, are those affine subsets contiguous? If so, then + // the index expressions will visit every location between the min + // and the max. This checks surjectivity over a linear region, + // which may not be the same as DetectIterMap's check of + // surjectivity over the affine subset. + // + // For example, should exclude the following, doesn't touch all + // output locations within the output region touched. + // + // for i in T.serial(16): + // B[2*i] = A[i] + // + // Similarly, should exclude the following, doesn't touch all + // input locations within the input region touched. + // + // for i in T.serial(16): + // B[i] = A[2*i] + total_loop_iterations = analyzer->Simplify(total_loop_iterations); + auto src_interval = analyzer->int_set(src_index, loop_intervals); + auto dst_interval = analyzer->int_set(dst_index, loop_intervals); + + if (!src_interval.HasLowerBound() || !src_interval.HasUpperBound()) { + return static_cast(std::stringstream() + << "Expected known bounds for src, but found " + << src_interval << " for expression " << src_index) + .str(); + } + if (!dst_interval.HasLowerBound() || !dst_interval.HasUpperBound()) { + return static_cast(std::stringstream() + << "Expected known bounds for dst, but found " + << dst_interval << " for expression " << dst_index) + .str(); + } + + { + PrimExpr must_prove = total_loop_iterations == src_interval.max() - src_interval.min() + 1; + PrimExpr simplified = analyzer->Simplify(must_prove); + if (!analyzer->CanProve(simplified)) { + return static_cast( + std::stringstream() + << "Mismatch between loop iterations (" << total_loop_iterations + << ") and number of src indices touched (" << src_interval + << ". Equality to prove simplified to " << simplified) + .str(); + } + } + { + PrimExpr must_prove = total_loop_iterations == dst_interval.max() - dst_interval.min() + 1; + PrimExpr simplified = analyzer->Simplify(must_prove); + if (!analyzer->CanProve(simplified)) { + return static_cast( + std::stringstream() + << "Mismatch between loop iterations (" << total_loop_iterations + << ") and number of dst indices touched (" << dst_interval + << ". Equality to prove simplified to " << simplified) + .str(); + } + } + + // Third check, is there a transformation applied between the input + // and output iterators? + // + // For example, the following would pass all checks so far, but + // converts between row-major and column-major layouts, and could + // not be specified as a memcpy. + // + // for i,j in T.grid(4,4): + // B[i,j] = A[j,i] + + auto src_iter_sum = src_iter_map->indices[0]; + auto dst_iter_sum = dst_iter_map->indices[0]; + + if (src_iter_sum->args.size() != dst_iter_sum->args.size()) { + return static_cast( + std::stringstream() + << "IterMap for src/dst unpacked to different number of IterSplitExpr: " + << src_iter_sum->args.size() << " for src, " << dst_iter_sum->args.size() + << " for dst. " + << "IterMaps were detected as src = " << src_iter_sum << ", dst = " << dst_iter_sum) + .str(); + } + std::vector src_iter_terms(src_iter_sum->args.begin(), + src_iter_sum->args.end()); + std::vector dst_iter_terms(dst_iter_sum->args.begin(), + dst_iter_sum->args.end()); + + auto make_comparison_tuple = [](const arith::IterSplitExpr& expr) { + auto as_int_or_zero = [](auto& val) -> int64_t { + if (auto* as_int = val.template as()) { + return as_int->value; + } else { + return 0; + } + }; + return std::tuple{ + static_cast(expr->scale.as()), as_int_or_zero(expr->scale), + static_cast(expr->extent.as()), as_int_or_zero(expr->lower_factor), + static_cast(expr->lower_factor.as()), as_int_or_zero(expr->lower_factor), + }; + }; + auto sorting_function = [&make_comparison_tuple](const arith::IterSplitExpr& lhs, + const arith::IterSplitExpr& rhs) -> bool { + return make_comparison_tuple(lhs) < make_comparison_tuple(rhs); + }; + std::sort(src_iter_terms.begin(), src_iter_terms.end(), sorting_function); + std::sort(dst_iter_terms.begin(), dst_iter_terms.end(), sorting_function); + + for (size_t i = 0; i < src_iter_terms.size(); i++) { + const arith::IterSplitExpr& src_term = src_iter_terms[i]; + const arith::IterSplitExpr& dst_term = dst_iter_terms[i]; + + if (!analyzer->CanProve( + arith::NormalizeIterMapToExpr(src_term->source->source == dst_term->source->source))) { + return static_cast( + std::stringstream() + << "Term " << i << " had different source, src_term->source = " << src_term->source + << ", dst_term->source = " << dst_term->source) + .str(); + } + if (!analyzer->CanProve(src_term->lower_factor == dst_term->lower_factor)) { + return static_cast( + std::stringstream() + << "Term " << i << " had different lower_factor, src_term->lower_factor = " + << src_term->lower_factor + << ", dst_term->lower_factor = " << dst_term->lower_factor) + .str(); + } + if (!analyzer->CanProve(src_term->extent == dst_term->extent)) { + return static_cast( + std::stringstream() + << "Term " << i << " had different extent, src_term->extent = " << src_term->extent + << ", dst_term->extent = " << dst_term->extent) + .str(); + } + if (!analyzer->CanProve(src_term->scale == dst_term->scale)) { + return static_cast( + std::stringstream() + << "Term " << i << " had different scale, src_term->scale = " << src_term->scale + << ", dst_term->scale = " << dst_term->scale) + .str(); + } + } + + BufferRegion src_region(load->buffer, arith::DomainTouched(loop, load->buffer, true, true)); + BufferRegion dst_region(store->buffer, arith::DomainTouched(loop, store->buffer, true, true)); + + return MemCpyDetails{src_region, dst_region}; +} + +std::optional IdentifyMemCpy(const For& loop, arith::Analyzer* analyzer) { + auto result = IdentifyMemCpyImpl(loop, analyzer); + if (auto* ptr = std::get_if(&result)) { + return *ptr; + } else { + return std::nullopt; + } +} + +// Expose the IdentifyMemCpy functionality to Python API for purpose of unit testing. +TVM_REGISTER_GLOBAL("tir.analysis._identify_memcpy").set_body_typed([](const Stmt& stmt) { + Array output; + + struct Visitor : arith::IRVisitorWithAnalyzer { + explicit Visitor(Array* output) : output(output) {} + Array* output; + + private: + using IRVisitorWithAnalyzer::VisitStmt_; + void VisitStmt_(const ForNode* op) override { + For loop = GetRef(op); + auto result = IdentifyMemCpyImpl(loop, &analyzer_); + if (auto* ptr = std::get_if(&result)) { + output->push_back(Array{ptr->source, ptr->dest}); + } else if (auto* ptr = std::get_if(&result)) { + output->push_back(StringImm(*ptr)); + } else { + LOG(FATAL) << "Internal error, unhandled std::variant type"; + } + + IRVisitorWithAnalyzer::VisitStmt_(op); + } + }; + + Visitor visitor(&output); + visitor(stmt); + + return output; +}); + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_analysis_identify_memcpy.py b/tests/python/unittest/test_tir_analysis_identify_memcpy.py new file mode 100644 index 000000000000..b69d3aea3ea3 --- /dev/null +++ b/tests/python/unittest/test_tir_analysis_identify_memcpy.py @@ -0,0 +1,324 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re + +import pytest + +import tvm +import tvm.testing +from tvm.tir import BufferRegion, StringImm + +from tvm.script import tir as T + +identify_memcpy = tvm.tir.analysis._ffi_api._identify_memcpy + + +class BaseTest: + """Utility class for defining unit tests for memcpy""" + + def __init_subclass__(cls): + cls.func = tvm.testing.CompareBeforeAfter._normalize_before(cls.func) + cls.expected = pytest.fixture(cls.expected) + + def test_identify_memcpy(self, func, expected): + results = identify_memcpy(func.body) + + if isinstance(expected, str) or ( + isinstance(expected, tuple) and isinstance(expected[0], BufferRegion) + ): + expected = [expected] + + assert len(expected) == len(results) + for expected, result in zip(expected, results): + if isinstance(expected, str): + assert isinstance(result, StringImm) + assert re.search(expected, result.value) + else: + tvm.ir.assert_structural_equal(result, expected) + + +class Test1D(BaseTest): + """Simplest test case""" + + def func(A: T.Buffer[1024, "float32"], B: T.Buffer[1024, "float32"]): + for i in T.serial(1024): + B[i] = A[i] + + def expected(self, func): + A, B = func.buffer_map.values() + return A[0:1024], B[0:1024] + + +class Test1DCompute(BaseTest): + """Like Test1D, but a computation prevents this being a memcpy""" + + def func(A: T.Buffer[1024, "float32"], B: T.Buffer[1024, "float32"]): + for i in T.serial(1024): + B[i] = A[i] + 1.0 + + def expected(self, func): + return "Expected BufferStore's value to be BufferLoad" + + +class Test1DConditional(BaseTest): + """Like Test1D, but a conditionals prevents this being a memcpy""" + + def func(A: T.Buffer[1024, "float32"], B: T.Buffer[1024, "float32"]): + for i in T.serial(1024): + if i < 1024: + B[i] = A[i] + + def expected(self, func): + A, B = func.buffer_map.values() + return "Expected innermost loop to have BufferStore body" + + +class Test1DStridedInput(BaseTest): + """Like Test1D, but strided input prevents this being a memcpy""" + + def func(A: T.Buffer[2048, "float32"], B: T.Buffer[1024, "float32"]): + for i in T.serial(1024): + B[i] = A[i * 2] + + def expected(self, func): + return "Mismatch between loop iterations (.*) and number of src indices" + + +class Test1DStridedOutput(BaseTest): + """Like Test1D, but strided output prevents this being a memcpy""" + + def func(A: T.Buffer[1024, "float32"], B: T.Buffer[2048, "float32"]): + for i in T.serial(1024): + B[i * 2] = A[i] + + def expected(self, func): + return "Mismatch between loop iterations (.*) and number of dst indices" + + +class Test1DInput2DOutputFusedLoop(BaseTest): + """Like Test1D, but the output is written as a 2-d buffer""" + + def func(A: T.Buffer[1024, "float32"], B: T.Buffer[(32, 32), "float32"]): + for i in T.serial(1024): + B[i // 32, i % 32] = A[i] + + def expected(self, func): + A, B = func.buffer_map.values() + return A[0:1024], B[0:32, 0:32] + + +class Test2DInput1DOutputFusedLoop(BaseTest): + """Like Test1D, but the input is written as a 2-d buffer""" + + def func(A: T.Buffer[(32, 32), "float32"], B: T.Buffer[1024, "float32"]): + for i in T.serial(1024): + B[i] = A[i // 32, i % 32] + + def expected(self, func): + A, B = func.buffer_map.values() + return A[0:32, 0:32], B[0:1024] + + +class Test1DInput1DOutputNestedLoop(BaseTest): + """Like Test1D, but the iterator is written as a nested loop + + In test cases with more than one loop, each loop is checked to see + if could be written as a memcpy. The C++ utility function + operates on individual loops, but for unit testing in Python, it + is more convenient to return the results for all loops. + """ + + def func(A: T.Buffer[1024, "float32"], B: T.Buffer[1024, "float32"]): + for i, j in T.grid(32, 32): + B[i * 32 + j] = A[i * 32 + j] + + def expected(self, func): + A, B = func.buffer_map.values() + i = func.body.loop_var + return [ + (A[0:1024], B[0:1024]), + (A[i * 32 : i * 32 + 32], B[i * 32 : i * 32 + 32]), + ] + + +class Test1DInput1DOutputNestedLoopEquivalentExpressions(BaseTest): + """Like Test1DInput1DOutputNestedLoop, but with equivalent indices + + If the expressions are not identical, the loops may still be + recognizable as a memcpy, so long as the expressions are + equivalent. + """ + + def func(A: T.Buffer[1024, "float32"], B: T.Buffer[1024, "float32"]): + for i, j in T.grid(32, 32): + B[i * 32 + j] = A[j + i * 32] + + def expected(self, func): + A, B = func.buffer_map.values() + i = func.body.loop_var + return [ + (A[0:1024], B[0:1024]), + (A[i * 32 : i * 32 + 32], B[i * 32 : i * 32 + 32]), + ] + + +class Test1DInput2DOutputNestedLoop(BaseTest): + """Like Test1DInput1DOutputNestedLoop, but with a 2-d output buffer""" + + def func(A: T.Buffer[1024, "float32"], B: T.Buffer[(32, 32), "float32"]): + for i, j in T.grid(32, 32): + B[i, j] = A[i * 32 + j] + + def expected(self, func): + A, B = func.buffer_map.values() + i = func.body.loop_var + return [ + (A[0:1024], B[0:32, 0:32]), + (A[i * 32 : i * 32 + 32], B[i, 0:32]), + ] + + +class Test2DInput1DOutputNestedLoop(BaseTest): + """Like Test1DInput1DOutputNestedLoop, but with a 2-d input buffer""" + + def func(A: T.Buffer[(32, 32), "float32"], B: T.Buffer[1024, "float32"]): + for i, j in T.grid(32, 32): + B[i * 32 + j] = A[i, j] + + def expected(self, func): + A, B = func.buffer_map.values() + i = func.body.loop_var + return [ + (A[0:32, 0:32], B[0:1024]), + (A[i, 0:32], B[i * 32 : i * 32 + 32]), + ] + + +class Test2DInput2DOutputNestedLoop(BaseTest): + """Like Test1DInput1DOutputNestedLoop, but with 2-d input/output buffers""" + + def func(A: T.Buffer[(32, 32), "float32"], B: T.Buffer[(32, 32), "float32"]): + for i, j in T.grid(32, 32): + B[i, j] = A[i, j] + + def expected(self, func): + A, B = func.buffer_map.values() + i = func.body.loop_var + return [ + (A[0:32, 0:32], B[0:32, 0:32]), + (A[i, 0:32], B[i, 0:32]), + ] + + +class Test2DInput2DOutputTransposeOutput(BaseTest): + """Test2DInput2DOutputNestedLoop, but with a transposed output + + This is not recognized as a memcpy, because it results in a transpose. + """ + + def func(A: T.Buffer[(32, 32), "float32"], B: T.Buffer[(32, 32), "float32"]): + for i, j in T.grid(32, 32): + B[j, i] = A[i, j] + + def expected(self, func): + return [ + "different source", + "Mismatch .* number of dst indices touched", + ] + + +class Test2DInput2DOutputTransposeInput(BaseTest): + """Test2DInput2DOutputNestedLoop, but with a transposed input + + This is not recognized as a memcpy, because it results in a transpose. + """ + + def func(A: T.Buffer[(32, 32), "float32"], B: T.Buffer[(32, 32), "float32"]): + for i, j in T.grid(32, 32): + B[i, j] = A[j, i] + + def expected(self, func): + return [ + "different source", + "Mismatch .* number of src indices touched", + ] + + +class Test2DInput2DOutputTransposeBoth(BaseTest): + """Test2DInput2DOutputNestedLoop, but with a transposed input + + The inner loop is not recognized as a memcpy, because it has + strided access of both the input and output buffers. However, the + outer loop is still recognized as a memcpy, because the full + region has been copied over, even though it occurs out of order. + """ + + def func(A: T.Buffer[(32, 32), "float32"], B: T.Buffer[(32, 32), "float32"]): + for i, j in T.grid(32, 32): + B[j, i] = A[j, i] + + def expected(self, func): + A, B = func.buffer_map.values() + return [ + (A[0:32, 0:32], B[0:32, 0:32]), + "Mismatch .* number of src indices touched", + ] + + +class TestCacheRead(BaseTest): + """Like Test2DInput2DOutputNestedLoop, but with a 1-d + + The inner loop is a memcpy of a single row at a time. This + pattern would appear when B is a read cache of A. + """ + + def func(A: T.Buffer[(32, 32), "float32"], B: T.Buffer[32, "float32"]): + for i, j in T.grid(32, 32): + B[j] = A[i, j] + + def expected(self, func): + A, B = func.buffer_map.values() + i = func.body.loop_var + return [ + "does not form a bijective transform", + (A[i, 0:32], B[0:32]), + ] + + +class TestCacheWrite(BaseTest): + """Like Test2DInput2DOutputNestedLoop, but with a 1-d + + The inner loop is a memcpy of a single row at a time. This + pattern would appear when A is a write cache of B. + """ + + def func(A: T.Buffer[32, "float32"], B: T.Buffer[(32, 32), "float32"]): + for i, j in T.grid(32, 32): + B[i, j] = A[j] + + def expected(self, func): + A, B = func.buffer_map.values() + i = func.body.loop_var + return [ + "does not form a bijective transform", + (A[0:32], B[i, 0:32]), + ] + + +if __name__ == "__main__": + tvm.testing.main()