Skip to content

Commit

Permalink
Generalize TA::squared_norm to be applicable for tensor-of-tensor a…
Browse files Browse the repository at this point in the history
…rrays as well.
  • Loading branch information
bimalgaudel committed Jun 26, 2024
1 parent e8ab5e2 commit 051b3b9
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/TiledArray/dist_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,17 @@ class DistArray : public madness::archive::ParallelSerializableObject {
return TiledArray::expressions::TsrExpr<const DistArray>(*this, vars);
}

///
/// \brief This method creates a tensor expression but does not insist the
/// annotation to be bipartite (outer and inner tensor annotations).
/// \param vars Annotation for the tensor expression.
/// \note Only use for unary evaluations when the indexing of the inner
/// tensors is not significant, eg. norm computation.
///
auto index_unchecked_tensor_expression(const std::string& vars) const {
return TiledArray::expressions::TsrExpr<const DistArray>(*this, vars);
}

/// Create a tensor expression

/// \param vars A string with a comma-separated list of variables
Expand Down Expand Up @@ -1917,7 +1928,8 @@ auto inner_product(const DistArray<Tile, Policy>& a,

template <typename Tile, typename Policy>
auto squared_norm(const DistArray<Tile, Policy>& a) {
return a(detail::dummy_annotation(rank(a))).squared_norm();
return a.index_unchecked_tensor_expression(detail::dummy_annotation(rank(a)))
.squared_norm();
}

template <typename Tile, typename Policy>
Expand Down

0 comments on commit 051b3b9

Please sign in to comment.