Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Enum datatype #1061

Merged
merged 7 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New features

- `$cut()` and `$qcut()` to bin continuous values into discrete categories (#1057).
- Add support for the `Enum` datatype via `pl$Enum()` (#1061).

### Bug fixes

Expand Down
55 changes: 55 additions & 0 deletions R/datatype.R
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ DataType_constructors = function() {
list(
Array = DataType_Array,
Categorical = DataType_Categorical,
Enum = DataType_Enum,
Datetime = DataType_Datetime,
Duration = DataType_Duration,
List = DataType_List,
Expand Down Expand Up @@ -362,6 +363,60 @@ DataType_Categorical = function(ordering = "physical") {
.pr$DataType$new_categorical(ordering) |> unwrap()
}

#' Create Enum DataType
#'
#' An `Enum` is a fixed set categorical encoding of a set of strings. It is
#' similar to the [`Categorical`][DataType_Categorical] data type, but the
#' categories are explicitly provided by the user and cannot be modified.
#'
#' This functionality is **unstable**. It is a work-in-progress feature and may
#' not always work as expected. It may be changed at any point without it being
#' considered a breaking change.
#'
#' @param categories A character vector specifying the categories of the variable.
#'
#' @return An Enum DataType
#' @examples
#' pl$DataFrame(
#' x = c("Polar", "Panda", "Brown", "Brown", "Polar"),
#' schema = list(x = pl$Enum(c("Polar", "Panda", "Brown")))
#' )
#'
#' # All values of the variable have to be in the categories
#' dtype = pl$Enum(c("Polar", "Panda", "Brown"))
#' tryCatch(
#' pl$DataFrame(
#' x = c("Polar", "Panda", "Brown", "Brown", "Polar", "Black"),
#' schema = list(x = dtype)
#' ),
#' error = function(e) e
#' )
#'
#' # Comparing two Enum is only valid if they have the same categories
#' df = pl$DataFrame(
#' x = c("Polar", "Panda", "Brown", "Brown", "Polar"),
#' y = c("Polar", "Polar", "Polar", "Brown", "Brown"),
#' z = c("Polar", "Polar", "Polar", "Brown", "Brown"),
#' schema = list(
#' x = pl$Enum(c("Polar", "Panda", "Brown")),
#' y = pl$Enum(c("Polar", "Panda", "Brown")),
#' z = pl$Enum(c("Polar", "Black", "Brown"))
#' )
#' )
#'
#' # Same categories
#' df$with_columns(x_eq_y = pl$col("x") == pl$col("y"))
#'
#' # Different categories
#' tryCatch(
#' df$with_columns(x_eq_z = pl$col("x") == pl$col("z")),
#' error = function(e) e
#' )
DataType_Enum = function(categories) {
.pr$DataType$new_enum(categories) |> unwrap()
}


#' Check whether the data type is a temporal type
#'
#' @return A logical value
Expand Down
2 changes: 2 additions & 0 deletions R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ RPolarsDataType$new <- function(s) .Call(wrap__RPolarsDataType__new, s)

RPolarsDataType$new_categorical <- function(ordering) .Call(wrap__RPolarsDataType__new_categorical, ordering)

RPolarsDataType$new_enum <- function(categories) .Call(wrap__RPolarsDataType__new_enum, categories)

RPolarsDataType$new_datetime <- function(tu, tz) .Call(wrap__RPolarsDataType__new_datetime, tu, tz)

RPolarsDataType$new_duration <- function(tu) .Call(wrap__RPolarsDataType__new_duration, tu)
Expand Down
61 changes: 61 additions & 0 deletions man/DataType_Enum.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/pl_pl.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/rust/src/conversion_s_to_r.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ pub fn pl_series_to_list(
.set_class(["rpolars_raw_list", "list"])
.expect("this class label is always valid")
}),
Enum(_, _) => s
.categorical()
.map(|ca| extendr_api::call!("factor", ca.iter_str().collect_robj()).unwrap()),
Categorical(_, _) => s
.categorical()
.map(|ca| extendr_api::call!("factor", ca.iter_str().collect_robj()).unwrap()),
Expand Down
25 changes: 25 additions & 0 deletions src/rust/src/rdatatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl RPolarsDataType {
"Time" | "time" => pl::DataType::Time,
"Null" | "null" => pl::DataType::Null,
"Categorical" | "factor" => pl::DataType::Categorical(None, Default::default()),
"Enum" => pl::DataType::Enum(None, Default::default()),
"Unknown" | "unknown" => pl::DataType::Unknown,

_ => panic!("data type not recgnized "),
Expand All @@ -91,6 +92,16 @@ impl RPolarsDataType {
Ok(RPolarsDataType(pl::DataType::Categorical(None, ordering)))
}

pub fn new_enum(categories: Robj) -> RResult<RPolarsDataType> {
use crate::conversion_r_to_s::robjname2series;
let s = robjname2series(categories, "").unwrap();
let ca = s.str()?;
let categories = ca.downcast_iter().next().unwrap().clone();
Ok(RPolarsDataType(pl::datatypes::create_enum_data_type(
categories,
)))
}

pub fn new_datetime(tu: Robj, tz: Nullable<String>) -> RResult<RPolarsDataType> {
robj_to!(timeunit, tu)
.map(|dt| RPolarsDataType(pl::DataType::Datetime(dt, null_to_opt(tz))))
Expand Down Expand Up @@ -203,6 +214,20 @@ impl RPolarsDataType {
self.0.is_temporal()
}

// When rust-polars 0.40.0 is released:

// pub fn is_enum(&self) -> bool {
// self.0.is_enum()
// }

// pub fn is_categorical(&self) -> bool {
// self.0.is_categorical()
// }

// pub fn is_string(&self) -> bool {
// self.0.is_string()
// }

pub fn is_logical(&self) -> bool {
self.0.is_logical()
}
Expand Down
99 changes: 50 additions & 49 deletions tests/testthat/_snaps/after-wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,55 +7,56 @@
[3] "Boolean" "Categorical"
[5] "DataFrame" "Date"
[7] "Datetime" "Duration"
[9] "Field" "Float32"
[11] "Float64" "Int16"
[13] "Int32" "Int64"
[15] "Int8" "LazyFrame"
[17] "List" "Null"
[19] "PTime" "SQLContext"
[21] "Series" "String"
[23] "Struct" "Time"
[25] "UInt16" "UInt32"
[27] "UInt64" "UInt8"
[29] "Unknown" "Utf8"
[31] "all" "all_horizontal"
[33] "any_horizontal" "approx_n_unique"
[35] "arg_sort_by" "arg_where"
[37] "class_names" "coalesce"
[39] "col" "concat"
[41] "concat_list" "concat_str"
[43] "corr" "count"
[45] "cov" "date"
[47] "date_range" "date_ranges"
[49] "datetime" "datetime_range"
[51] "datetime_ranges" "disable_string_cache"
[53] "dtypes" "duration"
[55] "element" "enable_string_cache"
[57] "first" "fold"
[59] "from_epoch" "get_global_rpool_cap"
[61] "head" "implode"
[63] "int_range" "int_ranges"
[65] "is_schema" "last"
[67] "len" "lit"
[69] "max" "max_horizontal"
[71] "mean" "mean_horizontal"
[73] "median" "mem_address"
[75] "min" "min_horizontal"
[77] "n_unique" "numeric_dtypes"
[79] "raw_list" "read_csv"
[81] "read_ipc" "read_ndjson"
[83] "read_parquet" "reduce"
[85] "rolling_corr" "rolling_cov"
[87] "same_outer_dt" "scan_csv"
[89] "scan_ipc" "scan_ndjson"
[91] "scan_parquet" "select"
[93] "set_global_rpool_cap" "show_all_public_functions"
[95] "show_all_public_methods" "std"
[97] "struct" "sum"
[99] "sum_horizontal" "tail"
[101] "thread_pool_size" "time"
[103] "using_string_cache" "var"
[105] "when" "with_string_cache"
[9] "Enum" "Field"
[11] "Float32" "Float64"
[13] "Int16" "Int32"
[15] "Int64" "Int8"
[17] "LazyFrame" "List"
[19] "Null" "PTime"
[21] "SQLContext" "Series"
[23] "String" "Struct"
[25] "Time" "UInt16"
[27] "UInt32" "UInt64"
[29] "UInt8" "Unknown"
[31] "Utf8" "all"
[33] "all_horizontal" "any_horizontal"
[35] "approx_n_unique" "arg_sort_by"
[37] "arg_where" "class_names"
[39] "coalesce" "col"
[41] "concat" "concat_list"
[43] "concat_str" "corr"
[45] "count" "cov"
[47] "date" "date_range"
[49] "date_ranges" "datetime"
[51] "datetime_range" "datetime_ranges"
[53] "disable_string_cache" "dtypes"
[55] "duration" "element"
[57] "enable_string_cache" "first"
[59] "fold" "from_epoch"
[61] "get_global_rpool_cap" "head"
[63] "implode" "int_range"
[65] "int_ranges" "is_schema"
[67] "last" "len"
[69] "lit" "max"
[71] "max_horizontal" "mean"
[73] "mean_horizontal" "median"
[75] "mem_address" "min"
[77] "min_horizontal" "n_unique"
[79] "numeric_dtypes" "raw_list"
[81] "read_csv" "read_ipc"
[83] "read_ndjson" "read_parquet"
[85] "reduce" "rolling_corr"
[87] "rolling_cov" "same_outer_dt"
[89] "scan_csv" "scan_ipc"
[91] "scan_ndjson" "scan_parquet"
[93] "select" "set_global_rpool_cap"
[95] "show_all_public_functions" "show_all_public_methods"
[97] "std" "struct"
[99] "sum" "sum_horizontal"
[101] "tail" "thread_pool_size"
[103] "time" "using_string_cache"
[105] "var" "when"
[107] "with_string_cache"

---

Expand Down
39 changes: 39 additions & 0 deletions tests/testthat/test-datatype.R
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,42 @@ test_that("is_* functions for datatype work", {
expect_false(pl$Struct()$is_primitive())
expect_false(pl$List()$is_primitive())
})

test_that("Enum", {
expect_identical(
as_polars_series(c("z", "z", "k", "a"))$
cast(pl$Enum(c("z", "k", "a")))$
to_r(),
factor(c("z", "z", "k", "a"))
)

expect_grepl_error(pl$Enum(), "missing")
expect_grepl_error(pl$Enum(1), "invalid series dtype")
expect_grepl_error(pl$Enum(TRUE), "invalid series dtype")
expect_grepl_error(pl$Enum(factor("a")), "invalid series dtype")

expect_error(
as_polars_series(c("z", "z", "k", "a"))$
cast(pl$Enum(c("foo", "k", "a"))),
"Ensure that all values in the input column are present"
)

# Can compare two cols if same Enum categories only

df = pl$DataFrame(x = "a", y = "b", z = "c")$
with_columns(
pl$col("x")$cast(pl$Enum(c("a", "b", "c"))),
pl$col("y")$cast(pl$Enum(c("a", "b", "c"))),
pl$col("z")$cast(pl$Enum(c("a", "c")))
)

expect_identical(
df$select(x_eq_y = pl$col("x") == pl$col("y"))$to_list(),
list(x_eq_y = FALSE)
)

expect_grepl_error(
df$select(x_eq_z = pl$col("x") == pl$col("z")),
"cannot compare categoricals coming from different sources"
)
})
Loading