This repository was archived by the owner on Jul 31, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmlr3learners_lightgbm_regression.Rmd
123 lines (95 loc) · 2.77 KB
/
mlr3learners_lightgbm_regression.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
---
title: "mlr3learners.lightgbm: Regression Classification Example"
date: "`r Sys.Date()`"
output:
rmarkdown::html_vignette:
keep_md: true
vignette: >
%\VignetteIndexEntry{mlr3learners_lightgbm_regression}
%\VignetteEncoding{UTF-8}
%\VignetteEngine{knitr::rmarkdown}
editor_options:
chunk_output_type: console
---
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>"
)
```
```{r setup}
library(mlr3)
library(mlr3learners.lightgbm)
library(mlbench)
```
# Create the mlr3 task
```{r}
data("BostonHousing2")
dataset = data.table::as.data.table(BostonHousing2)
target_col = "medv"
dataset = backend_preprocessing(
datatable = dataset,
target_col = target_col,
task_type = "regression"
)
task = mlr3::TaskRegr$new(
id = "bostonhousing",
backend = dataset,
target = target_col
)
```
To have independent validation data and test data, we further create a list `split`, containing the respective row indices.
```{r}
set.seed(17)
split = list(
train_index = sample(seq_len(task$nrow), size = 0.7 * task$nrow)
)
split$test_index = setdiff(seq_len(task$nrow), split$train_index)
```
# Instantiate the lightgbm learner
Then, the `regr.lightgbm` class needs to be instantiated:
```{r}
learner = mlr3::lrn("regr.lightgbm", objective = "regression")
```
# Configure the learner
In the next step, some of the learner's parameters need to be set. E.g., the parameters `num_iterations` and `early_stopping_round` can be set here. Please refer to the [LightGBM manual](https://lightgbm.readthedocs.io) for further details these parameters. Almost all possible parameters have been implemented here. You can inspect them using the following command:
```{r eval=FALSE}
learner$param_set
```
Use the custom "rmsle" evaluation function:
```{r}
learner$param_set$values = mlr3misc::insert_named(
learner$param_set$values,
list(
"early_stopping_round" = 10,
"learning_rate" = 0.1,
"seed" = 17L,
"num_iterations" = 100,
"metric" = "rmse"
)
)
```
# Train the learner
The learner is now ready to be trained by using its `train` function.
```{r results='hide', message=FALSE, warning=FALSE, error=FALSE}
learner$train(task, row_ids = split$train_index)
```
# Evaluate the model performance
Basic metrics can be assessed directly from the learner model:
```{r}
learner$model$current_iter()
```
The learner's `predict` function returns an object of mlr3's class `PredictionRegr`.
```{r}
predictions = learner$predict(task, row_ids = split$test_index)
head(predictions$response)
```
Further metrics can be calculated by using mlr3 measures:
```{r}
predictions$score(mlr3::msr("regr.rmsle"))
```
The variable importance plot can be calculated by using the learner's `importance` function:
```{r}
importance = learner$importance()
importance
```