Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: optimize Rcpp #411

Merged
merged 13 commits into from
Sep 6, 2024
Merged
2 changes: 1 addition & 1 deletion .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ indent_size = 2
indent_size = 4

[*.{cpp,hpp}]
indent_size = 4
indent_size = 2

[{NEWS.md,DESCRIPTION,LICENSE}]
max_line_length = 80
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.6.7
Version: 0.6.8
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# mlr3proba 0.6.8

- `Rcpp` code optimizations
- Fixed ERV scoring to comply with `mlr3` dev version (no bugs before)
- Skipping `survtoregr` pipelines due to bugs (to be refactored in the future)

# mlr3proba 0.6.7

- Deprecate `crank` to `distr` composition in `distrcompose` pipeop (only from `lp` => `distr` works now)
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerDens.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ LearnerDens = R6::R6Class("LearnerDens",
#' @description Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id, param_set = ps(),
predict_types = "cdf", feature_types = character(),
properties = character(), data_formats = "data.table",
properties = character(),
packages = character(),
label = NA_character_,
man = NA_character_) {
super$initialize(
id = id, task_type = "dens", param_set = param_set,
predict_types = predict_types, feature_types = feature_types, properties = properties,
data_formats = data_formats, packages = c("mlr3proba", packages), label = label, man = man)
packages = c("mlr3proba", packages), label = label, man = man)
}
)
)
16 changes: 9 additions & 7 deletions R/scoring_rule_erv.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,20 @@
if (!is.null(ps$se) && ps$se) {
stop("Only one of `ERV` and `se` can be TRUE")
}
measure$param_set$values$ERV = FALSE

measure$param_set$set_values(ERV = FALSE)
# compute score for the learner
learner_score = measure$score(prediction, task = task,
train_set = train_set)
learner_score = measure$score(prediction, task = task, train_set = train_set)

# compute score for the baseline (Kaplan-Meier)
# train KM
km = lrn("surv.kaplan")$train(task = task, row_ids = train_set)
km = km$predict(as_task_surv(data.frame(
as.matrix(prediction$truth)), event = "status"))
base_score = measure$score(km, task = task, train_set = train_set)
# predict KM on the test set (= not train ids)
test_set = setdiff(task$row_ids, train_set)
km_pred = km$predict(task, row_ids = test_set)
base_score = measure$score(km_pred, task = task, train_set = train_set)

measure$param_set$values$ERV = TRUE
measure$param_set$set_values(ERV = TRUE)

# return percentage decrease
1 - (learner_score / base_score)
Expand Down
48 changes: 24 additions & 24 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// c_assert_surv
bool c_assert_surv(NumericMatrix mat);
bool c_assert_surv(const NumericMatrix& mat);
RcppExport SEXP _mlr3proba_c_assert_surv(SEXP matSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericMatrix >::type mat(matSEXP);
Rcpp::traits::input_parameter< const NumericMatrix& >::type mat(matSEXP);
rcpp_result_gen = Rcpp::wrap(c_assert_surv(mat));
return rcpp_result_gen;
END_RCPP
Expand All @@ -34,74 +34,74 @@ BEGIN_RCPP
END_RCPP
}
// c_score_intslogloss
NumericMatrix c_score_intslogloss(NumericVector truth, NumericVector unique_times, NumericMatrix cdf, double eps);
NumericMatrix c_score_intslogloss(const NumericVector& truth, const NumericVector& unique_times, const NumericMatrix& cdf, double eps);
RcppExport SEXP _mlr3proba_c_score_intslogloss(SEXP truthSEXP, SEXP unique_timesSEXP, SEXP cdfSEXP, SEXP epsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericVector >::type truth(truthSEXP);
Rcpp::traits::input_parameter< NumericVector >::type unique_times(unique_timesSEXP);
Rcpp::traits::input_parameter< NumericMatrix >::type cdf(cdfSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type truth(truthSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type unique_times(unique_timesSEXP);
Rcpp::traits::input_parameter< const NumericMatrix& >::type cdf(cdfSEXP);
Rcpp::traits::input_parameter< double >::type eps(epsSEXP);
rcpp_result_gen = Rcpp::wrap(c_score_intslogloss(truth, unique_times, cdf, eps));
return rcpp_result_gen;
END_RCPP
}
// c_score_graf_schmid
NumericMatrix c_score_graf_schmid(NumericVector truth, NumericVector unique_times, NumericMatrix cdf, int power);
NumericMatrix c_score_graf_schmid(const NumericVector& truth, const NumericVector& unique_times, const NumericMatrix& cdf, int power);
RcppExport SEXP _mlr3proba_c_score_graf_schmid(SEXP truthSEXP, SEXP unique_timesSEXP, SEXP cdfSEXP, SEXP powerSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericVector >::type truth(truthSEXP);
Rcpp::traits::input_parameter< NumericVector >::type unique_times(unique_timesSEXP);
Rcpp::traits::input_parameter< NumericMatrix >::type cdf(cdfSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type truth(truthSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type unique_times(unique_timesSEXP);
Rcpp::traits::input_parameter< const NumericMatrix& >::type cdf(cdfSEXP);
Rcpp::traits::input_parameter< int >::type power(powerSEXP);
rcpp_result_gen = Rcpp::wrap(c_score_graf_schmid(truth, unique_times, cdf, power));
return rcpp_result_gen;
END_RCPP
}
// c_weight_survival_score
NumericMatrix c_weight_survival_score(NumericMatrix score, NumericMatrix truth, NumericVector unique_times, NumericMatrix cens, bool proper, double eps);
NumericMatrix c_weight_survival_score(const NumericMatrix& score, const NumericMatrix& truth, const NumericVector& unique_times, const NumericMatrix& cens, bool proper, double eps);
RcppExport SEXP _mlr3proba_c_weight_survival_score(SEXP scoreSEXP, SEXP truthSEXP, SEXP unique_timesSEXP, SEXP censSEXP, SEXP properSEXP, SEXP epsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericMatrix >::type score(scoreSEXP);
Rcpp::traits::input_parameter< NumericMatrix >::type truth(truthSEXP);
Rcpp::traits::input_parameter< NumericVector >::type unique_times(unique_timesSEXP);
Rcpp::traits::input_parameter< NumericMatrix >::type cens(censSEXP);
Rcpp::traits::input_parameter< const NumericMatrix& >::type score(scoreSEXP);
Rcpp::traits::input_parameter< const NumericMatrix& >::type truth(truthSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type unique_times(unique_timesSEXP);
Rcpp::traits::input_parameter< const NumericMatrix& >::type cens(censSEXP);
Rcpp::traits::input_parameter< bool >::type proper(properSEXP);
Rcpp::traits::input_parameter< double >::type eps(epsSEXP);
rcpp_result_gen = Rcpp::wrap(c_weight_survival_score(score, truth, unique_times, cens, proper, eps));
return rcpp_result_gen;
END_RCPP
}
// c_concordance
float c_concordance(NumericVector time, NumericVector status, NumericVector crank, double t_max, std::string weight_meth, NumericMatrix cens, NumericMatrix surv, float tiex);
float c_concordance(const NumericVector& time, const NumericVector& status, const NumericVector& crank, double t_max, const std::string& weight_meth, const NumericMatrix& cens, const NumericMatrix& surv, float tiex);
RcppExport SEXP _mlr3proba_c_concordance(SEXP timeSEXP, SEXP statusSEXP, SEXP crankSEXP, SEXP t_maxSEXP, SEXP weight_methSEXP, SEXP censSEXP, SEXP survSEXP, SEXP tiexSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericVector >::type time(timeSEXP);
Rcpp::traits::input_parameter< NumericVector >::type status(statusSEXP);
Rcpp::traits::input_parameter< NumericVector >::type crank(crankSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type time(timeSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type status(statusSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type crank(crankSEXP);
Rcpp::traits::input_parameter< double >::type t_max(t_maxSEXP);
Rcpp::traits::input_parameter< std::string >::type weight_meth(weight_methSEXP);
Rcpp::traits::input_parameter< NumericMatrix >::type cens(censSEXP);
Rcpp::traits::input_parameter< NumericMatrix >::type surv(survSEXP);
Rcpp::traits::input_parameter< const std::string& >::type weight_meth(weight_methSEXP);
Rcpp::traits::input_parameter< const NumericMatrix& >::type cens(censSEXP);
Rcpp::traits::input_parameter< const NumericMatrix& >::type surv(survSEXP);
Rcpp::traits::input_parameter< float >::type tiex(tiexSEXP);
rcpp_result_gen = Rcpp::wrap(c_concordance(time, status, crank, t_max, weight_meth, cens, surv, tiex));
return rcpp_result_gen;
END_RCPP
}
// c_gonen
double c_gonen(NumericVector crank, float tiex);
double c_gonen(const NumericVector& crank, float tiex);
RcppExport SEXP _mlr3proba_c_gonen(SEXP crankSEXP, SEXP tiexSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericVector >::type crank(crankSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type crank(crankSEXP);
Rcpp::traits::input_parameter< float >::type tiex(tiexSEXP);
rcpp_result_gen = Rcpp::wrap(c_gonen(crank, tiex));
return rcpp_result_gen;
Expand Down
5 changes: 2 additions & 3 deletions src/survival_assert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
using namespace Rcpp;

// [[Rcpp::export]]
bool c_assert_surv(NumericMatrix mat) {
bool c_assert_surv(const NumericMatrix& mat) {
for (int i = 0; i < mat.nrow(); i++) {
// check first element
if (mat(i, 0) < 0 || mat(i, 0) > 1) {
// check first element
return false;
}

Expand All @@ -24,4 +24,3 @@ bool c_assert_surv(NumericMatrix mat) {

return true;
}

Loading