diff --git a/include/matx_tensor.h b/include/matx_tensor.h index b8657bf02..6653eff13 100644 --- a/include/matx_tensor.h +++ b/include/matx_tensor.h @@ -187,7 +187,6 @@ class tensor_t : public detail::tensor_impl_t { using stride_container = typename Desc::stride_container; using desc_type = Desc; ///< Descriptor type trait using self_type = tensor_t; - static constexpr bool PRINT_ON_DEVICE = false; ///< Print() uses printf on device /** * @brief Construct a new 0-D tensor t object @@ -1654,26 +1653,8 @@ class tensor_t : public detail::tensor_impl_t { std::enable_if_t<((std::is_integral_v)&&...) && (RANK == 0 || sizeof...(Args) > 0), bool> = true> - void Print(Args... dims) const { -#ifdef __CUDACC__ - auto kind = GetPointerKind(this->ldata_); - cudaDeviceSynchronize(); - if (HostPrintable(kind)) { - InternalPrint(dims...); - } - else if (DevicePrintable(kind) || kind == MATX_INVALID_MEMORY) { - if constexpr (PRINT_ON_DEVICE) { - PrintKernel<<<1, 1>>>(*this, dims...); - } - else { - auto tmpv = make_tensor(this->Shape()); - (tmpv = *this).run(); - tmpv.Print(dims...); - } - } -#else - InternalPrint(dims...); -#endif + [[deprecated("Use non-member function Print() instead")]] void Print(Args... dims) const { + matx::Print(*this, dims...); } /** @@ -1689,9 +1670,7 @@ class tensor_t : public detail::tensor_impl_t { template 0 && sizeof...(Args) == 0), bool> = true> void Print(Args... dims) const { - std::array arr = {0}; - auto tp = std::tuple_cat(arr); - std::apply([&](auto &&...args) { this->Print(args...); }, tp); + matx::Print(*this, dims...); } /** @@ -1718,7 +1697,7 @@ class tensor_t : public detail::tensor_impl_t { auto tup = std::tuple_cat(arr); std::apply( [&](auto&&... args) { - s.InternalPrint(args...); + detail::InternalPrint(s, args...); }, tup); } @@ -1747,7 +1726,7 @@ class tensor_t : public detail::tensor_impl_t { auto tup = std::tuple_cat(arr); std::apply( [&](auto&&... args) { - s.InternalPrint(args...); + detail::InternalPrint(s, args...); }, tup); } diff --git a/include/matx_tensor_utils.h b/include/matx_tensor_utils.h index 5bf29e6d1..30649a774 100644 --- a/include/matx_tensor_utils.h +++ b/include/matx_tensor_utils.h @@ -34,6 +34,7 @@ #include #include +#include "matx_make.h" namespace matx { @@ -359,5 +360,262 @@ namespace detail { } } + + /** + * Print a value + * + * Type-agnostic function to print a value to stdout + * + * @param val + */ + template + __MATX_INLINE__ __MATX_HOST__ void PrintVal(const T &val) + { + if constexpr (is_complex_v) { + printf("%.4e%+.4ej ", static_cast(val.real()), + static_cast(val.imag())); + } + else if constexpr (is_matx_half_v || is_half_v) { + printf("%.4e ", static_cast(val)); + } + else if constexpr (std::is_floating_point_v) { + printf("%.4e ", val); + } + else if constexpr (std::is_same_v) { + printf("%lld ", val); + } + else if constexpr (std::is_same_v) { + printf("%" PRId64 " ", val); + } + else if constexpr (std::is_same_v) { + printf("%" PRId32 " ", val); + } + else if constexpr (std::is_same_v) { + printf("%" PRId16 " ", val); + } + else if constexpr (std::is_same_v) { + printf("%" PRId8 " ", val); + } + else if constexpr (std::is_same_v) { + printf("%" PRIu64 " ", val); + } + else if constexpr (std::is_same_v) { + printf("%" PRIu32 " ", val); + } + else if constexpr (std::is_same_v) { + printf("%" PRIu16 " ", val); + } + else if constexpr (std::is_same_v) { + printf("%" PRIu8 " ", val); + } + else if constexpr (std::is_same_v) { + printf("%d ", val); + } + } + + /** + * Print a tensor + * + * Type-agnostic function to print a tensor to stdout + * + */ + template + __MATX_HOST__ void InternalPrint(const Op &op, Args ...dims) + { + MATX_STATIC_ASSERT(op.Rank() == sizeof...(Args), "Number of dimensions to print must match tensor rank"); + MATX_STATIC_ASSERT(op.Rank() <= 4, "Printing is only supported on tensors of rank 4 or lower currently"); + if constexpr (sizeof...(Args) == 0) { + PrintVal(op.operator()()); + printf("\n"); + } + else if constexpr (sizeof...(Args) == 1) { + auto& k =detail:: pp_get<0>(dims...); + for (index_t _k = 0; _k < ((k == 0) ? op.Size(0) : k); _k++) { + printf("%06lld: ", _k); + PrintVal(op.operator()(_k)); + printf("\n"); + } + } + else if constexpr (sizeof...(Args) == 2) { + auto& k = detail::pp_get<0>(dims...); + auto& l = detail::pp_get<1>(dims...); + for (index_t _k = 0; _k < ((k == 0) ? op.Size(0) : k); _k++) { + for (index_t _l = 0; _l < ((l == 0) ? op.Size(1) : l); _l++) { + if (_l == 0) + printf("%06lld: ", _k); + + PrintVal(op.operator()(_k, _l)); + } + printf("\n"); + } + } + else if constexpr (sizeof...(Args) == 3) { + auto& j = detail::pp_get<0>(dims...); + auto& k = detail::pp_get<1>(dims...); + auto& l = detail::pp_get<2>(dims...); + for (index_t _j = 0; _j < ((j == 0) ? op.Size(0) : j); _j++) { + printf("[%06lld,:,:]\n", _j); + for (index_t _k = 0; _k < ((k == 0) ? op.Size(1) : k); _k++) { + for (index_t _l = 0; _l < ((l == 0) ? op.Size(2) : l); _l++) { + if (_l == 0) + printf("%06lld: ", _k); + + PrintVal(op.operator()(_j, _k, _l)); + } + printf("\n"); + } + printf("\n"); + } + } + else if constexpr (sizeof...(Args) == 4) { + auto& i = detail::pp_get<0>(dims...); + auto& j = detail::pp_get<1>(dims...); + auto& k = detail::pp_get<2>(dims...); + auto& l = detail::pp_get<3>(dims...); + for (index_t _i = 0; _i < ((i == 0) ? op.Size(0) : i); _i++) { + for (index_t _j = 0; _j < ((j == 0) ? op.Size(1) : j); _j++) { + printf("[%06lld,%06lld,:,:]\n", _i, _j); + for (index_t _k = 0; _k < ((k == 0) ? op.Size(2) : k); _k++) { + for (index_t _l = 0; _l < ((l == 0) ? op.Size(3) : l); _l++) { + if (_l == 0) + printf("%06lld: ", _k); + + PrintVal(op.operator()(_i, _j, _k, _l)); + } + printf("\n"); + } + printf("\n"); + } + } + } + } } + +static constexpr bool PRINT_ON_DEVICE = false; ///< Print() uses printf on device + +/** + * @brief Print a tensor's values to stdout + * + * This form of `Print()` takes integral values for each index, and prints that as many values + * in each dimension as the arguments specify. For example: + * + * `a.Print(2, 3, 2);` + * + * Will print 2 values of the first, 3 values of the second, and 2 values of the third dimension + * of a 3D tensor. The number of parameters must match the rank of the tensor. A special value of + * 0 can be used if the entire tensor should be printed: + * + * `a.Print(0, 0, 0);` // Prints the whole tensor + * + * For more fine-grained printing, see the over `Print()` overloads. + * + * @tparam Args Integral argument types + * @param dims Number of values to print for each dimension + */ +template )&&...) && + (Op::Rank() == 0 || sizeof...(Args) > 0), + bool> = true> +void Print(const Op &op, Args... dims) { +#ifdef __CUDACC__ + if constexpr (is_tensor_view_v) { + auto kind = GetPointerKind(op.Data()); + cudaDeviceSynchronize(); + if (HostPrintable(kind)) { + detail::InternalPrint(op, dims...); + } + else if (DevicePrintable(kind) || kind == MATX_INVALID_MEMORY) { + if constexpr (PRINT_ON_DEVICE) { + PrintKernel<<<1, 1>>>(op, dims...); + } + else { + auto tmpv = make_tensor(op.Shape()); + (tmpv = op).run(); + tmpv.Print(dims...); + } + } + } + else { + InternalPrint(op, dims...); + } +#else + InternalPrint(op, dims...); +#endif +} + +/** + * @brief Print a tensor's all values to stdout + * + * This form of `Print()` is an alias of `Print(0)`, `Print(0, 0)`, + * `Print(0, 0, 0)` and `Print(0, 0, 0, 0)` for 1D, 2D, 3D and 4D tensor + * respectively. It passes the proper number of zeros to `Print(...)` + * automatically according to the rank of this tensor. The user only have to + * invoke `.Print()` to print the whole tensor, instead of passing zeros + * manually. + */ +template 0 && sizeof...(Args) == 0), bool> = true> +void Print(const Op &op, Args... dims) { + std::array arr = {0}; + auto tp = std::tuple_cat(arr); + std::apply([&](auto &&...args) { Print(op, args...); }, tp); +} + +/** + * @brief Print a tensor's values to stdout using start/end parameters + * + * This form of `Print()` takes two array-like lists for the start and end indices, respectively. For + * example: + * + * `a.Print({2, 3}, {matxEnd, 5});` + * + * Will print the 2D tensor `a` with the first dimension starting at index 2 and going to the end, and + * the second index starting at 3 and ending at 5 (exlusive). The format is identical to calling + * `Slice()` to get a sliced view, followed by `Print()` with the indices. + * + * @tparam NRANK Automatically-deduced rank of tensor + * @param start Start indices to print from + * @param end End indices to stop + */ +// template +// void Print(const Op &op, const index_t (&start)[NRANK], const index_t (&end)[NRANK]) const +// { +// auto s = this->Slice(start, end); +// std::array arr = {0}; +// auto tup = std::tuple_cat(arr); +// std::apply( +// [&](auto&&... args) { +// s.InternalPrint(args...); +// }, tup); +// } + +/** + * @brief Print a tensor's values to stdout using start/end/stride + * + * This form of `Print()` takes three array-like lists for the start, end, and stride indices, respectively. For + * example: + * + * `a.Print({2, 3}, {matxEnd, 5}, {1, 2});` + * + * Will print the 2D tensor `a` with the first dimension starting at index 2 and going to the end with a + * stride of 1, and the second index starting at 3 and ending at 5 (exlusive) with a stride of 2. The format is + * identical to calling `Slice()` to get a sliced view, followed by `Print()` with the indices. + * + * @tparam NRANK Automatically-deduced rank of tensor + * @param start Start indices to print from + * @param end End indices to stop + * @param strides Strides of each dimension + */ +// template +// void Print(const index_t (&start)[NRANK], const index_t (&end)[NRANK], const index_t (&strides)[NRANK]) const +// { +// auto s = this->Slice(start, end, strides); +// std::array arr = {0}; +// auto tup = std::tuple_cat(arr); +// std::apply( +// [&](auto&&... args) { +// s.InternalPrint(args...); +// }, tup); +// } + } \ No newline at end of file diff --git a/test/00_operators/OperatorTests.cu b/test/00_operators/OperatorTests.cu index 6e7e1e325..15c2aecee 100644 --- a/test/00_operators/OperatorTests.cu +++ b/test/00_operators/OperatorTests.cu @@ -2062,3 +2062,12 @@ TEST(OperatorTests, Cast) MATX_EXIT_HANDLER(); } +TYPED_TEST(OperatorTestsFloat, Print) +{ + MATX_ENTER_HANDLER(); + auto t = make_tensor({3}); + auto r = ones(t.Shape()); + + Print(r); + MATX_EXIT_HANDLER(); +} \ No newline at end of file diff --git a/test/00_tensor/BasicTensorTests.cu b/test/00_tensor/BasicTensorTests.cu index 75ecae707..93d7c6186 100644 --- a/test/00_tensor/BasicTensorTests.cu +++ b/test/00_tensor/BasicTensorTests.cu @@ -406,3 +406,14 @@ TYPED_TEST(BasicTensorTestsIntegral, InitAssign) MATX_EXIT_HANDLER(); } + +TYPED_TEST(BasicTensorTestsAll, Print) +{ + MATX_ENTER_HANDLER(); + + auto t = make_tensor({3}); + (t = ones(t.Shape())).run(); + t.Print(); + + MATX_EXIT_HANDLER(); +}