Skip to content

Commit

Permalink
feat: add $list$sample() (#1204)
Browse files Browse the repository at this point in the history
  • Loading branch information
etiennebacher authored Aug 24, 2024
1 parent b8e7903 commit 7588ca3
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 84 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
- New method `$gather_every()` for `LazyFrame` and `DataFrame` (#1199).
- `$glimpse()` for `DataFrame` has two new arguments `max_items_per_column` and
`max_colname_length` (#1200).
- New method `$list$sample()` (#1204).

### Other changes

Expand Down
29 changes: 29 additions & 0 deletions R/expr__list.R
Original file line number Diff line number Diff line change
Expand Up @@ -612,3 +612,32 @@ ExprList_set_symmetric_difference = function(other) {
ExprList_explode = function() {
.pr$Expr$explode(self)
}

#' Sample from this list
#'
#' @inheritParams Expr_sample
#'
#' @return Expr
#' @examples
#' df = pl$DataFrame(
#' values = list(1:3, NA_integer_, c(NA_integer_, 3L), 5:7),
#' n = c(1, 1, 1, 2)
#' )
#'
#' df$with_columns(
#' sample = pl$col("values")$list$sample(n = pl$col("n"), seed = 1)
#' )
ExprList_sample = function(
n = NULL, ..., fraction = NULL, with_replacement = FALSE, shuffle = FALSE,
seed = NULL) {
pcase(
!is.null(n) && !is.null(fraction), {
Err(.pr$Err$new()$plain("either arg `n` or `fraction` must be NULL"))
},
!is.null(n), .pr$Expr$list_sample_n(self, n, with_replacement, shuffle, seed),
or_else = {
.pr$Expr$list_sample_frac(self, fraction %||% 1, with_replacement, shuffle, seed)
}
) |>
unwrap("in $list$sample():")
}
4 changes: 4 additions & 0 deletions R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,10 @@ RPolarsExpr$list_any <- function() .Call(wrap__RPolarsExpr__list_any, self)

RPolarsExpr$list_set_operation <- function(other, operation) .Call(wrap__RPolarsExpr__list_set_operation, self, other, operation)

RPolarsExpr$list_sample_n <- function(n, with_replacement, shuffle, seed) .Call(wrap__RPolarsExpr__list_sample_n, self, n, with_replacement, shuffle, seed)

RPolarsExpr$list_sample_frac <- function(frac, with_replacement, shuffle, seed) .Call(wrap__RPolarsExpr__list_sample_frac, self, frac, with_replacement, shuffle, seed)

RPolarsExpr$arr_max <- function() .Call(wrap__RPolarsExpr__arr_max, self)

RPolarsExpr$arr_min <- function() .Call(wrap__RPolarsExpr__arr_min, self)
Expand Down
48 changes: 48 additions & 0 deletions man/ExprList_sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ features = [
"list_any_all",
"list_eval",
"list_gather",
"list_sample",
"list_sets",
"list_to_struct",
"log",
Expand Down
42 changes: 41 additions & 1 deletion src/rust/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,7 @@ impl RPolarsExpr {
Ok(self.0.clone().rle_id().into())
}

//arr/list methods
// list methods

fn list_len(&self) -> Self {
self.0.clone().list().len().into()
Expand Down Expand Up @@ -1292,6 +1292,46 @@ impl RPolarsExpr {
.into())
}

pub fn list_sample_n(
&self,
n: Robj,
with_replacement: Robj,
shuffle: Robj,
seed: Robj,
) -> RResult<Self> {
Ok(self
.0
.clone()
.list()
.sample_n(
robj_to!(PLExpr, n)?,
robj_to!(bool, with_replacement)?,
robj_to!(bool, shuffle)?,
robj_to!(Option, u64, seed)?,
)
.into())
}

pub fn list_sample_frac(
&self,
frac: Robj,
with_replacement: Robj,
shuffle: Robj,
seed: Robj,
) -> RResult<Self> {
Ok(self
.0
.clone()
.list()
.sample_fraction(
robj_to!(PLExpr, frac)?,
robj_to!(bool, with_replacement)?,
robj_to!(bool, shuffle)?,
robj_to!(Option, u64, seed)?,
)
.into())
}

// array methods

fn arr_max(&self) -> Self {
Expand Down
167 changes: 84 additions & 83 deletions tests/testthat/_snaps/after-wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -353,89 +353,90 @@
[163] "list_len" "list_max"
[165] "list_mean" "list_min"
[167] "list_n_unique" "list_reverse"
[169] "list_set_operation" "list_shift"
[171] "list_slice" "list_sort"
[173] "list_sum" "list_to_struct"
[175] "list_unique" "lit"
[177] "log" "log10"
[179] "lower_bound" "lt"
[181] "lt_eq" "map_batches"
[183] "map_batches_in_background" "map_elements_in_background"
[185] "max" "mean"
[187] "median" "meta_eq"
[189] "meta_has_multiple_outputs" "meta_is_regex_projection"
[191] "meta_output_name" "meta_pop"
[193] "meta_root_names" "meta_tree_format"
[195] "meta_undo_aliases" "min"
[197] "mode" "mul"
[199] "n_unique" "name_keep"
[201] "name_map" "name_prefix"
[203] "name_prefix_fields" "name_suffix"
[205] "name_suffix_fields" "name_to_lowercase"
[207] "name_to_uppercase" "nan_max"
[209] "nan_min" "neq"
[211] "neq_missing" "new_first"
[213] "new_last" "new_len"
[215] "not" "null_count"
[217] "or" "over"
[219] "pct_change" "peak_max"
[221] "peak_min" "pow"
[223] "print" "product"
[225] "qcut" "qcut_uniform"
[227] "quantile" "rank"
[229] "rechunk" "reinterpret"
[231] "rem" "rep"
[233] "repeat_by" "replace"
[235] "replace_strict" "reshape"
[237] "reverse" "rle"
[239] "rle_id" "rolling"
[241] "rolling_corr" "rolling_cov"
[243] "rolling_max" "rolling_max_by"
[245] "rolling_mean" "rolling_mean_by"
[247] "rolling_median" "rolling_median_by"
[249] "rolling_min" "rolling_min_by"
[251] "rolling_quantile" "rolling_quantile_by"
[253] "rolling_skew" "rolling_std"
[255] "rolling_std_by" "rolling_sum"
[257] "rolling_sum_by" "rolling_var"
[259] "rolling_var_by" "round"
[261] "sample_frac" "sample_n"
[263] "search_sorted" "shift"
[265] "shrink_dtype" "shuffle"
[267] "sign" "sin"
[269] "sinh" "skew"
[271] "slice" "sort_by"
[273] "sort_with" "std"
[275] "str_base64_decode" "str_base64_encode"
[277] "str_contains" "str_contains_any"
[279] "str_count_matches" "str_ends_with"
[281] "str_extract" "str_extract_all"
[283] "str_extract_groups" "str_extract_many"
[285] "str_find" "str_head"
[287] "str_hex_decode" "str_hex_encode"
[289] "str_join" "str_json_decode"
[291] "str_json_path_match" "str_len_bytes"
[293] "str_len_chars" "str_pad_end"
[295] "str_pad_start" "str_replace"
[297] "str_replace_all" "str_replace_many"
[299] "str_reverse" "str_slice"
[301] "str_split" "str_split_exact"
[303] "str_splitn" "str_starts_with"
[305] "str_strip_chars" "str_strip_chars_end"
[307] "str_strip_chars_start" "str_tail"
[309] "str_to_date" "str_to_datetime"
[311] "str_to_integer" "str_to_lowercase"
[313] "str_to_time" "str_to_titlecase"
[315] "str_to_uppercase" "str_zfill"
[317] "struct_field_by_name" "struct_rename_fields"
[319] "struct_with_fields" "sub"
[321] "sum" "tail"
[323] "tan" "tanh"
[325] "to_physical" "top_k"
[327] "unique" "unique_counts"
[329] "unique_stable" "upper_bound"
[331] "value_counts" "var"
[333] "xor"
[169] "list_sample_frac" "list_sample_n"
[171] "list_set_operation" "list_shift"
[173] "list_slice" "list_sort"
[175] "list_sum" "list_to_struct"
[177] "list_unique" "lit"
[179] "log" "log10"
[181] "lower_bound" "lt"
[183] "lt_eq" "map_batches"
[185] "map_batches_in_background" "map_elements_in_background"
[187] "max" "mean"
[189] "median" "meta_eq"
[191] "meta_has_multiple_outputs" "meta_is_regex_projection"
[193] "meta_output_name" "meta_pop"
[195] "meta_root_names" "meta_tree_format"
[197] "meta_undo_aliases" "min"
[199] "mode" "mul"
[201] "n_unique" "name_keep"
[203] "name_map" "name_prefix"
[205] "name_prefix_fields" "name_suffix"
[207] "name_suffix_fields" "name_to_lowercase"
[209] "name_to_uppercase" "nan_max"
[211] "nan_min" "neq"
[213] "neq_missing" "new_first"
[215] "new_last" "new_len"
[217] "not" "null_count"
[219] "or" "over"
[221] "pct_change" "peak_max"
[223] "peak_min" "pow"
[225] "print" "product"
[227] "qcut" "qcut_uniform"
[229] "quantile" "rank"
[231] "rechunk" "reinterpret"
[233] "rem" "rep"
[235] "repeat_by" "replace"
[237] "replace_strict" "reshape"
[239] "reverse" "rle"
[241] "rle_id" "rolling"
[243] "rolling_corr" "rolling_cov"
[245] "rolling_max" "rolling_max_by"
[247] "rolling_mean" "rolling_mean_by"
[249] "rolling_median" "rolling_median_by"
[251] "rolling_min" "rolling_min_by"
[253] "rolling_quantile" "rolling_quantile_by"
[255] "rolling_skew" "rolling_std"
[257] "rolling_std_by" "rolling_sum"
[259] "rolling_sum_by" "rolling_var"
[261] "rolling_var_by" "round"
[263] "sample_frac" "sample_n"
[265] "search_sorted" "shift"
[267] "shrink_dtype" "shuffle"
[269] "sign" "sin"
[271] "sinh" "skew"
[273] "slice" "sort_by"
[275] "sort_with" "std"
[277] "str_base64_decode" "str_base64_encode"
[279] "str_contains" "str_contains_any"
[281] "str_count_matches" "str_ends_with"
[283] "str_extract" "str_extract_all"
[285] "str_extract_groups" "str_extract_many"
[287] "str_find" "str_head"
[289] "str_hex_decode" "str_hex_encode"
[291] "str_join" "str_json_decode"
[293] "str_json_path_match" "str_len_bytes"
[295] "str_len_chars" "str_pad_end"
[297] "str_pad_start" "str_replace"
[299] "str_replace_all" "str_replace_many"
[301] "str_reverse" "str_slice"
[303] "str_split" "str_split_exact"
[305] "str_splitn" "str_starts_with"
[307] "str_strip_chars" "str_strip_chars_end"
[309] "str_strip_chars_start" "str_tail"
[311] "str_to_date" "str_to_datetime"
[313] "str_to_integer" "str_to_lowercase"
[315] "str_to_time" "str_to_titlecase"
[317] "str_to_uppercase" "str_zfill"
[319] "struct_field_by_name" "struct_rename_fields"
[321] "struct_with_fields" "sub"
[323] "sum" "tail"
[325] "tan" "tanh"
[327] "to_physical" "top_k"
[329] "unique" "unique_counts"
[331] "unique_stable" "upper_bound"
[333] "value_counts" "var"
[335] "xor"

# public and private methods of each class When

Expand Down
33 changes: 33 additions & 0 deletions tests/testthat/test-expr_list.R
Original file line number Diff line number Diff line change
Expand Up @@ -626,3 +626,36 @@ test_that("$list$explode() works", {
"lengths don't match"
)
})

test_that("$list$sample() works", {
df = pl$DataFrame(
values = list(1:3, NA_integer_, c(NA_integer_, 3L), 5:7),
n = c(1, 1, 1, 2)
)

expect_identical(
df$select(
sample = pl$col("values")$list$sample(n = pl$col("n"), seed = 1)
)$to_list(),
list(sample = list(3L, NA_integer_, 3L, 6:5))
)

expect_grepl_error(
df$select(pl$col("values")$list$sample(fraction = 2)),
"cannot take a larger sample than the total population when `with_replacement=false`"
)

expect_identical(
df$select(
sample = pl$col("values")$list$sample(fraction = 2, with_replacement = TRUE, seed = 1)
)$to_list(),
list(
sample = list(
c(3L, 1L, 1L, 2L, 2L, 3L),
c(NA_integer_, NA_integer_),
c(3L, NA, NA, NA),
c(7L, 5L, 5L, 6L, 6L, 7L)
)
)
)
})

0 comments on commit 7588ca3

Please sign in to comment.