Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ exportMethods("%<=>%",
"array_max",
"array_min",
"array_position",
"array_repeat",
"array_sort",
"arrays_overlap",
"asc",
"ascii",
"asin",
Expand Down Expand Up @@ -302,6 +304,7 @@ exportMethods("%<=>%",
"lower",
"lpad",
"ltrim",
"map_entries",
"map_keys",
"map_values",
"max",
Expand Down
2 changes: 2 additions & 0 deletions R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -2297,6 +2297,8 @@ setMethod("rename",

setClassUnion("characterOrColumn", c("character", "Column"))

setClassUnion("numericOrColumn", c("numeric", "Column"))

#' Arrange Rows by Variables
#'
#' Sort a SparkDataFrame by the specified column(s).
Expand Down
58 changes: 53 additions & 5 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ NULL
#' the map or array of maps.
#' \item \code{from_json}: it is the column containing the JSON string.
#' }
#' @param y Column to compute on.
#' @param value A value to compute on.
#' \itemize{
#' \item \code{array_contains}: a value to be checked if contained in the column.
Expand All @@ -207,7 +208,7 @@ NULL
#' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp))
#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1)))
#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1)))
#' head(select(tmp, array_position(tmp$v1, 21), array_sort(tmp$v1)))
#' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1)))
#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1)))
#' tmp2 <- mutate(tmp, v2 = explode(tmp$v1))
#' head(tmp2)
Expand All @@ -216,11 +217,10 @@ NULL
#' head(select(tmp, sort_array(tmp$v1)))
#' head(select(tmp, sort_array(tmp$v1, asc = FALSE)))
#' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl))
#' head(select(tmp3, map_keys(tmp3$v3)))
#' head(select(tmp3, map_values(tmp3$v3)))
#' head(select(tmp3, map_entries(tmp3$v3), map_keys(tmp3$v3), map_values(tmp3$v3)))
#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))
#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$hp))
#' head(select(tmp4, concat(tmp4$v4, tmp4$v5)))
#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp))
#' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5)))
#' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))}
NULL

Expand Down Expand Up @@ -3048,6 +3048,26 @@ setMethod("array_position",
column(jc)
})

#' @details
#' \code{array_repeat}: Creates an array containing \code{x} repeated the number of times
#' given by \code{count}.
#'
#' @param count a Column or constant determining the number of repetitions.
#' @rdname column_collection_functions
#' @aliases array_repeat array_repeat,Column,numericOrColumn-method
#' @note array_repeat since 2.4.0
setMethod("array_repeat",
signature(x = "Column", count = "numericOrColumn"),
function(x, count) {
if (class(count) == "Column") {
count <- count@jc
} else {
count <- as.integer(count)
}
jc <- callJStatic("org.apache.spark.sql.functions", "array_repeat", x@jc, count)
column(jc)
})

#' @details
#' \code{array_sort}: Sorts the input array in ascending order. The elements of the input array
#' must be orderable. NA elements will be placed at the end of the returned array.
Expand All @@ -3062,6 +3082,21 @@ setMethod("array_sort",
column(jc)
})

#' @details
#' \code{arrays_overlap}: Returns true if the input arrays have at least one non-null element in
#' common. If not and both arrays are non-empty and any of them contains a null, it returns null.
#' It returns false otherwise.
#'
#' @rdname column_collection_functions
#' @aliases arrays_overlap arrays_overlap,Column-method
#' @note arrays_overlap since 2.4.0
setMethod("arrays_overlap",
signature(x = "Column", y = "Column"),
function(x, y) {
jc <- callJStatic("org.apache.spark.sql.functions", "arrays_overlap", x@jc, y@jc)
column(jc)
})

#' @details
#' \code{flatten}: Creates a single array from an array of arrays.
#' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed.
Expand All @@ -3076,6 +3111,19 @@ setMethod("flatten",
column(jc)
})

#' @details
#' \code{map_entries}: Returns an unordered array of all entries in the given map.
#'
#' @rdname column_collection_functions
#' @aliases map_entries map_entries,Column-method
#' @note map_entries since 2.4.0
setMethod("map_entries",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "map_entries", x@jc)
column(jc)
})

#' @details
#' \code{map_keys}: Returns an unordered array containing the keys of the map.
#'
Expand Down
12 changes: 12 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -769,10 +769,18 @@ setGeneric("array_min", function(x) { standardGeneric("array_min") })
#' @name NULL
setGeneric("array_position", function(x, value) { standardGeneric("array_position") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("array_sort", function(x) { standardGeneric("array_sort") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") })

#' @rdname column_string_functions
#' @name NULL
setGeneric("ascii", function(x) { standardGeneric("ascii") })
Expand Down Expand Up @@ -1034,6 +1042,10 @@ setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") })
#' @name NULL
setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("map_entries", function(x) { standardGeneric("map_entries") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("map_keys", function(x) { standardGeneric("map_keys") })
Expand Down
22 changes: 21 additions & 1 deletion R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1503,6 +1503,21 @@ test_that("column functions", {
result <- collect(select(df2, reverse(df2[[1]])))[[1]]
expect_equal(result, "cba")

# Test array_repeat()
df <- createDataFrame(list(list("a", 3L), list("b", 2L)))
result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]]
expect_equal(result, list(list("a", "a", "a"), list("b", "b")))

result <- collect(select(df, array_repeat(df[[1]], 2L)))[[1]]
expect_equal(result, list(list("a", "a"), list("b", "b")))

# Test arrays_overlap()
df <- createDataFrame(list(list(list(1L, 2L), list(3L, 1L)),
list(list(1L, 2L), list(3L, 4L)),
list(list(1L, NA), list(3L, 4L))))
result <- collect(select(df, arrays_overlap(df[[1]], df[[2]])))[[1]]
expect_equal(result, c(TRUE, FALSE, NA))

# Test array_sort() and sort_array()
df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L))))

Expand Down Expand Up @@ -1531,8 +1546,13 @@ test_that("column functions", {
result <- collect(select(df, flatten(df[[1]])))[[1]]
expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L)))

# Test map_keys(), map_values() and element_at()
# Test map_entries(), map_keys(), map_values() and element_at()
df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2)))))
result <- collect(select(df, map_entries(df$map)))[[1]]
expected_entries <- list(listToStruct(list(key = "x", value = 1)),
listToStruct(list(key = "y", value = 2)))
expect_equal(result, list(expected_entries))

result <- collect(select(df, map_keys(df$map)))[[1]]
expect_equal(result, list(list("x", "y")))

Expand Down