diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm index 412a22bded66..3e07e778d78f 160000 --- a/3rdparty/cutlass_fpA_intB_gemm +++ b/3rdparty/cutlass_fpA_intB_gemm @@ -1 +1 @@ -Subproject commit 412a22bded6631d02fa40e3994a8096a5b8a6c7c +Subproject commit 3e07e778d78f0fcd047533c1fdaed571a68a396f diff --git a/ffi/include/tvm/ffi/container/container_details.h b/ffi/include/tvm/ffi/container/container_details.h index cfc5590f5404..bb29a14f7cb8 100644 --- a/ffi/include/tvm/ffi/container/container_details.h +++ b/ffi/include/tvm/ffi/container/container_details.h @@ -199,6 +199,16 @@ class IterAdapter { IterAdapter operator-(difference_type offset) const { return IterAdapter(iter_ - offset); } + IterAdapter& operator+=(difference_type offset) { + iter_ += offset; + return *this; + } + + IterAdapter& operator-=(difference_type offset) { + iter_ -= offset; + return *this; + } + template typename std::enable_if::value, typename T::difference_type>::type inline diff --git a/ffi/include/tvm/ffi/container/tuple.h b/ffi/include/tvm/ffi/container/tuple.h index 9fba4b0e0691..10303f0aecab 100644 --- a/ffi/include/tvm/ffi/container/tuple.h +++ b/ffi/include/tvm/ffi/container/tuple.h @@ -55,10 +55,13 @@ class Tuple : public ObjectRef { template && ...), int>> Tuple(Tuple&& other) : ObjectRef(std::move(other)) {} - template - explicit Tuple(UTypes&&... args) : ObjectRef(MakeTupleNode(std::forward(args)...)) { - static_assert(sizeof...(Types) == sizeof...(UTypes), "Tuple size mismatch"); - } + + template , Tuple> && + ...))>> + explicit Tuple(UTypes&&... args) : ObjectRef(MakeTupleNode(std::forward(args)...)) {} TVM_FFI_INLINE Tuple& operator=(const Tuple& other) { data_ = other.data_; diff --git a/ffi/tests/cpp/test_tuple.cc b/ffi/tests/cpp/test_tuple.cc index 79eeb488643d..e0f69d820018 100644 --- a/ffi/tests/cpp/test_tuple.cc +++ b/ffi/tests/cpp/test_tuple.cc @@ -136,4 +136,32 @@ TEST(Tuple, Upcast) { static_assert(details::type_contains_v, Tuple>); static_assert(details::type_contains_v, Tuple>); } + +TEST(Tuple, ArrayIterForwarding) { + Tuple t0(1, 2); + Tuple t1(3, 4); + Array> arr0 = {t0, t1}; + std::vector> vec0 = {t0}; + vec0.insert(vec0.end(), arr0.begin(), arr0.end()); + EXPECT_EQ(vec0.size(), 3); + EXPECT_EQ(vec0[0].get<0>()->value, 1); + EXPECT_EQ(vec0[0].get<1>()->value, 2); + EXPECT_EQ(vec0[1].get<0>()->value, 1); + EXPECT_EQ(vec0[1].get<1>()->value, 2); + EXPECT_EQ(vec0[2].get<0>()->value, 3); + EXPECT_EQ(vec0[2].get<1>()->value, 4); +} + +TEST(Tuple, ArrayIterForwardSingleElem) { + Tuple t0(1); + Tuple t1(2); + Array> arr0 = {t0, t1}; + std::vector> vec0 = {t0}; + vec0.insert(vec0.end(), arr0.begin(), arr0.end()); + EXPECT_EQ(vec0.size(), 3); + EXPECT_EQ(vec0[0].get<0>()->value, 1); + EXPECT_EQ(vec0[1].get<0>()->value, 1); + EXPECT_EQ(vec0[2].get<0>()->value, 2); +} + } // namespace