From 857f70c906bd4beaf9e4e61007650ab2a5ffbb81 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 2 Sep 2015 13:40:58 +0800 Subject: [PATCH 1/6] [SPARK-10050][SPARKR] Support collecting data of MapType in DataFrame. --- R/pkg/R/SQLContext.R | 5 +- R/pkg/R/deserialize.R | 14 +++++ R/pkg/R/schema.R | 32 +++++++++--- R/pkg/inst/tests/test_sparkSQL.R | 52 +++++++++++++++---- .../scala/org/apache/spark/api/r/SerDe.scala | 35 +++++++++++++ .../org/apache/spark/sql/api/r/SQLUtils.scala | 3 ++ 6 files changed, 119 insertions(+), 22 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 4ac057d0f2d83..1c58fd96d750a 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -41,10 +41,7 @@ infer_type <- function(x) { if (type == "map") { stopifnot(length(x) > 0) key <- ls(x)[[1]] - list(type = "map", - keyType = "string", - valueType = infer_type(get(key, x)), - valueContainsNull = TRUE) + paste0("map") } else if (type == "array") { stopifnot(length(x) > 0) names <- names(x) diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index d1858ec227b56..ce88d0b071b72 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -50,6 +50,7 @@ readTypedObject <- function(con, type) { "t" = readTime(con), "a" = readArray(con), "l" = readList(con), + "e" = readEnv(con), "n" = NULL, "j" = getJobj(readString(con)), stop(paste("Unsupported type for deserialization", type))) @@ -121,6 +122,19 @@ readList <- function(con) { } } +readEnv <- function(con) { + env <- new.env() + len <- readInt(con) + if (len > 0) { + for (i in 1:len) { + key <- readString(con) + value <- readObject(con) + env[[key]] <- value + } + } + env +} + readRaw <- function(con) { dataLen <- readInt(con) readBin(con, raw(), as.integer(dataLen), endian = "big") diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 62d4f73878d29..2f6b0ad2637a3 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -131,13 +131,31 @@ checkType <- function(type) { if (type %in% primtiveTypes) { return() } else { - m <- regexec("^array<(.*)>$", type) - matchedStrings <- regmatches(type, m) - if (length(matchedStrings[[1]]) >= 2) { - elemType <- matchedStrings[[1]][2] - checkType(elemType) - return() - } + # Array type + firstChar <- substr(type, 1, 1) + switch (firstChar, + a = { + # Array type + m <- regexec("^array<(.*)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + elemType <- matchedStrings[[1]][2] + checkType(elemType) + return() + } + }, + m = { + # Map type + m <- regexec("^map<(.*),(.*)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 3) { + keyType <- matchedStrings[[1]][2] + valueType <- matchedStrings[[1]][3] + checkType(keyType) + checkType(valueType) + return() + } + }) } stop(paste("Unsupported type for Dataframe:", type)) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 1ccfde59176f5..af9e668d9eb13 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -72,9 +72,7 @@ test_that("infer types", { checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) e <- new.env() assign("a", 1L, envir = e) - expect_equal(infer_type(e), - list(type = "map", keyType = "string", valueType = "integer", - valueContainsNull = TRUE)) + expect_equal(infer_type(e), "map") }) test_that("structType and structField", { @@ -242,7 +240,7 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) -test_that("create DataFrame with nested array and struct", { +test_that("create DataFrame with nested array and map", { # e <- new.env() # assign("n", 3L, envir = e) # l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) @@ -253,21 +251,35 @@ test_that("create DataFrame with nested array and struct", { # ldf <- collect(df) # expect_equal(ldf[1,], l[[1]]) + # ArrayType and MapType + e <- new.env() + assign("n", 3L, envir = e) - # ArrayType only for now - l <- list(as.list(1:10), list("a", "b")) - df <- createDataFrame(sqlContext, list(l), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"))) + l <- list(as.list(1:10), list("a", "b"), e) + df <- createDataFrame(sqlContext, list(l), c("a", "b", "c")) + expect_equal(dtypes(df), list(c("a", "array"), + c("b", "array"), + c("c", "map"))) expect_equal(count(df), 1) ldf <- collect(df) - expect_equal(names(ldf), c("a", "b")) + expect_equal(names(ldf), c("a", "b", "c")) expect_equal(ldf[1, 1][[1]], l[[1]]) expect_equal(ldf[1, 2][[1]], l[[2]]) + e <- ldf$c[[1]] + expect_equal(class(e), "environment") + expect_equal(ls(e), "n") + expect_equal(e$n, 3L) }) +# For test map type in DataFrame +mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", + "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", + "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") +mapTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesMapType, mapTypeJsonPath) + test_that("Collect DataFrame with complex types", { - # only ArrayType now - # TODO: tests for StructType and MapType after they are supported + # ArrayType df <- jsonFile(sqlContext, complexTypeJsonPath) ldf <- collect(df) @@ -277,6 +289,24 @@ test_that("Collect DataFrame with complex types", { expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) + + # MapType + schema <- structType(structField("name", "string"), + structField("info", "map")) + df <- read.df(sqlContext, mapTypeJsonPath, "json", schema) + expect_equal(dtypes(df), list(c("name", "string"), + c("info", "map"))) + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 2) + expect_equal(names(ldf), c("name", "info")) + expect_equal(ldf$name, c("Bob", "Alice", "David")) + bob <- ldf$info[[1]] + expect_equal(class(bob), "environment") + expect_equal(bob$age, 16) + expect_equal(bob$height, 176.5) + + # TODO: tests for StructType after it is supported }) test_that("jsonFile() on a local file returns a DataFrame", { diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 3c92bb7a1c73c..6c2d37d2ee002 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -21,6 +21,7 @@ import java.io.{DataInputStream, DataOutputStream} import java.sql.{Timestamp, Date, Time} import scala.collection.JavaConverters._ +import scala.collection.JavaConversions._ import scala.collection.mutable.WrappedArray /** @@ -209,6 +210,7 @@ private[spark] object SerDe { case "array" => dos.writeByte('a') // Array of objects case "list" => dos.writeByte('l') + case "map" => dos.writeByte('e') case "jobj" => dos.writeByte('j') case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") } @@ -306,6 +308,39 @@ private[spark] object SerDe { writeInt(dos, v.length) v.foreach(elem => writeObject(dos, elem)) + // Handle map + case v: java.util.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + val iter = v.entrySet.iterator + while(iter.hasNext) { + val entry = iter.next + val key = entry.getKey + + if (key == null) { + throw new IllegalArgumentException("Key in map can't be null.") + } else if (!key.isInstanceOf[String]) { + throw new IllegalArgumentException(s"Invalid map key type: ${key.getClass.getName}") + } + writeString(dos, key.asInstanceOf[String]) + + val value = entry.getValue + writeObject(dos, value.asInstanceOf[Object]) + } + case v: scala.collection.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + v.foreach { case (key, value) => + if (key == null) { + throw new IllegalArgumentException("Key in map can't be null.") + } else if (!key.isInstanceOf[String]) { + throw new IllegalArgumentException(s"Invalid map key type: ${key.getClass.getName}") + } + + writeString(dos, key.asInstanceOf[String]) + writeObject(dos, value.asInstanceOf[Object]) + } + case _ => writeType(dos, "jobj") writeJObj(dos, value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index d4b834adb6e39..19e1d6c8db436 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -64,6 +64,9 @@ private[r] object SQLUtils { case r"\Aarray<(.*)${elemType}>\Z" => { org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType)) } + case r"\Amap<(.*)${keyType},(.*)${valueType}>\Z" => { + org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType)) + } case _ => throw new IllegalArgumentException(s"Invaid type $dataType") } } From 09e239637d632310321ebeb22ca80c82b67c24fb Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Sun, 6 Sep 2015 13:45:45 +0800 Subject: [PATCH 2/6] Extract common logic into a new private method: writeKeyValue(). --- .../scala/org/apache/spark/api/r/SerDe.scala | 31 +++++++++---------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 6c2d37d2ee002..d1d7a73a73739 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -216,6 +216,17 @@ private[spark] object SerDe { } } + private def writeKeyValue(dos: DataOutputStream, key: Object, value: Object): Unit = { + if (key == null) { + throw new IllegalArgumentException("Key in map can't be null.") + } else if (!key.isInstanceOf[String]) { + throw new IllegalArgumentException(s"Invalid map key type: ${key.getClass.getName}") + } + + writeString(dos, key.asInstanceOf[String]) + writeObject(dos, value) + } + def writeObject(dos: DataOutputStream, obj: Object): Unit = { if (obj == null) { writeType(dos, "void") @@ -316,29 +327,15 @@ private[spark] object SerDe { while(iter.hasNext) { val entry = iter.next val key = entry.getKey - - if (key == null) { - throw new IllegalArgumentException("Key in map can't be null.") - } else if (!key.isInstanceOf[String]) { - throw new IllegalArgumentException(s"Invalid map key type: ${key.getClass.getName}") - } - writeString(dos, key.asInstanceOf[String]) - val value = entry.getValue - writeObject(dos, value.asInstanceOf[Object]) + + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) } case v: scala.collection.Map[_, _] => writeType(dos, "map") writeInt(dos, v.size) v.foreach { case (key, value) => - if (key == null) { - throw new IllegalArgumentException("Key in map can't be null.") - } else if (!key.isInstanceOf[String]) { - throw new IllegalArgumentException(s"Invalid map key type: ${key.getClass.getName}") - } - - writeString(dos, key.asInstanceOf[String]) - writeObject(dos, value.asInstanceOf[Object]) + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) } case _ => From 74558ebdfb55b799a663043d4e1d101eec0f09d8 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Fri, 11 Sep 2015 12:51:06 +0800 Subject: [PATCH 3/6] Fix scala style. --- core/src/main/scala/org/apache/spark/api/r/SerDe.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index d1d7a73a73739..0c78613e406e1 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -21,7 +21,6 @@ import java.io.{DataInputStream, DataOutputStream} import java.sql.{Timestamp, Date, Time} import scala.collection.JavaConverters._ -import scala.collection.JavaConversions._ import scala.collection.mutable.WrappedArray /** @@ -337,7 +336,7 @@ private[spark] object SerDe { v.foreach { case (key, value) => writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) } - + case _ => writeType(dos, "jobj") writeJObj(dos, value) From 748fb070a4da9a920d43bb332c6a1827b183d21c Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Mon, 14 Sep 2015 16:18:47 +0800 Subject: [PATCH 4/6] Check if key type of a map is string when creating structField in R. --- R/pkg/R/schema.R | 4 ++-- .../src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 2f6b0ad2637a3..d184436a65bfa 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -131,7 +131,7 @@ checkType <- function(type) { if (type %in% primtiveTypes) { return() } else { - # Array type + # Check complex types firstChar <- substr(type, 1, 1) switch (firstChar, a = { @@ -150,8 +150,8 @@ checkType <- function(type) { matchedStrings <- regmatches(type, m) if (length(matchedStrings[[1]]) >= 3) { keyType <- matchedStrings[[1]][2] + stopifnot (keyType == "string" || keyType == "character") valueType <- matchedStrings[[1]][3] - checkType(keyType) checkType(valueType) return() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 19e1d6c8db436..f45d119c8cfdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -65,6 +65,9 @@ private[r] object SQLUtils { org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType)) } case r"\Amap<(.*)${keyType},(.*)${valueType}>\Z" => { + if (keyType != "string" && keyType != "character") { + throw new IllegalArgumentException("Key type of a map must be string or character") + } org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType)) } case _ => throw new IllegalArgumentException(s"Invaid type $dataType") From f0e52e096530eb7c31a5c468b886fe3b3be7a5d2 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Tue, 15 Sep 2015 13:16:47 +0800 Subject: [PATCH 5/6] Add a test case for checking map type. --- R/pkg/R/schema.R | 4 +++- R/pkg/inst/tests/test_sparkSQL.R | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index d184436a65bfa..8df1563f8ebc0 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -150,7 +150,9 @@ checkType <- function(type) { matchedStrings <- regmatches(type, m) if (length(matchedStrings[[1]]) >= 3) { keyType <- matchedStrings[[1]][2] - stopifnot (keyType == "string" || keyType == "character") + if (keyType != "string" && keyType != "character") { + stop("Key type in a map must be string or character") + } valueType <- matchedStrings[[1]][3] checkType(valueType) return() diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index af9e668d9eb13..b17d6379304ad 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -57,7 +57,7 @@ mockLinesComplexType <- complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLinesComplexType, complexTypeJsonPath) -test_that("infer types", { +test_that("infer types and check types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") expect_equal(infer_type("abc"), "string") @@ -73,6 +73,8 @@ test_that("infer types", { e <- new.env() assign("a", 1L, envir = e) expect_equal(infer_type(e), "map") + + expect_error(checkType("map"), "Key type in a map must be string or character") }) test_that("structType and structField", { From 6daf126099d6ed716ce10a1928f0224bd17e077d Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Tue, 15 Sep 2015 16:16:35 +0800 Subject: [PATCH 6/6] Fix R style. --- R/pkg/inst/tests/test_sparkSQL.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index b17d6379304ad..39d048c6cb09a 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -73,7 +73,7 @@ test_that("infer types and check types", { e <- new.env() assign("a", 1L, envir = e) expect_equal(infer_type(e), "map") - + expect_error(checkType("map"), "Key type in a map must be string or character") })