diff --git a/src/TiledArray/external/btas.h b/src/TiledArray/external/btas.h index fe84e6f0c6..c22afd3813 100644 --- a/src/TiledArray/external/btas.h +++ b/src/TiledArray/external/btas.h @@ -62,6 +62,13 @@ class boxrange_iteration_order { static constexpr int value = row_major; }; +template +class is_tensor> : public std::true_type {}; + +template +class is_tensor> + : public std::true_type {}; + } // namespace btas namespace TiledArray { diff --git a/tests/btas.cpp b/tests/btas.cpp index 9c15540e9a..4e972cfc28 100644 --- a/tests/btas.cpp +++ b/tests/btas.cpp @@ -256,6 +256,21 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(tensor_ctor, Tensor, tensor_types) { BOOST_REQUIRE_NO_THROW(Tensor t1 = t0); Tensor t1 = t0; BOOST_CHECK(t1.empty()); + + // can copy TA::Tensor to btas::Tensor + TA::Tensor ta_tensor(r); + BOOST_REQUIRE_NO_THROW(Tensor(ta_tensor)); + Tensor t2(ta_tensor); + for (auto i : r) { + BOOST_CHECK_EQUAL(ta_tensor(i), t2(i)); + } + + // can copy TA::TensorInterface to btas::Tensor + BOOST_REQUIRE_NO_THROW(Tensor(ta_tensor.block(r.lobound(), r.upbound()))); + Tensor t3(ta_tensor.block(r.lobound(), r.upbound())); + for (auto i : r) { + BOOST_CHECK_EQUAL(ta_tensor(i), t3(i)); + } } BOOST_AUTO_TEST_CASE_TEMPLATE(copy, Array, array_types) { diff --git a/tests/expressions_btas.cpp b/tests/expressions_btas.cpp index 83ff4b1ed0..7b1ae422ce 100644 --- a/tests/expressions_btas.cpp +++ b/tests/expressions_btas.cpp @@ -23,6 +23,8 @@ * */ +#include + #ifdef TILEDARRAY_HAS_BTAS #include "expressions_fixture.h"