Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-4394][SQL] Data Sources API Improvements #3260

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
* Optimized version of In clause, when all filter values of In clause are
* static.
*/
case class InSet(value: Expression, hset: HashSet[Any], child: Seq[Expression])
case class InSet(value: Expression, hset: Set[Any])
extends Predicate {

def children = child
def children = value :: Nil

def nullable = true // TODO: Figure out correct nullability semantics of IN.
override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ object OptimizeIn extends Rule[LogicalPlan] {
case q: LogicalPlan => q transformExpressionsDown {
case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) =>
val hSet = list.map(e => e.eval(null))
InSet(v, HashSet() ++ hSet, v +: list)
InSet(v, HashSet() ++ hSet)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,13 @@ class ExpressionEvaluationSuite extends FunSuite {
val nl = Literal(null)
val s = Seq(one, two)
val nullS = Seq(one, two, null)
checkEvaluation(InSet(one, hS, one +: s), true)
checkEvaluation(InSet(two, hS, two +: s), true)
checkEvaluation(InSet(two, nS, two +: nullS), true)
checkEvaluation(InSet(nl, nS, nl +: nullS), true)
checkEvaluation(InSet(three, hS, three +: s), false)
checkEvaluation(InSet(three, nS, three +: nullS), false)
checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true)
checkEvaluation(InSet(one, hS), true)
checkEvaluation(InSet(two, hS), true)
checkEvaluation(InSet(two, nS), true)
checkEvaluation(InSet(nl, nS), true)
checkEvaluation(InSet(three, hS), false)
checkEvaluation(InSet(three, nS), false)
checkEvaluation(InSet(one, hS) && InSet(two, hS), true)
}

test("MaxOf") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ class OptimizeInSuite extends PlanTest {
val optimized = Optimize(originalQuery.analyze)
val correctAnswer =
testRelation
.where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2,
UnresolvedAttribute("a") +: Seq(Literal(1),Literal(2))))
.where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2))
.analyze

comparePlans(optimized, correctAnswer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,5 +108,7 @@ private[sql] object DataSourceStrategy extends Strategy {

case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v)
case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v)

case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ private[sql] case class LogicalRelation(relation: BaseRelation)
}

@transient override lazy val statistics = Statistics(
// TODO: Allow datasources to provide statistics as well.
sizeInBytes = BigInt(relation.sqlContext.defaultSizeInBytes)
sizeInBytes = BigInt(relation.sizeInBytes)
)

/** Used to lookup original attribute capitalization */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ case class GreaterThan(attribute: String, value: Any) extends Filter
case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter
case class LessThan(attribute: String, value: Any) extends Filter
case class LessThanOrEqual(attribute: String, value: Any) extends Filter
case class In(attribute: String, values: Array[Any]) extends Filter
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.sources

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext, StructType}
import org.apache.spark.sql.{SQLConf, Row, SQLContext, StructType}
import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute}

/**
Expand Down Expand Up @@ -53,6 +53,15 @@ trait RelationProvider {
abstract class BaseRelation {
def sqlContext: SQLContext
def schema: StructType

/**
* Returns an estimated size of this relation in bytes. This information is used by the planner
* to decided when it is safe to broadcast a relation and can be overridden by sources that
* know the size ahead of time. By default, the system will assume that tables are too
* large to broadcast. This method will be called multiple times during query planning
* and thus should not perform expensive operations for each invocation.
*/
def sizeInBytes = sqlContext.defaultSizeInBytes
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL
case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v
case GreaterThan("a", v: Int) => (a: Int) => a > v
case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v
case In("a", values) => (a: Int) => values.map(_.asInstanceOf[Int]).toSet.contains(a)
}

def eval(a: Int) = !filterFunctions.map(_(a)).contains(false)
Expand Down Expand Up @@ -121,6 +122,10 @@ class FilteredScanSuite extends DataSourceTest {
"SELECT * FROM oneToTenFiltered WHERE a = 1",
Seq(1).map(i => Row(i, i * 2)).toSeq)

sqlTest(
"SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)",
Seq(1,3,5).map(i => Row(i, i * 2)).toSeq)

sqlTest(
"SELECT * FROM oneToTenFiltered WHERE A = 1",
Seq(1).map(i => Row(i, i * 2)).toSeq)
Expand Down Expand Up @@ -150,6 +155,8 @@ class FilteredScanSuite extends DataSourceTest {

testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8)

testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3)

testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0)
testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10)

Expand Down