Skip to content

Commit

Permalink
add regr.logloss
Browse files Browse the repository at this point in the history
  • Loading branch information
RaphaelS1 committed Feb 10, 2023
1 parent d92822f commit babf866
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 22 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.5.0
Version: 0.5.1
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ export(LearnerSurvKaplan)
export(LearnerSurvRpart)
export(MeasureDens)
export(MeasureDensLogloss)
export(MeasureRegrLogloss)
export(MeasureSurv)
export(MeasureSurvAUC)
export(MeasureSurvCalibrationAlpha)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# mlr3proba 0.5.1

* Add `regr.logloss`

# mlr3proba 0.5.0

* Possibly small breaking change, renamed `PipeOpProbregrCompositor` to `PipeOpProbregr` and default distribution now `"Uniform"`.
Expand Down
75 changes: 55 additions & 20 deletions R/MeasureRegrLogloss.R
Original file line number Diff line number Diff line change
@@ -1,21 +1,56 @@
#' MeasureRegrLogloss = R6::R6Class("MeasureRegrLogloss",
#' inherit = MeasureRegr,
#' public = list(
#' #' @description Creates a new instance of this [R6][R6::R6Class] class.
#' initialize = function() {
#' super$initialize(
#' id = "regr.logloss",
#' range = c(0, Inf),
#' minimize = TRUE,
#' predict_type = "distr"
#' # task_properties = "twoclass",
#' # packages = "Metrics"
#' )
#' },
#' @template regr_measure
#' @templateVar title Log loss
#' @templateVar inherit [MeasureRegr]
#' @templateVar fullname MeasureRegrLogloss
#' @templateVar pars eps = 1e-15
#' @templateVar eps_par TRUE
#'
#' .score = function(prediction, ...) {
#' return(mean(-log(as.numeric(do.call(prediction$prob$pdf,
#' as.list(prediction$truth))))))
#' }
#' )
#' )
#' @template param_eps
#'
#' @description
#' Calculates the cross-entropy, or logarithmic (log), loss.
#'
#' The logloss, in the context of probabilistic predictions, is defined as the negative log
#' probability density function, \eqn{f}, evaluated at the observed value, \eqn{y},
#' \deqn{L(f, y) = -\log(f(y))}{L(f, y) = -log(f(y))}
#'
#' @export
MeasureRegrLogloss = R6::R6Class("MeasureRegrLogloss",
inherit = MeasureRegr,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
ps = ps(
eps = p_dbl(0, 1, default = 1e-15)
)
ps$values$eps = 1e-15

super$initialize(
id = "regr.logloss",
range = c(0, Inf),
minimize = TRUE,
predict_type = "distr",
man = "mlr3proba::mlr_measures_regr.logloss",
label = "Log Loss",
param_set = ps
)
}
),

private = list(
.score = function(prediction, ...) {
distr = prediction$distr
truth = prediction$truth

if (inherits(distr, "Matdist")) {
pdf = diag(distr$pdf(truth))
} else {
pdf = as.numeric(distr$pdf(data = matrix(truth, nrow = 1)))
}

pdf[pdf == 0] = self$param_set$values$eps
mean(-log(pdf))
}
)
)
2 changes: 1 addition & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ register_mlr3 = function() {

x$add("dens.logloss", MeasureDensLogloss)

# x$add("regr.logloss", MeasureRegrLogloss)
x$add("regr.logloss", MeasureRegrLogloss)

x$add("surv.graf", MeasureSurvGraf)
x$add("surv.brier", MeasureSurvGraf)
Expand Down
15 changes: 15 additions & 0 deletions man-roxygen/regr_measure.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#' <% meas = get(fullname)$new() %>
#' <% shortname = meas$id %>
#'
#' @include MeasureRegr.R
#' @title <%=title%> Regression Measure
#' @name <%= paste("mlr_measures", shortname, sep = "_")%>
#'
#' @section Meta Information:
#' * Type: `"regr"`
#' * Range: <%= format_range(meas$range) %>
#' * Minimize: `<%=meas$minimize%>`
#' * Required prediction: `<%=meas$predict_type%>`
#'
#' @family regression measures
#' @template seealso_measure
73 changes: 73 additions & 0 deletions man/mlr_measures_regr.logloss.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions tests/testthat/test_mlr_measures_regr.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
library(mlr3pipelines)
l = as_learner(ppl("probregr", learner = lrn("regr.featureless")))
task = tsk("mtcars")
p = l$train(task)$predict(task)

expect_measure(msr("regr.logloss"))
expect_numeric(p$score(msr("regr.logloss")))

0 comments on commit babf866

Please sign in to comment.