From b3b383e21a24f16c0a9afc7b4ef7466a43f43baf Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Fri, 10 Feb 2023 08:29:24 -0600 Subject: [PATCH 1/8] [TIR][Analysis] Implement IdentifyMemCpy analysis function This commit adds a utility function `tir::IdentifyMemCpy`, which analyzes a `tir::For` loop and determines whether it is equivalent to a memcpy. If it is equivalent, it will additionally return the source and destination regions of the memcpy. This utility is initially intended for use in the `LowerAsyncDMA` pass, but may be useful in other areas in the future. (e.g. Identifying an entire TIR function as equivalent to memcpy would allow calls to that TIR function to be removed from end-to-end models.) --- include/tvm/tir/analysis.h | 29 ++ python/tvm/tir/analysis/analysis.py | 1 + src/tir/analysis/identify_memcpy.cc | 302 ++++++++++++++++ .../test_tir_analysis_identify_memcpy.py | 325 ++++++++++++++++++ 4 files changed, 657 insertions(+) create mode 100644 src/tir/analysis/identify_memcpy.cc create mode 100644 tests/python/unittest/test_tir_analysis_identify_memcpy.py diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index a8edc2675fc4..ec8e32526abb 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -31,9 +31,15 @@ #include #include +#include #include namespace tvm { + +namespace arith { +class Analyzer; +} + namespace tir { /*! @@ -203,6 +209,29 @@ TVM_DLL Array> GetBlockAccessRegion(const Block& block, TVM_DLL Array> GetBlockReadWriteRegion(const Block& block, const Map& buffer_var_map); +/*! \brief Helper struct for return value of IdentifyMemCpy + * + * This helper struct is not strictly necessary, as `IdentifyMemCpy` + * could instead return a `std::pair`. + * However, that would introduce ambiguity between the two unnamed + * regions. + */ +struct MemCpyDetails { + BufferRegion source; + BufferRegion dest; +}; + +/*! \brief Identify whether a For loop is semantically equivalent to MemCpy + * + * \param loop The loop to be checked + * + * \param analyzer The analyzer with which to check any algebraic expressions + * + * \returns The source and destination regions being copied, if the + * loop is equivalent to memcpy. Otherwise, returns nullopt. + */ +TVM_DLL std::optional IdentifyMemCpy(const For& loop, arith::Analyzer* analyzer); + /*! * \brief Calculate the expresion complexity based on number of symbols it contains. * \param expr The expr to be calculated. diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 45b1f745c3de..c023976ad9c8 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -18,6 +18,7 @@ # pylint: disable=invalid-name from typing import Dict, List, Union +import tvm from tvm import Object from tvm.ir import IRModule from tvm.tir.expr import Var diff --git a/src/tir/analysis/identify_memcpy.cc b/src/tir/analysis/identify_memcpy.cc new file mode 100644 index 000000000000..a9605c903b3b --- /dev/null +++ b/src/tir/analysis/identify_memcpy.cc @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/identify_memcpy.cc + * \brief Check if a loopnest is equivalent to memcpy + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../arith/ir_visitor_with_analyzer.h" + +namespace tvm { +namespace tir { + +std::variant IdentifyMemCpyImpl(const For& loop, + arith::Analyzer* analyzer) { + Map loop_intervals; + Map loop_ranges; + PrimExpr total_loop_iterations = 1; + + // Walk through the loopnest, stopping at the first loop whose body + // is not a loop. + Stmt stmt = loop; + while (auto* for_node = stmt.as()) { + loop_ranges.Set(for_node->loop_var, Range::FromMinExtent(for_node->min, for_node->extent)); + loop_intervals.Set(for_node->loop_var, + arith::IntSet::FromMinExtent(for_node->min, for_node->extent)); + total_loop_iterations = total_loop_iterations * for_node->extent; + + stmt = for_node->body; + } + + BufferStore store; + if (auto* ptr = stmt.as()) { + store = GetRef(ptr); + } else { + return ( + std::stringstream() + << "Expected innermost loop to have BufferStore body, but instead found " << stmt) + .str(); + } + + BufferLoad load; + if (auto* ptr = store->value.as()) { + load = GetRef(ptr); + } else { + return ( + std::stringstream() + << "Expected BufferStore's value to be BufferLoad, but instead found " + << store->value) + .str(); + } + + // Now, we have a BufferStore whose value is a BufferLoad. Because + // non-flat physical indices are target-dependent, only handle cases + // where the buffer will be flattened to a 1-d physical buffer. + Array flattened_dst = store->buffer.OffsetOf(store->indices); + Array flattened_src = load->buffer.OffsetOf(load->indices); + + if (flattened_dst.size() != 1 || flattened_src.size() != 1) { + return (std::stringstream() << "Expected flattened dimension of src/dest to be 1, but found" + << flattened_src.size() << "-d src and " << flattened_dst.size() + << "-d dst") + .str(); + } + PrimExpr src_index = flattened_src[0]; + PrimExpr dst_index = flattened_dst[0]; + + // First check, do the input/output form affine subsets of their + // respective buffers? + // + // For example, should exclude the following, indices are not affine + // + // for i in T.serial(16): + // B[i] = A[T.abs(i-8)] + + auto src_iter_map = arith::DetectIterMap({src_index}, loop_ranges, Bool(true), + arith::IterMapLevel::Bijective, analyzer); + if (src_iter_map->errors.size()) { + return (std::stringstream() << "arith::DetectIterMap(src) returned " + << src_iter_map->errors.size() << " errors: [" + << src_iter_map->errors << "]" + << " for src_index = " << src_index) + .str(); + } + auto dst_iter_map = arith::DetectIterMap({dst_index}, loop_ranges, Bool(true), + arith::IterMapLevel::Bijective, analyzer); + if (dst_iter_map->errors.size()) { + return (std::stringstream() << "arith::DetectIterMap(dst) returned " + << dst_iter_map->errors.size() << " errors: [" + << dst_iter_map->errors << "]" + << " for dst_index = " << dst_index) + .str(); + } + + // Second check, are those affine subsets contiguous? If so, then + // the index expressions will visit every location between the min + // and the max. This checks surjectivity over a linear region, + // which may not be the same as DetectIterMap's check of + // surjectivity over the affine subset. + // + // For example, should exclude the following, doesn't touch all + // output locations within the output region touched. + // + // for i in T.serial(16): + // B[2*i] = A[i] + // + // Similarly, should exclude the following, doesn't touch all + // input locations within the input region touched. + // + // for i in T.serial(16): + // B[i] = A[2*i] + total_loop_iterations = analyzer->Simplify(total_loop_iterations); + auto src_interval = analyzer->int_set(src_index, loop_intervals); + auto dst_interval = analyzer->int_set(dst_index, loop_intervals); + + if (!src_interval.HasLowerBound() || !src_interval.HasUpperBound()) { + return (std::stringstream() << "Expected known bounds for src, but found " << src_interval + << " for expression " << src_index) + .str(); + } + if (!dst_interval.HasLowerBound() || !dst_interval.HasUpperBound()) { + return (std::stringstream() << "Expected known bounds for dst, but found " << dst_interval + << " for expression " << dst_index) + .str(); + } + + { + PrimExpr must_prove = total_loop_iterations == src_interval.max() - src_interval.min() + 1; + PrimExpr simplified = analyzer->Simplify(must_prove); + if (!analyzer->CanProve(simplified)) { + return (std::stringstream() << "Mismatch between loop iterations (" << total_loop_iterations + << ") and number of src indices touched (" << src_interval + << ". Equality to prove simplified to " << simplified) + .str(); + } + } + { + PrimExpr must_prove = total_loop_iterations == dst_interval.max() - dst_interval.min() + 1; + PrimExpr simplified = analyzer->Simplify(must_prove); + if (!analyzer->CanProve(simplified)) { + return (std::stringstream() << "Mismatch between loop iterations (" << total_loop_iterations + << ") and number of dst indices touched (" << dst_interval + << ". Equality to prove simplified to " << simplified) + .str(); + } + } + + // Thrid check, is there a transformation applied between the input + // and output iterators? + // + // For example, the following would pass all checks so far, but + // converts between row-major and column-major layouts, and could + // not be specified as a memcpy. + // + // for i,j in T.grid(4,4): + // B[i,j] = A[j,i] + + auto src_iter_sum = src_iter_map->indices[0]; + auto dst_iter_sum = dst_iter_map->indices[0]; + + if (src_iter_sum->args.size() != dst_iter_sum->args.size()) { + return ( + std::stringstream() + << "IterMap for src/dst unpacked to different number of IterSplitExpr: " + << src_iter_sum->args.size() << " for src, " << dst_iter_sum->args.size() + << " for dst. " + << "IterMaps were detected as src = " << src_iter_sum << ", dst = " << dst_iter_sum) + .str(); + } + std::vector src_iter_terms(src_iter_sum->args.begin(), + src_iter_sum->args.end()); + std::vector dst_iter_terms(dst_iter_sum->args.begin(), + dst_iter_sum->args.end()); + + auto make_comparison_tuple = [](const arith::IterSplitExpr& expr) { + auto as_int_or_zero = [](auto& val) -> int64_t { + if (auto* as_int = val.template as()) { + return as_int->value; + } else { + return 0; + } + }; + return std::tuple{ + bool(expr->scale.as()), as_int_or_zero(expr->scale), + bool(expr->extent.as()), as_int_or_zero(expr->lower_factor), + bool(expr->lower_factor.as()), as_int_or_zero(expr->lower_factor), + }; + }; + auto sorting_function = [&make_comparison_tuple](const arith::IterSplitExpr& lhs, + const arith::IterSplitExpr& rhs) -> bool { + return make_comparison_tuple(lhs) < make_comparison_tuple(rhs); + }; + std::sort(src_iter_terms.begin(), src_iter_terms.end(), sorting_function); + std::sort(dst_iter_terms.begin(), dst_iter_terms.end(), sorting_function); + + for (size_t i = 0; i < src_iter_terms.size(); i++) { + const arith::IterSplitExpr& src_term = src_iter_terms[i]; + const arith::IterSplitExpr& dst_term = dst_iter_terms[i]; + + if (!analyzer->CanProve( + arith::NormalizeIterMapToExpr(src_term->source->source == dst_term->source->source))) { + return (std::stringstream() << "Term " << i << " had different source, src_term->source = " + << src_term->source + << ", dst_term->source = " << dst_term->source) + .str(); + } + if (!analyzer->CanProve(src_term->lower_factor == dst_term->lower_factor)) { + return (std::stringstream() << "Term " << i + << " had different lower_factor, src_term->lower_factor = " + << src_term->lower_factor + << ", dst_term->lower_factor = " << dst_term->lower_factor) + .str(); + } + if (!analyzer->CanProve(src_term->extent == dst_term->extent)) { + return (std::stringstream() << "Term " << i << " had different extent, src_term->extent = " + << src_term->extent + << ", dst_term->extent = " << dst_term->extent) + .str(); + } + if (!analyzer->CanProve(src_term->scale == dst_term->scale)) { + return (std::stringstream() << "Term " << i << " had different scale, src_term->scale = " + << src_term->scale << ", dst_term->scale = " << dst_term->scale) + .str(); + } + } + + BufferRegion src_region(load->buffer, arith::DomainTouched(loop, load->buffer, true, true)); + BufferRegion dst_region(store->buffer, arith::DomainTouched(loop, store->buffer, true, true)); + + return MemCpyDetails{src_region, dst_region}; +} + +std::optional IdentifyMemCpy(const For& loop, arith::Analyzer* analyzer) { + auto result = IdentifyMemCpyImpl(loop, analyzer); + if (auto* ptr = std::get_if(&result)) { + return *ptr; + } else { + return std::nullopt; + } +} + +// Expose the IdentifyMemCpy functionality to Python API for purpose +// of unit testing. +TVM_REGISTER_GLOBAL("tir.analysis._identify_memcpy").set_body_typed([](const Stmt& stmt) { + Array output; + + struct Visitor : arith::IRVisitorWithAnalyzer { + Visitor(Array* output) : output(output) {} + Array* output; + + private: + using IRVisitorWithAnalyzer::VisitStmt_; + void VisitStmt_(const ForNode* op) override { + For loop = GetRef(op); + auto result = IdentifyMemCpyImpl(loop, &analyzer_); + if (auto* ptr = std::get_if(&result)) { + output->push_back(Array{ptr->source, ptr->dest}); + } else if (auto* ptr = std::get_if(&result)) { + output->push_back(StringImm(*ptr)); + } else { + LOG(FATAL) << "Internal error, unhandled std::variant type"; + } + + IRVisitorWithAnalyzer::VisitStmt_(op); + } + }; + + Visitor{&output}(stmt); + + return output; +}); + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_analysis_identify_memcpy.py b/tests/python/unittest/test_tir_analysis_identify_memcpy.py new file mode 100644 index 000000000000..39807b3191e0 --- /dev/null +++ b/tests/python/unittest/test_tir_analysis_identify_memcpy.py @@ -0,0 +1,325 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re + +import pytest + +import tvm +import tvm.testing +from tvm import te, topi +from tvm.tir import BufferRegion, StringImm + +from tvm.script import tir as T + +identify_memcpy = tvm.tir.analysis._ffi_api._identify_memcpy + + +class BaseTest: + """Utility class for defining unit tests for memcpy """ + + def __init_subclass__(cls): + cls.func = tvm.testing.CompareBeforeAfter._normalize_before(cls.func) + cls.expected = pytest.fixture(cls.expected) + + def test_identify_memcpy(self, func, expected): + results = identify_memcpy(func.body) + + if isinstance(expected, str) or ( + isinstance(expected, tuple) and isinstance(expected[0], BufferRegion) + ): + expected = [expected] + + assert len(expected) == len(results) + for expected, result in zip(expected, results): + if isinstance(expected, str): + assert isinstance(result, StringImm) + assert re.search(expected, result.value) + else: + tvm.ir.assert_structural_equal(result, expected) + + +class Test1D(BaseTest): + """ Simplest test case """ + + def func(A: T.Buffer[1024, "float32"], B: T.Buffer[1024, "float32"]): + for i in T.serial(1024): + B[i] = A[i] + + def expected(self, func): + A, B = func.buffer_map.values() + return A[0:1024], B[0:1024] + + +class Test1DCompute(BaseTest): + """ Like Test1D, but a computation prevents this being a memcpy """ + + def func(A: T.Buffer[1024, "float32"], B: T.Buffer[1024, "float32"]): + for i in T.serial(1024): + B[i] = A[i] + 1.0 + + def expected(self, func): + return "Expected BufferStore's value to be BufferLoad" + + +class Test1DConditional(BaseTest): + """ Like Test1D, but a conditionals prevents this being a memcpy """ + + def func(A: T.Buffer[1024, "float32"], B: T.Buffer[1024, "float32"]): + for i in T.serial(1024): + if i < 1024: + B[i] = A[i] + + def expected(self, func): + A, B = func.buffer_map.values() + return "Expected innermost loop to have BufferStore body" + + +class Test1DStridedInput(BaseTest): + """ Like Test1D, but strided input prevents this being a memcpy """ + + def func(A: T.Buffer[2048, "float32"], B: T.Buffer[1024, "float32"]): + for i in T.serial(1024): + B[i] = A[i * 2] + + def expected(self, func): + return "Mismatch between loop iterations (.*) and number of src indices" + + +class Test1DStridedOutput(BaseTest): + """ Like Test1D, but strided output prevents this being a memcpy """ + + def func(A: T.Buffer[1024, "float32"], B: T.Buffer[2048, "float32"]): + for i in T.serial(1024): + B[i * 2] = A[i] + + def expected(self, func): + return "Mismatch between loop iterations (.*) and number of dst indices" + + +class Test1DInput2DOutputFusedLoop(BaseTest): + """ Like Test1D, but the output is written as a 2-d buffer """ + + def func(A: T.Buffer[1024, "float32"], B: T.Buffer[(32, 32), "float32"]): + for i in T.serial(1024): + B[i // 32, i % 32] = A[i] + + def expected(self, func): + A, B = func.buffer_map.values() + return A[0:1024], B[0:32, 0:32] + + +class Test2DInput1DOutputFusedLoop(BaseTest): + """ Like Test1D, but the input is written as a 2-d buffer """ + + def func(A: T.Buffer[(32, 32), "float32"], B: T.Buffer[1024, "float32"]): + for i in T.serial(1024): + B[i] = A[i // 32, i % 32] + + def expected(self, func): + A, B = func.buffer_map.values() + return A[0:32, 0:32], B[0:1024] + + +class Test1DInput1DOutputNestedLoop(BaseTest): + """Like Test1D, but the iterator is written as a nested loop + + In test cases with more than one loop, each loop is checked to see + if could be written as a memcpy. The C++ utility function + operates on individual loops, but for unit testing in Python, it + is more convenient to return the results for all loops. + """ + + def func(A: T.Buffer[1024, "float32"], B: T.Buffer[1024, "float32"]): + for i, j in T.grid(32, 32): + B[i * 32 + j] = A[i * 32 + j] + + def expected(self, func): + A, B = func.buffer_map.values() + i = func.body.loop_var + return [ + (A[0:1024], B[0:1024]), + (A[i * 32 : i * 32 + 32], B[i * 32 : i * 32 + 32]), + ] + + +class Test1DInput1DOutputNestedLoopEquivalentExpressions(BaseTest): + """Like Test1DInput1DOutputNestedLoop, but with equivalent indices + + If the expressions are not identical, the loops may still be + recognizable as a memcpy, so long as the expressions are + equivalent. + """ + + def func(A: T.Buffer[1024, "float32"], B: T.Buffer[1024, "float32"]): + for i, j in T.grid(32, 32): + B[i * 32 + j] = A[j + i * 32] + + def expected(self, func): + A, B = func.buffer_map.values() + i = func.body.loop_var + return [ + (A[0:1024], B[0:1024]), + (A[i * 32 : i * 32 + 32], B[i * 32 : i * 32 + 32]), + ] + + +class Test1DInput2DOutputNestedLoop(BaseTest): + """ Like Test1DInput1DOutputNestedLoop, but with a 2-d output buffer """ + + def func(A: T.Buffer[1024, "float32"], B: T.Buffer[(32, 32), "float32"]): + for i, j in T.grid(32, 32): + B[i, j] = A[i * 32 + j] + + def expected(self, func): + A, B = func.buffer_map.values() + i = func.body.loop_var + return [ + (A[0:1024], B[0:32, 0:32]), + (A[i * 32 : i * 32 + 32], B[i, 0:32]), + ] + + +class Test2DInput1DOutputNestedLoop(BaseTest): + """ Like Test1DInput1DOutputNestedLoop, but with a 2-d input buffer """ + + def func(A: T.Buffer[(32, 32), "float32"], B: T.Buffer[1024, "float32"]): + for i, j in T.grid(32, 32): + B[i * 32 + j] = A[i, j] + + def expected(self, func): + A, B = func.buffer_map.values() + i = func.body.loop_var + return [ + (A[0:32, 0:32], B[0:1024]), + (A[i, 0:32], B[i * 32 : i * 32 + 32]), + ] + + +class Test2DInput2DOutputNestedLoop(BaseTest): + """ Like Test1DInput1DOutputNestedLoop, but with 2-d input/output buffers """ + + def func(A: T.Buffer[(32, 32), "float32"], B: T.Buffer[(32, 32), "float32"]): + for i, j in T.grid(32, 32): + B[i, j] = A[i, j] + + def expected(self, func): + A, B = func.buffer_map.values() + i = func.body.loop_var + return [ + (A[0:32, 0:32], B[0:32, 0:32]), + (A[i, 0:32], B[i, 0:32]), + ] + + +class Test2DInput2DOutputTransposeOutput(BaseTest): + """Test2DInput2DOutputNestedLoop, but with a transposed output + + This is not recognized as a memcpy, because it results in a transpose. + """ + + def func(A: T.Buffer[(32, 32), "float32"], B: T.Buffer[(32, 32), "float32"]): + for i, j in T.grid(32, 32): + B[j, i] = A[i, j] + + def expected(self, func): + return [ + "different source", + "Mismatch .* number of dst indices touched", + ] + + +class Test2DInput2DOutputTransposeInput(BaseTest): + """Test2DInput2DOutputNestedLoop, but with a transposed input + + This is not recognized as a memcpy, because it results in a transpose. + """ + + def func(A: T.Buffer[(32, 32), "float32"], B: T.Buffer[(32, 32), "float32"]): + for i, j in T.grid(32, 32): + B[i, j] = A[j, i] + + def expected(self, func): + return [ + "different source", + "Mismatch .* number of src indices touched", + ] + + +class Test2DInput2DOutputTransposeBoth(BaseTest): + """Test2DInput2DOutputNestedLoop, but with a transposed input + + The inner loop is not recognized as a memcpy, because it has + strided access of both the input and output buffers. However, the + outer loop is still recognized as a memcpy, because the full + region has been copied over, even though it occurs out of order. + """ + + def func(A: T.Buffer[(32, 32), "float32"], B: T.Buffer[(32, 32), "float32"]): + for i, j in T.grid(32, 32): + B[j, i] = A[j, i] + + def expected(self, func): + A, B = func.buffer_map.values() + return [ + (A[0:32, 0:32], B[0:32, 0:32]), + "Mismatch .* number of src indices touched", + ] + + +class TestCacheRead(BaseTest): + """Like Test2DInput2DOutputNestedLoop, but with a 1-d + + The inner loop is a memcpy of a single row at a time. This + pattern would appear when B is a read cache of A. + """ + + def func(A: T.Buffer[(32, 32), "float32"], B: T.Buffer[32, "float32"]): + for i, j in T.grid(32, 32): + B[j] = A[i, j] + + def expected(self, func): + A, B = func.buffer_map.values() + i = func.body.loop_var + return [ + "does not form a bijective transform", + (A[i, 0:32], B[0:32]), + ] + + +class TestCacheWrite(BaseTest): + """Like Test2DInput2DOutputNestedLoop, but with a 1-d + + The inner loop is a memcpy of a single row at a time. This + pattern would appear when A is a write cache of B. + """ + + def func(A: T.Buffer[32, "float32"], B: T.Buffer[(32, 32), "float32"]): + for i, j in T.grid(32, 32): + B[i, j] = A[j] + + def expected(self, func): + A, B = func.buffer_map.values() + i = func.body.loop_var + return [ + "does not form a bijective transform", + (A[0:32], B[i, 0:32]), + ] + + +if __name__ == "__main__": + tvm.testing.main() From 105fb8c643b682669135b6dd40049a77b6a4bf3f Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Fri, 10 Feb 2023 13:37:42 -0600 Subject: [PATCH 2/8] Remove accidental use of C++20 feature Looks like there's some confusion over whether the LWG1203 proposals should be retroactively applied to C++11 onward. GCC 11 implements it for C++17, but GCC 9 does not, so this failed to compile in CI. https://stackoverflow.com/q/69325059/2689797 --- src/tir/analysis/identify_memcpy.cc | 80 +++++++++++++++++------------ 1 file changed, 47 insertions(+), 33 deletions(-) diff --git a/src/tir/analysis/identify_memcpy.cc b/src/tir/analysis/identify_memcpy.cc index a9605c903b3b..a9b2b9080ea0 100644 --- a/src/tir/analysis/identify_memcpy.cc +++ b/src/tir/analysis/identify_memcpy.cc @@ -85,9 +85,10 @@ std::variant IdentifyMemCpyImpl(const For& loop, Array flattened_src = load->buffer.OffsetOf(load->indices); if (flattened_dst.size() != 1 || flattened_src.size() != 1) { - return (std::stringstream() << "Expected flattened dimension of src/dest to be 1, but found" - << flattened_src.size() << "-d src and " << flattened_dst.size() - << "-d dst") + return static_cast( + std::stringstream() + << "Expected flattened dimension of src/dest to be 1, but found" + << flattened_src.size() << "-d src and " << flattened_dst.size() << "-d dst") .str(); } PrimExpr src_index = flattened_src[0]; @@ -104,19 +105,21 @@ std::variant IdentifyMemCpyImpl(const For& loop, auto src_iter_map = arith::DetectIterMap({src_index}, loop_ranges, Bool(true), arith::IterMapLevel::Bijective, analyzer); if (src_iter_map->errors.size()) { - return (std::stringstream() << "arith::DetectIterMap(src) returned " - << src_iter_map->errors.size() << " errors: [" - << src_iter_map->errors << "]" - << " for src_index = " << src_index) + return static_cast(std::stringstream() + << "arith::DetectIterMap(src) returned " + << src_iter_map->errors.size() << " errors: [" + << src_iter_map->errors << "]" + << " for src_index = " << src_index) .str(); } auto dst_iter_map = arith::DetectIterMap({dst_index}, loop_ranges, Bool(true), arith::IterMapLevel::Bijective, analyzer); if (dst_iter_map->errors.size()) { - return (std::stringstream() << "arith::DetectIterMap(dst) returned " - << dst_iter_map->errors.size() << " errors: [" - << dst_iter_map->errors << "]" - << " for dst_index = " << dst_index) + return static_cast(std::stringstream() + << "arith::DetectIterMap(dst) returned " + << dst_iter_map->errors.size() << " errors: [" + << dst_iter_map->errors << "]" + << " for dst_index = " << dst_index) .str(); } @@ -142,13 +145,15 @@ std::variant IdentifyMemCpyImpl(const For& loop, auto dst_interval = analyzer->int_set(dst_index, loop_intervals); if (!src_interval.HasLowerBound() || !src_interval.HasUpperBound()) { - return (std::stringstream() << "Expected known bounds for src, but found " << src_interval - << " for expression " << src_index) + return static_cast(std::stringstream() + << "Expected known bounds for src, but found " + << src_interval << " for expression " << src_index) .str(); } if (!dst_interval.HasLowerBound() || !dst_interval.HasUpperBound()) { - return (std::stringstream() << "Expected known bounds for dst, but found " << dst_interval - << " for expression " << dst_index) + return static_cast(std::stringstream() + << "Expected known bounds for dst, but found " + << dst_interval << " for expression " << dst_index) .str(); } @@ -156,9 +161,11 @@ std::variant IdentifyMemCpyImpl(const For& loop, PrimExpr must_prove = total_loop_iterations == src_interval.max() - src_interval.min() + 1; PrimExpr simplified = analyzer->Simplify(must_prove); if (!analyzer->CanProve(simplified)) { - return (std::stringstream() << "Mismatch between loop iterations (" << total_loop_iterations - << ") and number of src indices touched (" << src_interval - << ". Equality to prove simplified to " << simplified) + return static_cast( + std::stringstream() + << "Mismatch between loop iterations (" << total_loop_iterations + << ") and number of src indices touched (" << src_interval + << ". Equality to prove simplified to " << simplified) .str(); } } @@ -166,9 +173,11 @@ std::variant IdentifyMemCpyImpl(const For& loop, PrimExpr must_prove = total_loop_iterations == dst_interval.max() - dst_interval.min() + 1; PrimExpr simplified = analyzer->Simplify(must_prove); if (!analyzer->CanProve(simplified)) { - return (std::stringstream() << "Mismatch between loop iterations (" << total_loop_iterations - << ") and number of dst indices touched (" << dst_interval - << ". Equality to prove simplified to " << simplified) + return static_cast( + std::stringstream() + << "Mismatch between loop iterations (" << total_loop_iterations + << ") and number of dst indices touched (" << dst_interval + << ". Equality to prove simplified to " << simplified) .str(); } } @@ -227,27 +236,32 @@ std::variant IdentifyMemCpyImpl(const For& loop, if (!analyzer->CanProve( arith::NormalizeIterMapToExpr(src_term->source->source == dst_term->source->source))) { - return (std::stringstream() << "Term " << i << " had different source, src_term->source = " - << src_term->source - << ", dst_term->source = " << dst_term->source) + return static_cast( + std::stringstream() + << "Term " << i << " had different source, src_term->source = " << src_term->source + << ", dst_term->source = " << dst_term->source) .str(); } if (!analyzer->CanProve(src_term->lower_factor == dst_term->lower_factor)) { - return (std::stringstream() << "Term " << i - << " had different lower_factor, src_term->lower_factor = " - << src_term->lower_factor - << ", dst_term->lower_factor = " << dst_term->lower_factor) + return static_cast( + std::stringstream() + << "Term " << i << " had different lower_factor, src_term->lower_factor = " + << src_term->lower_factor + << ", dst_term->lower_factor = " << dst_term->lower_factor) .str(); } if (!analyzer->CanProve(src_term->extent == dst_term->extent)) { - return (std::stringstream() << "Term " << i << " had different extent, src_term->extent = " - << src_term->extent - << ", dst_term->extent = " << dst_term->extent) + return static_cast( + std::stringstream() + << "Term " << i << " had different extent, src_term->extent = " << src_term->extent + << ", dst_term->extent = " << dst_term->extent) .str(); } if (!analyzer->CanProve(src_term->scale == dst_term->scale)) { - return (std::stringstream() << "Term " << i << " had different scale, src_term->scale = " - << src_term->scale << ", dst_term->scale = " << dst_term->scale) + return static_cast( + std::stringstream() + << "Term " << i << " had different scale, src_term->scale = " << src_term->scale + << ", dst_term->scale = " << dst_term->scale) .str(); } } From c3c39f934c836f35e2d70a7ea775a5cf0c677b5f Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Mon, 13 Feb 2023 11:25:45 -0600 Subject: [PATCH 3/8] Fixup 3 more cases that would require C++20 --- src/tir/analysis/identify_memcpy.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/analysis/identify_memcpy.cc b/src/tir/analysis/identify_memcpy.cc index a9b2b9080ea0..e0a8c9f88cce 100644 --- a/src/tir/analysis/identify_memcpy.cc +++ b/src/tir/analysis/identify_memcpy.cc @@ -61,7 +61,7 @@ std::variant IdentifyMemCpyImpl(const For& loop, if (auto* ptr = stmt.as()) { store = GetRef(ptr); } else { - return ( + return static_cast( std::stringstream() << "Expected innermost loop to have BufferStore body, but instead found " << stmt) .str(); @@ -71,7 +71,7 @@ std::variant IdentifyMemCpyImpl(const For& loop, if (auto* ptr = store->value.as()) { load = GetRef(ptr); } else { - return ( + return static_cast( std::stringstream() << "Expected BufferStore's value to be BufferLoad, but instead found " << store->value) @@ -196,7 +196,7 @@ std::variant IdentifyMemCpyImpl(const For& loop, auto dst_iter_sum = dst_iter_map->indices[0]; if (src_iter_sum->args.size() != dst_iter_sum->args.size()) { - return ( + return static_cast( std::stringstream() << "IterMap for src/dst unpacked to different number of IterSplitExpr: " << src_iter_sum->args.size() << " for src, " << dst_iter_sum->args.size() From f617960c0e17c235beb4b4825a2deda82cb03aef Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 16 Feb 2023 14:44:49 -0600 Subject: [PATCH 4/8] Fix linting errors --- .../test_tir_analysis_identify_memcpy.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/python/unittest/test_tir_analysis_identify_memcpy.py b/tests/python/unittest/test_tir_analysis_identify_memcpy.py index 39807b3191e0..185a101f819b 100644 --- a/tests/python/unittest/test_tir_analysis_identify_memcpy.py +++ b/tests/python/unittest/test_tir_analysis_identify_memcpy.py @@ -30,7 +30,7 @@ class BaseTest: - """Utility class for defining unit tests for memcpy """ + """Utility class for defining unit tests for memcpy""" def __init_subclass__(cls): cls.func = tvm.testing.CompareBeforeAfter._normalize_before(cls.func) @@ -54,7 +54,7 @@ def test_identify_memcpy(self, func, expected): class Test1D(BaseTest): - """ Simplest test case """ + """Simplest test case""" def func(A: T.Buffer[1024, "float32"], B: T.Buffer[1024, "float32"]): for i in T.serial(1024): @@ -66,7 +66,7 @@ def expected(self, func): class Test1DCompute(BaseTest): - """ Like Test1D, but a computation prevents this being a memcpy """ + """Like Test1D, but a computation prevents this being a memcpy""" def func(A: T.Buffer[1024, "float32"], B: T.Buffer[1024, "float32"]): for i in T.serial(1024): @@ -77,7 +77,7 @@ def expected(self, func): class Test1DConditional(BaseTest): - """ Like Test1D, but a conditionals prevents this being a memcpy """ + """Like Test1D, but a conditionals prevents this being a memcpy""" def func(A: T.Buffer[1024, "float32"], B: T.Buffer[1024, "float32"]): for i in T.serial(1024): @@ -90,7 +90,7 @@ def expected(self, func): class Test1DStridedInput(BaseTest): - """ Like Test1D, but strided input prevents this being a memcpy """ + """Like Test1D, but strided input prevents this being a memcpy""" def func(A: T.Buffer[2048, "float32"], B: T.Buffer[1024, "float32"]): for i in T.serial(1024): @@ -101,7 +101,7 @@ def expected(self, func): class Test1DStridedOutput(BaseTest): - """ Like Test1D, but strided output prevents this being a memcpy """ + """Like Test1D, but strided output prevents this being a memcpy""" def func(A: T.Buffer[1024, "float32"], B: T.Buffer[2048, "float32"]): for i in T.serial(1024): @@ -112,7 +112,7 @@ def expected(self, func): class Test1DInput2DOutputFusedLoop(BaseTest): - """ Like Test1D, but the output is written as a 2-d buffer """ + """Like Test1D, but the output is written as a 2-d buffer""" def func(A: T.Buffer[1024, "float32"], B: T.Buffer[(32, 32), "float32"]): for i in T.serial(1024): @@ -124,7 +124,7 @@ def expected(self, func): class Test2DInput1DOutputFusedLoop(BaseTest): - """ Like Test1D, but the input is written as a 2-d buffer """ + """Like Test1D, but the input is written as a 2-d buffer""" def func(A: T.Buffer[(32, 32), "float32"], B: T.Buffer[1024, "float32"]): for i in T.serial(1024): @@ -179,7 +179,7 @@ def expected(self, func): class Test1DInput2DOutputNestedLoop(BaseTest): - """ Like Test1DInput1DOutputNestedLoop, but with a 2-d output buffer """ + """Like Test1DInput1DOutputNestedLoop, but with a 2-d output buffer""" def func(A: T.Buffer[1024, "float32"], B: T.Buffer[(32, 32), "float32"]): for i, j in T.grid(32, 32): @@ -195,7 +195,7 @@ def expected(self, func): class Test2DInput1DOutputNestedLoop(BaseTest): - """ Like Test1DInput1DOutputNestedLoop, but with a 2-d input buffer """ + """Like Test1DInput1DOutputNestedLoop, but with a 2-d input buffer""" def func(A: T.Buffer[(32, 32), "float32"], B: T.Buffer[1024, "float32"]): for i, j in T.grid(32, 32): @@ -211,7 +211,7 @@ def expected(self, func): class Test2DInput2DOutputNestedLoop(BaseTest): - """ Like Test1DInput1DOutputNestedLoop, but with 2-d input/output buffers """ + """Like Test1DInput1DOutputNestedLoop, but with 2-d input/output buffers""" def func(A: T.Buffer[(32, 32), "float32"], B: T.Buffer[(32, 32), "float32"]): for i, j in T.grid(32, 32): From d86c1a538f820b387887bc678f2e9c281f6385dc Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 21 Feb 2023 08:20:23 -0600 Subject: [PATCH 5/8] Remove unused import --- python/tvm/tir/analysis/analysis.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index c023976ad9c8..45b1f745c3de 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -18,7 +18,6 @@ # pylint: disable=invalid-name from typing import Dict, List, Union -import tvm from tvm import Object from tvm.ir import IRModule from tvm.tir.expr import Var From 9d5dfccaef2354c7d5ef612e1d3d2a4b3c940708 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 21 Feb 2023 09:21:40 -0600 Subject: [PATCH 6/8] lint fixes --- src/tir/analysis/identify_memcpy.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/analysis/identify_memcpy.cc b/src/tir/analysis/identify_memcpy.cc index e0a8c9f88cce..e2bf2fbb0c27 100644 --- a/src/tir/analysis/identify_memcpy.cc +++ b/src/tir/analysis/identify_memcpy.cc @@ -218,9 +218,9 @@ std::variant IdentifyMemCpyImpl(const For& loop, } }; return std::tuple{ - bool(expr->scale.as()), as_int_or_zero(expr->scale), - bool(expr->extent.as()), as_int_or_zero(expr->lower_factor), - bool(expr->lower_factor.as()), as_int_or_zero(expr->lower_factor), + static_cast(expr->scale.as()), as_int_or_zero(expr->scale), + static_cast(expr->extent.as()), as_int_or_zero(expr->lower_factor), + static_cast(expr->lower_factor.as()), as_int_or_zero(expr->lower_factor), }; }; auto sorting_function = [&make_comparison_tuple](const arith::IterSplitExpr& lhs, From c55eaca557bfa85fb2f0e6273a73c7217f738ca3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 21 Feb 2023 09:57:29 -0600 Subject: [PATCH 7/8] lint fixes --- src/tir/analysis/identify_memcpy.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/tir/analysis/identify_memcpy.cc b/src/tir/analysis/identify_memcpy.cc index e2bf2fbb0c27..a8b7b9b34659 100644 --- a/src/tir/analysis/identify_memcpy.cc +++ b/src/tir/analysis/identify_memcpy.cc @@ -287,7 +287,7 @@ TVM_REGISTER_GLOBAL("tir.analysis._identify_memcpy").set_body_typed([](const Stm Array output; struct Visitor : arith::IRVisitorWithAnalyzer { - Visitor(Array* output) : output(output) {} + explicit Visitor(Array* output) : output(output) {} Array* output; private: @@ -307,7 +307,8 @@ TVM_REGISTER_GLOBAL("tir.analysis._identify_memcpy").set_body_typed([](const Stm } }; - Visitor{&output}(stmt); + Visitor visitor(&output); + visitor(stmt); return output; }); From 492caab7150e327759caeb2d559579227dc7cee6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 3 Mar 2023 11:42:03 -0600 Subject: [PATCH 8/8] Updates from review comments --- src/tir/analysis/identify_memcpy.cc | 9 ++++----- .../python/unittest/test_tir_analysis_identify_memcpy.py | 1 - 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/tir/analysis/identify_memcpy.cc b/src/tir/analysis/identify_memcpy.cc index a8b7b9b34659..0d3b48dbc2c6 100644 --- a/src/tir/analysis/identify_memcpy.cc +++ b/src/tir/analysis/identify_memcpy.cc @@ -19,7 +19,7 @@ /*! * \file tir/analysis/identify_memcpy.cc - * \brief Check if a loopnest is equivalent to memcpy + * \brief Check if a loop nest is equivalent to memcpy */ #include @@ -45,7 +45,7 @@ std::variant IdentifyMemCpyImpl(const For& loop, Map loop_ranges; PrimExpr total_loop_iterations = 1; - // Walk through the loopnest, stopping at the first loop whose body + // Walk through the loop nest, stopping at the first loop whose body // is not a loop. Stmt stmt = loop; while (auto* for_node = stmt.as()) { @@ -182,7 +182,7 @@ std::variant IdentifyMemCpyImpl(const For& loop, } } - // Thrid check, is there a transformation applied between the input + // Third check, is there a transformation applied between the input // and output iterators? // // For example, the following would pass all checks so far, but @@ -281,8 +281,7 @@ std::optional IdentifyMemCpy(const For& loop, arith::Analyzer* an } } -// Expose the IdentifyMemCpy functionality to Python API for purpose -// of unit testing. +// Expose the IdentifyMemCpy functionality to Python API for purpose of unit testing. TVM_REGISTER_GLOBAL("tir.analysis._identify_memcpy").set_body_typed([](const Stmt& stmt) { Array output; diff --git a/tests/python/unittest/test_tir_analysis_identify_memcpy.py b/tests/python/unittest/test_tir_analysis_identify_memcpy.py index 185a101f819b..b69d3aea3ea3 100644 --- a/tests/python/unittest/test_tir_analysis_identify_memcpy.py +++ b/tests/python/unittest/test_tir_analysis_identify_memcpy.py @@ -21,7 +21,6 @@ import tvm import tvm.testing -from tvm import te, topi from tvm.tir import BufferRegion, StringImm from tvm.script import tir as T