Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
9cf1ebf
Adds higher order functions to scala API
nvander1 Mar 28, 2019
efc6ba4
Add (Scala-specifc) note to higher order functions
nvander1 Mar 28, 2019
b9dceec
Follow style guide more closely
nvander1 Mar 28, 2019
1fb46a3
Fix scalastyle issues
nvander1 Mar 28, 2019
03d602f
Add java-specific version of higher order function api
nvander1 Mar 28, 2019
6bf07d8
Do not prematurely bind lambda variables
nvander1 Jun 14, 2019
b03399a
Resolve conflict between Java Function and Scala Function
HyukjinKwon Jul 25, 2019
79d6f84
Adds higher order functions to scala API
nvander1 Mar 28, 2019
7adaf9c
Add (Scala-specifc) note to higher order functions
nvander1 Mar 28, 2019
ac5c1c2
Follow style guide more closely
nvander1 Mar 28, 2019
40ac418
Fix scalastyle issues
nvander1 Mar 28, 2019
fb5f8ef
Add java-specific version of higher order function api
nvander1 Mar 28, 2019
85979d4
Do not prematurely bind lambda variables
nvander1 Jun 14, 2019
5d389d2
Merge branch 'fix-24232' of git://github.com/HyukjinKwon/spark into H…
nvander1 Aug 2, 2019
5d77d6b
Merge branch 'master' into feature/add_higher_order_functions_to_scal…
nvander1 Aug 6, 2019
a8c7ecd
Add forall to org.apache.spark.sql.functions
nvander1 Aug 6, 2019
96fb0ad
Add "@since 3.0.0" to new functions
nvander1 Aug 10, 2019
5fa3e71
Add tests for Java transform function
nvander1 Aug 10, 2019
0bfa483
Add tests for Java map_filter function
nvander1 Aug 10, 2019
815e9f6
Add tests for Java filter function
nvander1 Aug 10, 2019
47b100b
Add tests for Java exists function
nvander1 Aug 10, 2019
4baf084
Add test for Java API forall
nvander1 Aug 19, 2019
9c0f70e
Merge branch 'master' into feature/add_higher_order_functions_to_scal…
nvander1 Aug 19, 2019
06b4c82
Add test for Java API: aggregate
nvander1 Aug 19, 2019
412ece5
Add test for Java API: map_zip_with
nvander1 Aug 19, 2019
c49e7d3
Add java tests for transform_keys, transform_values
nvander1 Aug 21, 2019
182a08b
Add tests for java zip_with function
nvander1 Aug 21, 2019
ef6b6bb
Remove JavaFunction overloads and add Java transform test
nvander1 Aug 21, 2019
a543c90
Merge branch 'tmp' into feature/add_higher_order_functions_to_scala_api
nvander1 Aug 21, 2019
527c0cb
Remove (Scala-specifc) from higher order functions
nvander1 Aug 21, 2019
013187f
Remove java tests from DataFrameFunctionsSuite
nvander1 Sep 17, 2019
554a992
Add simple java test for filter
nvander1 Sep 18, 2019
0433756
Add simple java test for exists
nvander1 Sep 18, 2019
f371413
Add simple java test for forall
nvander1 Sep 18, 2019
c3e320c
Add java test for aggregate
nvander1 Sep 19, 2019
84ccf55
Add java aggregate test with finish
nvander1 Sep 19, 2019
e43033b
Add java test for zip_with
nvander1 Sep 19, 2019
c1c76a9
Add java test for transformKeys
nvander1 Sep 20, 2019
10a5f2e
Add java test for transform_values
nvander1 Sep 20, 2019
722f0e6
Add java test for map_filter and map_zip_with
nvander1 Sep 20, 2019
1bf2654
Fix style nits
nvander1 Oct 2, 2019
64c0f87
Fix linter errors in imports
nvander1 Oct 2, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3385,6 +3385,167 @@ object functions {
ArrayExcept(col1.expr, col2.expr)
}

private def createLambda(f: Column => Column) = {
val x = UnresolvedNamedLambdaVariable(Seq("x"))
val function = f(Column(x)).expr
LambdaFunction(function, Seq(x))
}

private def createLambda(f: (Column, Column) => Column) = {
val x = UnresolvedNamedLambdaVariable(Seq("x"))
val y = UnresolvedNamedLambdaVariable(Seq("y"))
val function = f(Column(x), Column(y)).expr
LambdaFunction(function, Seq(x, y))
}

private def createLambda(f: (Column, Column, Column) => Column) = {
val x = UnresolvedNamedLambdaVariable(Seq("x"))
val y = UnresolvedNamedLambdaVariable(Seq("y"))
val z = UnresolvedNamedLambdaVariable(Seq("z"))
val function = f(Column(x), Column(y), Column(z)).expr
LambdaFunction(function, Seq(x, y, z))
}

/**
* Returns an array of elements after applying a tranformation to each element
* in the input array.
*
* @group collection_funcs
* @since 3.0.0
*/
def transform(column: Column, f: Column => Column): Column = withExpr {
ArrayTransform(column.expr, createLambda(f))
}

/**
* Returns an array of elements after applying a tranformation to each element
* in the input array.
*
* @group collection_funcs
* @since 3.0.0
*/
def transform(column: Column, f: (Column, Column) => Column): Column = withExpr {
ArrayTransform(column.expr, createLambda(f))
}

/**
* Returns whether a predicate holds for one or more elements in the array.
*
* @group collection_funcs
* @since 3.0.0
*/
def exists(column: Column, f: Column => Column): Column = withExpr {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But how do we support this in Java?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we change the signatures to accept scala.runtime.AbstractFunctions instead to avoid using the Function traits?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add (Scala-specific) at least for each doc. BTW, please take a look for style guide at https://github.com/databricks/scala-style-guide

Copy link
Contributor Author

@nvander1 nvander1 Mar 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually a better idea would probably be to use java functional interfaces.

@FuncitonalInterface
interface Function3[T1, T2, T3, R] {
  R apply(T1 t1, T2 t2, T3 t3);
}

Column map_zip_with(Column left, Column right, Function3[Column, Column, Column, Column] f) = ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And of course we would use the existing functional interfaces first from java.util.function, but I don't think there are any that accept three parameters likes some of the functions here require.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears these interfaces already exist in the source tree: https://github.com/apache/spark/blob/v2.4.0/core/src/main/java/org/apache/spark/api/java/function/Function3.java

I'll come back later to add java-specific apis that utilizes these.

ArrayExists(column.expr, createLambda(f))
}

/**
* Returns whether a predicate holds for every element in the array.
*
* @group collection_funcs
* @since 3.0.0
*/
def forall(column: Column, f: Column => Column): Column = withExpr {
ArrayForAll(column.expr, createLambda(f))
}

/**
* Returns an array of elements for which a predicate holds in a given array.
*
* @group collection_funcs
* @since 3.0.0
*/
def filter(column: Column, f: Column => Column): Column = withExpr {
ArrayFilter(column.expr, createLambda(f))
}

/**
* Applies a binary operator to an initial state and all elements in the array,
* and reduces this to a single state. The final state is converted into the final result
* by applying a finish function.
*
* @group collection_funcs
* @since 3.0.0
*/
def aggregate(
expr: Column,
zero: Column,
merge: (Column, Column) => Column,
finish: Column => Column): Column = withExpr {
ArrayAggregate(
expr.expr,
zero.expr,
createLambda(merge),
createLambda(finish)
)
}

/**
* Applies a binary operator to an initial state and all elements in the array,
* and reduces this to a single state.
*
* @group collection_funcs
* @since 3.0.0
*/
def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column): Column =
aggregate(expr, zero, merge, c => c)

/**
* Merge two given arrays, element-wise, into a signle array using a function.
* If one array is shorter, nulls are appended at the end to match the length of the longer
* array, before applying the function.
*
* @group collection_funcs
* @since 3.0.0
*/
def zip_with(left: Column, right: Column, f: (Column, Column) => Column): Column = withExpr {
ZipWith(left.expr, right.expr, createLambda(f))
}

/**
* Applies a function to every key-value pair in a map and returns
* a map with the results of those applications as the new keys for the pairs.
*
* @group collection_funcs
* @since 3.0.0
*/
def transform_keys(expr: Column, f: (Column, Column) => Column): Column = withExpr {
TransformKeys(expr.expr, createLambda(f))
}

/**
* Applies a function to every key-value pair in a map and returns
* a map with the results of those applications as the new values for the pairs.
*
* @group collection_funcs
* @since 3.0.0
*/
def transform_values(expr: Column, f: (Column, Column) => Column): Column = withExpr {
TransformValues(expr.expr, createLambda(f))
}

/**
* Returns a map whose key-value pairs satisfy a predicate.
*
* @group collection_funcs
* @since 3.0.0
*/
def map_filter(expr: Column, f: (Column, Column) => Column): Column = withExpr {
MapFilter(expr.expr, createLambda(f))
}

/**
* Merge two given maps, key-wise into a single map using a function.
*
* @group collection_funcs
* @since 3.0.0
*/
def map_zip_with(
left: Column,
right: Column,
f: (Column, Column, Column) => Column): Column = withExpr {
MapZipWith(left.expr, right.expr, createLambda(f))
}

/**
* Creates a new row for each element in the given array or map column.
* Uses the default column name `col` for elements in the array and
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package test.org.apache.spark.sql;

import java.util.HashMap;
import java.util.List;

import static scala.collection.JavaConverters.mapAsScalaMap;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.*;
import static org.apache.spark.sql.types.DataTypes.*;
import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.test.TestSparkSession;
import static test.org.apache.spark.sql.JavaTestUtils.*;

public class JavaHigherOrderFunctionsSuite {
private transient TestSparkSession spark;
private Dataset<Row> arrDf;
private Dataset<Row> mapDf;

private void setUpArrDf() {
List<Row> data = toRows(
makeArray(1, 9, 8, 7),
makeArray(5, 8, 9, 7, 2),
JavaTestUtils.<Integer>makeArray(),
null
);
StructType schema = new StructType()
.add("x", new ArrayType(IntegerType, true), true);
arrDf = spark.createDataFrame(data, schema);
}

private void setUpMapDf() {
List<Row> data = toRows(
new HashMap<Integer, Integer>() {{
put(1, 1);
put(2, 2);
}},
null
);
StructType schema = new StructType()
.add("x", new MapType(IntegerType, IntegerType, true));
mapDf = spark.createDataFrame(data, schema);
}

@Before
public void setUp() {
spark = new TestSparkSession();
setUpArrDf();
setUpMapDf();
}

@After
public void tearDown() {
spark.stop();
spark = null;
}

@Test
public void testTransform() {
checkAnswer(
arrDf.select(transform(col("x"), x -> x.plus(1))),
toRows(
makeArray(2, 10, 9, 8),
makeArray(6, 9, 10, 8, 3),
JavaTestUtils.<Integer>makeArray(),
null
)
);
checkAnswer(
arrDf.select(transform(col("x"), (x, i) -> x.plus(i))),
toRows(
makeArray(1, 10, 10, 10),
makeArray(5, 9, 11, 10, 6),
JavaTestUtils.<Integer>makeArray(),
null
)
);
}

@Test
public void testFilter() {
checkAnswer(
arrDf.select(filter(col("x"), x -> x.plus(1).equalTo(10))),
toRows(
makeArray(9),
makeArray(9),
JavaTestUtils.<Integer>makeArray(),
null
)
);
}

@Test
public void testExists() {
checkAnswer(
arrDf.select(exists(col("x"), x -> x.plus(1).equalTo(10))),
toRows(
true,
true,
false,
null
)
);
}

@Test
public void testForall() {
checkAnswer(
arrDf.select(forall(col("x"), x -> x.plus(1).equalTo(10))),
toRows(
false,
false,
true,
null
)
);
}

@Test
public void testAggregate() {
checkAnswer(
arrDf.select(aggregate(col("x"), lit(0), (acc, x) -> acc.plus(x))),
toRows(
25,
31,
0,
null
)
);
checkAnswer(
arrDf.select(aggregate(col("x"), lit(0), (acc, x) -> acc.plus(x), x -> x)),
toRows(
25,
31,
0,
null
)
);
}

@Test
public void testZipWith() {
checkAnswer(
arrDf.select(zip_with(col("x"), col("x"), (a, b) -> lit(42))),
toRows(
makeArray(42, 42, 42, 42),
makeArray(42, 42, 42, 42, 42),
JavaTestUtils.<Integer>makeArray(),
null
)
);
}

@Test
public void testTransformKeys() {
checkAnswer(
mapDf.select(transform_keys(col("x"), (k, v) -> k.plus(v))),
toRows(
mapAsScalaMap(new HashMap<Integer, Integer>() {{
put(2, 1);
put(4, 2);
}}),
null
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: style. one more indent?

)
);
}

@Test
public void testTransformValues() {
checkAnswer(
mapDf.select(transform_values(col("x"), (k, v) -> k.plus(v))),
toRows(
mapAsScalaMap(new HashMap<Integer, Integer>() {{
put(1, 2);
put(2, 4);
}}),
null
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    @Test
    public void testTransformValues() {
        checkAnswer(
            mapDf.select(transform_values(col("x"), (k, v) -> k.plus(v))),
            toRows(
                mapAsScalaMap(new HashMap<Integer, Integer>() {{
                    put(1, 2);
                    put(2, 4);
                }}),
                null
            )
        );
    }

Does this work as well? I've moved the new HashMap up a line. @ueshin

Also, what is the general preference in the codebase, each paren and brace on a new line?

Or the more "lispy" style of every close on the same line:

    @Test
    public void testTransformValues() {
        checkAnswer(
            mapDf.select(transform_values(col("x"), (k, v) -> k.plus(v))),
            toRows(
                mapAsScalaMap(new HashMap<Integer, Integer>() {{
                    put(1, 2);
                    put(2, 4);}}),
                null));
    }

I've seen a mixture of the two to various degrees in the code, I edited this file to at least be consistent with itself (the exception here being the mapAsScalaMap / hashmap since it really is its own entity just being converted to a scala equivalent.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the first one is more preferred.
The second one needs a line break at the end of HashMap since it's a block:

                mapAsScalaMap(new HashMap<Integer, Integer>() {{
                    put(1, 2);
                    put(2, 4);
                }}),
                null));

I'm not quite sure about the parentheses after null. Maybe we need a line break as well.

As for my comment, sorry, maybe my pointer was wrong.
I meant new HashMap ... should be on one more indent.

                mapAsScalaMap(
                    new HashMap<Integer, Integer>() {{	
                        put(1, 2);
                        put(2, 4);
                    }}
                ),
                null

)
);
}

@Test
public void testMapFilter() {
checkAnswer(
mapDf.select(map_filter(col("x"), (k, v) -> lit(false))),
toRows(
mapAsScalaMap(new HashMap<Integer, Integer>()),
null
)
);
}

@Test
public void testMapZipWith() {
checkAnswer(
mapDf.select(map_zip_with(col("x"), col("x"), (k, v1, v2) -> lit(false))),
toRows(
mapAsScalaMap(new HashMap<Integer, Boolean>() {{
put(1, false);
put(2, false);
}}),
null
)
);
}
}
Loading