From 6994163c27395b5247f92580953e0b2ca3697e29 Mon Sep 17 00:00:00 2001 From: Bimal Gaudel Date: Sun, 7 Jul 2024 15:16:06 -0400 Subject: [PATCH] ToT support for `math/linalg` functions and `concat` function. --- src/TiledArray/conversions/concat.h | 5 +++-- src/TiledArray/math/linalg/basic.h | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/TiledArray/conversions/concat.h b/src/TiledArray/conversions/concat.h index 7c440c54e2..398a5dc7b3 100644 --- a/src/TiledArray/conversions/concat.h +++ b/src/TiledArray/conversions/concat.h @@ -92,8 +92,9 @@ DistArray concat( DistArray result(*target_world, tr); const auto annot = detail::dummy_annotation(r); for (auto i = 0ul; i != arrays.size(); ++i) { - result(annot).block(tile_begin_end[i].first, tile_begin_end[i].second) = - arrays[i](annot); + result.make_tsrexpr(annot).block(tile_begin_end[i].first, + tile_begin_end[i].second) = + arrays[i].make_tsrexpr(annot); } result.world().gop.fence(); diff --git a/src/TiledArray/math/linalg/basic.h b/src/TiledArray/math/linalg/basic.h index 856c915bbe..c00a363286 100644 --- a/src/TiledArray/math/linalg/basic.h +++ b/src/TiledArray/math/linalg/basic.h @@ -79,14 +79,14 @@ template inline void vec_multiply(DistArray& a1, const DistArray& a2) { auto vars = TiledArray::detail::dummy_annotation(rank(a1)); - a1(vars) = a1(vars) * a2(vars); + a1.make_tsrexpr(vars) = a1.make_tsrexpr(vars) * a2.make_tsrexpr(vars); } template inline void scale(DistArray& a, S scaling_factor) { using numeric_type = typename DistArray::numeric_type; auto vars = TiledArray::detail::dummy_annotation(rank(a)); - a(vars) = numeric_type(scaling_factor) * a(vars); + a.make_tsrexpr(vars) = numeric_type(scaling_factor) * a.make_tsrexpr(vars); } template @@ -99,7 +99,8 @@ inline void axpy(DistArray& y, S alpha, const DistArray& x) { using numeric_type = typename DistArray::numeric_type; auto vars = TiledArray::detail::dummy_annotation(rank(y)); - y(vars) = y(vars) + numeric_type(alpha) * x(vars); + y.make_tsrexpr(vars) = + y.make_tsrexpr(vars) + numeric_type(alpha) * x.make_tsrexpr(vars); } /// selector for concat