From 2209cac9282b56d47f04851976d4f83ec2e87aee Mon Sep 17 00:00:00 2001 From: LTLA Date: Tue, 21 May 2024 17:16:15 -0700 Subject: [PATCH] Generalized the SVT parsing code for external use, specifically in beachmat. --- docs/Doxyfile | 1 + include/tatami_r/sparse_matrix.hpp | 175 ++++++++++++++++++----------- 2 files changed, 113 insertions(+), 63 deletions(-) diff --git a/docs/Doxyfile b/docs/Doxyfile index f2ef70e..ebe9d0a 100644 --- a/docs/Doxyfile +++ b/docs/Doxyfile @@ -908,6 +908,7 @@ WARN_LOGFILE = INPUT = ../include/tatami_r/tatami_r.hpp \ ../include/tatami_r/parallelize.hpp \ + ../include/tatami_r/sparse_matrix.hpp \ ../include/tatami_r/UnknownMatrix.hpp \ ../README.md parallel.md diff --git a/include/tatami_r/sparse_matrix.hpp b/include/tatami_r/sparse_matrix.hpp index c51e3d3..fa877a4 100644 --- a/include/tatami_r/sparse_matrix.hpp +++ b/include/tatami_r/sparse_matrix.hpp @@ -5,24 +5,43 @@ #include "tatami/tatami.hpp" #include +/** + * @file sparse_matrix.hpp + * @brief Parse sparse matrices from block processing. + */ + namespace tatami_r { -template -void parse_sparse_matrix_internal( - Rcpp::RObject seed, - bool row, - std::vector& value_ptrs, - std::vector& index_ptrs, - Index_* counts) -{ - Rcpp::RObject raw_svt = seed.slot("SVT"); +/** + * Parse the contents of a `SVT_SparseMatrix` from the **DelayedArray** package. + * This accounts for different versions of the class definition, different types of the values, and the presence of lacunar leaf nodes. + * + * @tparam Function_ Function to be applied at each leaf node. + * + * @param matrix The `SVT_SparseMatrix` object. + * @param fun Function to apply to each leaf node, accepting four arguments: + * . + * 1. `c`, an integer specifying the index of the leaf node, i.e., the column index. + * 2. `indices`, an `Rcpp::IntegerVector` containing the sorted, zero-based indices of the structural non-zero elements in this node (i.e., column). + * 3. `all_ones`, a boolean indicating whether all values in this node/column are equal to 1. + * 4. `values`, an `Rcpp::IntegerVector`, `Rcpp::LogicalVector` or `Rcpp::NumericVector` containing the values of the structural non-zeros. + * This should be of the same length as `indices`. + * It should be ignored if `all_ones = true`. + * . + * The return value of this function is ignored. + * Note that `fun` may not be called for all `c` - if leaf nodes do not contain any data, they will be skipped. + */ +template +void parse_SVT_SparseMatrix(Rcpp::RObject matrix, Function_ fun) { + Rcpp::RObject raw_svt = matrix.slot("SVT"); if (raw_svt == R_NilValue) { return; } - Rcpp::IntegerVector svt_version(seed.slot(".svt_version")); + Rcpp::IntegerVector svt_version(matrix.slot(".svt_version")); if (svt_version.size() != 1) { - throw std::runtime_error("'.svt_version' should be an integer scalar"); + auto ctype = get_class_name(matrix); + throw std::runtime_error("'.svt_version' slot of a " + ctype + " should be an integer scalar"); } int version = svt_version[0]; int index_x = (version == 0 ? 0 : 1); @@ -30,11 +49,6 @@ void parse_sparse_matrix_internal( Rcpp::List svt(raw_svt); int NC = svt.size(); - bool needs_value = !value_ptrs.empty(); - bool needs_index = !index_ptrs.empty(); - - // Note that non-empty value_ptrs and index_ptrs may be longer than the - // number of rows/columns in the SVT matrix, due to the reuse of slabs. for (int c = 0; c < NC; ++c) { Rcpp::RObject raw_inner(svt[c]); @@ -44,46 +58,108 @@ void parse_sparse_matrix_internal( Rcpp::List inner(raw_inner); if (inner.size() != 2) { - auto ctype = get_class_name(seed); + auto ctype = get_class_name(matrix); throw std::runtime_error("each entry of the 'SVT' slot of a " + ctype + " object should be a list of length 2 or NULL"); } // Verify type to ensure that we're not making a view on a temporary array. Rcpp::RObject raw_indices = inner[index_x]; if (raw_indices.sexp_type() != INTSXP) { - auto ctype = get_class_name(seed); + auto ctype = get_class_name(matrix); throw std::runtime_error("indices of each element of the 'SVT' slot in a " + ctype + " object should be an integer vector"); } Rcpp::IntegerVector curindices(raw_indices); - size_t nnz = curindices.size(); + auto nnz = curindices.size(); - InputObject_ curvalues; Rcpp::RObject raw_values(inner[value_x]); + auto vsexp = raw_values.sexp_type(); bool has_values = raw_values != R_NilValue; + Rcpp::IntegerVector curvalues_i; + Rcpp::NumericVector curvalues_n; + Rcpp::LogicalVector curvalues_l; + if (has_values) { - if (raw_values.sexp_type() != desired_sexp_) { - auto ctype = get_class_name(seed); - throw std::runtime_error("value vector of an element of the 'SVT' slot in a " + ctype + " object has an unexpected type"); + decltype(nnz) vsize; + switch (vsexp) { + case INTSXP: + curvalues_i = Rcpp::IntegerVector(raw_values); + vsize = curvalues_i.size(); + break; + case REALSXP: + curvalues_n = Rcpp::NumericVector(raw_values); + vsize = curvalues_n.size(); + break; + case LGLSXP: + curvalues_l = Rcpp::LogicalVector(raw_values); + vsize = curvalues_l.size(); + break; + default: + { + auto ctype = get_class_name(matrix); + throw std::runtime_error("value vector of an element of the 'SVT' slot in a " + ctype + " object is not a numeric or logical type"); + } } - curvalues = InputObject_(raw_values); - if (nnz != static_cast(curvalues.size())) { - auto ctype = get_class_name(seed); + if (nnz != vsize) { + auto ctype = get_class_name(matrix); throw std::runtime_error("both vectors of an element of the 'SVT' slot in a " + ctype + " object should have the same length"); } } + switch (vsexp) { + case INTSXP: + fun(c, curindices, !has_values, curvalues_i); + break; + case REALSXP: + fun(c, curindices, !has_values, curvalues_n); + break; + default: + fun(c, curindices, !has_values, curvalues_l); + break; + } + } +} + +/** + * @cond + */ +template +void parse_sparse_matrix( + Rcpp::RObject matrix, + bool row, + std::vector& value_ptrs, + std::vector& index_ptrs, + Index_* counts) +{ + auto ctype = get_class_name(matrix); + if (ctype != "SVT_SparseMatrix") { + // Can't be bothered to write a parser for COO_SparseMatrix objects, + // which are soon-to-be-superceded by SVT_SparseMatrix anyway; so we + // just forcibly coerce it. + auto methods_env = Rcpp::Environment::namespace_env("methods"); + Rcpp::Function converter(methods_env["as"]); + matrix = converter(matrix, Rcpp::CharacterVector::create("SVT_SparseMatrix")); + } + + bool needs_value = !value_ptrs.empty(); + bool needs_index = !index_ptrs.empty(); + + parse_SVT_SparseMatrix(matrix, [&](int c, const auto& curindices, bool all_ones, const auto& curvalues) { + size_t nnz = curindices.size(); + + // Note that non-empty value_ptrs and index_ptrs may be longer than the + // number of rows/columns in the SVT matrix, due to the reuse of slabs. if (row) { if (needs_value) { - if (has_values) { + if (all_ones) { for (size_t i = 0; i < nnz; ++i) { auto ix = curindices[i]; - value_ptrs[ix][counts[ix]] = curvalues[i]; + value_ptrs[ix][counts[ix]] = 1; } } else { for (size_t i = 0; i < nnz; ++i) { auto ix = curindices[i]; - value_ptrs[ix][counts[ix]] = 1; + value_ptrs[ix][counts[ix]] = curvalues[i]; } } } @@ -99,10 +175,10 @@ void parse_sparse_matrix_internal( } else { if (needs_value) { - if (has_values) { - std::copy(curvalues.begin(), curvalues.end(), value_ptrs[c]); - } else { + if (all_ones) { std::fill_n(value_ptrs[c], nnz, 1); + } else { + std::copy(curvalues.begin(), curvalues.end(), value_ptrs[c]); } } if (needs_index) { @@ -110,38 +186,11 @@ void parse_sparse_matrix_internal( } counts[c] = nnz; } - } -} - -template -void parse_sparse_matrix( - Rcpp::RObject seed, - bool row, - std::vector& value_ptrs, - std::vector& index_ptrs, - Index_* counts) -{ - auto ctype = get_class_name(seed); - if (ctype != "SVT_SparseMatrix") { - // Can't be bothered to write a parser for COO_SparseMatrix objects, - // which are soon-to-be-superceded by SVT_SparseMatrix anyway; so we - // just forcibly coerce it. - auto methods_env = Rcpp::Environment::namespace_env("methods"); - Rcpp::Function converter(methods_env["as"]); - seed = converter(seed, Rcpp::CharacterVector::create("SVT_SparseMatrix")); - } - - std::string type = Rcpp::as(seed.slot("type")); - if (type == "double") { - parse_sparse_matrix_internal(seed, row, value_ptrs, index_ptrs, counts); - } else if (type == "integer") { - parse_sparse_matrix_internal(seed, row, value_ptrs, index_ptrs, counts); - } else if (type == "logical") { - parse_sparse_matrix_internal(seed, row, value_ptrs, index_ptrs, counts); - } else { - throw std::runtime_error("unsupported type '" + type + "' for a " + ctype); - } + }); } +/** + * @endcond + */ }