Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R] Cannot plot trees with categorical splits #9925

Open
david-cortes opened this issue Dec 26, 2023 · 10 comments
Open

[R] Cannot plot trees with categorical splits #9925

david-cortes opened this issue Dec 26, 2023 · 10 comments

Comments

@david-cortes
Copy link
Contributor

ref #9810

Currently, attempting to plot trees that have categorical splits in R will result in an error:

library(xgboost)
set.seed(123)
y <- rnorm(100)
x <- sample(3, size=100*3, replace=TRUE) |> matrix(nrow=100)
x <- x - 1
dm <- xgb.DMatrix(data=x, label=y)
setinfo(dm, "feature_type", c("c", "c", "c"))
model <- xgb.train(
    data=dm,
    params=list(
        tree_method="hist",
        max_depth=3
    ),
    nrounds=2
)
xgb.plot.tree(model=model)
Error in do.call(rbind, matches)[, c(2, 3, 5, 6, 7, 8, 10), drop = FALSE] :
subscript out of bounds

This is due to the regexes used to parse the dumps not having been updated for the format used in categorical splits:

branch_rx <- paste0("f(\\d+)<(", anynumber_regex, ")\\] yes=(\\d+),no=(\\d+),missing=(\\d+),",

@david-cortes
Copy link
Contributor Author

@mayer79 Perhaps you would like to work on this issue?

@trivialfis
Copy link
Member

I hope we can remove the regex if possible. XGB can output graphviz dump. I can help add other formats if necessary.

@david-cortes
Copy link
Contributor Author

I hope we can remove the regex if possible. XGB can output graphviz dump. I can help add other formats if necessary.

@trivialfis would be very helpful to add a format "table" which would output the same as python function trees_to_dataframe. Then we can get rid of the regexes in both interfaces and avoid needing to update when vector leaves are implemented.

@trivialfis
Copy link
Member

Yes, I did a proof of concept before, but didn't submit a PR because at the time I was wondering how to export the data to arrow. I can make another attempt.

@david-cortes
Copy link
Contributor Author

Yes, I did a proof of concept before, but didn't submit a PR because at the time I was wondering how to export the data to arrow. I can make another attempt.

I don't think arrow is necessary here - these tables are going to be rather small in most cases, so perhaps a plain JSON with one entry per column in the table would do.

@trivialfis
Copy link
Member

trivialfis commented Jan 2, 2024

I don't think arrow is necessary here

It's more future-proof, we already had feature requests for representing the model as a table, which means XGBoost needs to be able to save and load models as tables. Currently, the to_table method doesn't have a corresponding from_table implementation.

Another thing about Arrow is that the performance is just a bonus, I believe the goal is to have a protocol-like class that can be used for other projects. For example, the spark framework uses Arrow as the underlying representation of a table and uses it to transfer dataframe from Java processes to Python processes, presumably to R as well. As a result, if we are dealing with dataframe, exporting directly to Arrow might be the most efficient and useful way to do it.

@trivialfis
Copy link
Member

This is fixed in the latest by #10989 . Will look into dataframe export separately.

@david-cortes
Copy link
Contributor Author

@trivialfis Looks like the option with_stats=TRUE didn't work. Passing that in the example at the top of this issue has no effect:

library(xgboost)
set.seed(123)
y <- rnorm(100)
x <- sample(3, size=100*3, replace=TRUE) |> matrix(nrow=100)
x <- x - 1
dm <- xgb.DMatrix(data=x, label=y)
setinfo(dm, "feature_type", c("c", "c", "c"))
model <- xgb.train(
    data=dm,
    params=list(
        tree_method="hist",
        max_depth=3
    ),
    nrounds=2
)
xgb.plot.tree(model=model, with_stats=TRUE)

Also xgb.plot.multi.trees still has the issue with categorical columns. Perhaps it could throw an error saying that it doesn't support them.

@trivialfis
Copy link
Member

trivialfis commented Dec 4, 2024

@david-cortes I just ran your script and saved the result to html, the categories and the leaf node hessian (cover) looks correct:
Screenshot from 2024-12-04 22-20-45

gr <- xgb.plot.tree(model=model, with_stats=TRUE)
htmlwidgets::saveWidget(gr, 'plot.html')

Stat for categorical internal split is not available for tree dump yet, I need to add it for all types of tree dump (json, text, graphviz).

@trivialfis trivialfis reopened this Dec 4, 2024
@trivialfis
Copy link
Member

Also xgb.plot.multi.trees still has the issue with categorical columns. Perhaps it could throw an error saying that it doesn't support them.

Thank you for sharing. Yes, we can throw an error. Also need to check plot.deepness and plot.shap. I have added cat support in the Python shap package, not sure what I can help with the R package though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants