Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@

package org.apache.spark.sql.connector.read;

import java.util.ArrayList;
import java.util.List;

import org.apache.spark.annotation.Experimental;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.filter.Predicate;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.internal.connector.PredicateUtils;

import scala.Option;

/**
* A mix-in interface for {@link Scan}. Data sources can implement this interface if they can
Expand All @@ -30,7 +37,7 @@
* @since 3.2.0
*/
@Experimental
public interface SupportsRuntimeFiltering extends Scan {
public interface SupportsRuntimeFiltering extends Scan, SupportsRuntimeV2Filtering {
/**
* Returns attributes this scan can be filtered by at runtime.
* <p>
Expand All @@ -57,4 +64,18 @@ public interface SupportsRuntimeFiltering extends Scan {
* @param filters data source filters used to filter the scan at runtime
*/
void filter(Filter[] filters);

default void filter(Predicate[] predicates) {
List<Filter> filterList = new ArrayList();

for (int i = 0; i < predicates.length; i++) {
Option filter = PredicateUtils.toV1(predicates[i]);
if (filter.nonEmpty()) {
filterList.add((Filter)filter.get());
}
}

Filter[] filters = new Filter[filterList.size()];
this.filter(filterList.toArray(filters));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.read;

import org.apache.spark.annotation.Experimental;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.filter.Predicate;
import org.apache.spark.sql.sources.Filter;

/**
* A mix-in interface for {@link Scan}. Data sources can implement this interface if they can
* filter initially planned {@link InputPartition}s using predicates Spark infers at runtime.
* This interface is very similar to {@link SupportsRuntimeFiltering} except it uses
* data source V2 {@link Predicate} instead of data source V1 {@link Filter}.
* {@link SupportsRuntimeV2Filtering} is preferred over {@link SupportsRuntimeFiltering}
* and only one of them should be implemented by the data sources.
*
* <p>
* Note that Spark will push runtime filters only if they are beneficial.
*
* @since 3.4.0
*/
@Experimental
public interface SupportsRuntimeV2Filtering extends Scan {
/**
* Returns attributes this scan can be filtered by at runtime.
* <p>
* Spark will call {@link #filter(Predicate[])} if it can derive a runtime
* predicate for any of the filter attributes.
*/
NamedReference[] filterAttributes();

/**
* Filters this scan using runtime predicates.
* <p>
* The provided expressions must be interpreted as a set of predicates that are ANDed together.
* Implementations may use the predicates to prune initially planned {@link InputPartition}s.
* <p>
* If the scan also implements {@link SupportsReportPartitioning}, it must preserve
* the originally reported partitioning during runtime filtering. While applying runtime
* predicates, the scan may detect that some {@link InputPartition}s have no matching data. It
* can omit such partitions entirely only if it does not report a specific partitioning.
* Otherwise, the scan can replace the initially planned {@link InputPartition}s that have no
* matching data with empty {@link InputPartition}s but must preserve the overall number of
* partitions.
* <p>
* Note that Spark will call {@link Scan#toBatch()} again after filtering the scan at runtime.
*
* @param predicates data source V2 predicates used to filter the scan at runtime
*/
void filter(Predicate[] predicates);
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ object Literal {
case b: Byte => Literal(b, ByteType)
case s: Short => Literal(s, ShortType)
case s: String => Literal(UTF8String.fromString(s), StringType)
case s: UTF8String => Literal(s, StringType)
case c: Char => Literal(UTF8String.fromString(c.toString), StringType)
case ac: Array[Char] => Literal(UTF8String.fromString(String.valueOf(ac)), StringType)
case b: Boolean => Literal(b, BooleanType)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.internal.connector

import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.connector.expressions.{LiteralValue, NamedReference}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.sources.{Filter, In}

private[sql] object PredicateUtils {

def toV1(predicate: Predicate): Option[Filter] = {
predicate.name() match {
// TODO: add conversion for other V2 Predicate
case "IN" if predicate.children()(0).isInstanceOf[NamedReference] =>
val attribute = predicate.children()(0).toString
val values = predicate.children().drop(1)
if (values.length > 0) {
if (!values.forall(_.isInstanceOf[LiteralValue[_]])) return None
val dataType = values(0).asInstanceOf[LiteralValue[_]].dataType
if (!values.forall(_.asInstanceOf[LiteralValue[_]].dataType.sameType(dataType))) {
return None
}
val inValues = values.map(v =>
CatalystTypeConverters.convertToScala(v.asInstanceOf[LiteralValue[_]].value, dataType))
Some(In(attribute, inValues))
} else {
Some(In(attribute, Array.empty[Any]))
}

case _ => None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,11 @@ class InMemoryTable(

case class InMemoryStats(sizeInBytes: OptionalLong, numRows: OptionalLong) extends Statistics

case class InMemoryBatchScan(
abstract class BatchScanBaseClass(
var data: Seq[InputPartition],
readSchema: StructType,
tableSchema: StructType)
extends Scan with Batch with SupportsRuntimeFiltering with SupportsReportStatistics
with SupportsReportPartitioning {
extends Scan with Batch with SupportsReportStatistics with SupportsReportPartitioning {

override def toBatch: Batch = this

Expand Down Expand Up @@ -308,6 +307,13 @@ class InMemoryTable(
val nonMetadataColumns = readSchema.filterNot(f => metadataColumns.contains(f.name))
new BufferedRowsReaderFactory(metadataColumns, nonMetadataColumns, tableSchema)
}
}

case class InMemoryBatchScan(
var _data: Seq[InputPartition],
readSchema: StructType,
tableSchema: StructType)
extends BatchScanBaseClass (_data, readSchema, tableSchema) with SupportsRuntimeFiltering {

override def filterAttributes(): Array[NamedReference] = {
val scanFields = readSchema.fields.map(_.name).toSet
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.catalog

import java.util

import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue, NamedReference, Transform}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{InputPartition, Scan, ScanBuilder, SupportsRuntimeV2Filtering}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class InMemoryTableWithV2Filter(
name: String,
schema: StructType,
partitioning: Array[Transform],
properties: util.Map[String, String])
extends InMemoryTable(name, schema, partitioning, properties) {

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new InMemoryV2FilterScanBuilder(schema)
}

class InMemoryV2FilterScanBuilder(tableSchema: StructType)
extends InMemoryScanBuilder(tableSchema) {
override def build: Scan =
InMemoryV2FilterBatchScan(data.map(_.asInstanceOf[InputPartition]), schema, tableSchema)
}

case class InMemoryV2FilterBatchScan(
var _data: Seq[InputPartition],
readSchema: StructType,
tableSchema: StructType)
extends BatchScanBaseClass (_data, readSchema, tableSchema) with SupportsRuntimeV2Filtering {

override def filterAttributes(): Array[NamedReference] = {
val scanFields = readSchema.fields.map(_.name).toSet
partitioning.flatMap(_.references)
.filter(ref => scanFields.contains(ref.fieldNames.mkString(".")))
}

override def filter(filters: Array[Predicate]): Unit = {
if (partitioning.length == 1 && partitioning.head.references().length == 1) {
val ref = partitioning.head.references().head
filters.foreach {
case p : Predicate if p.name().equals("IN") =>
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels like some unapply method to extract what you want is more preferable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Predicate is a java class. I don't think unapply can be used

if (p.children().length > 1) {
val filterRef = p.children()(0).asInstanceOf[FieldReference].references.head
if (filterRef.toString.equals(ref.toString)) {
val matchingKeys =
p.children().drop(1).map(_.asInstanceOf[LiteralValue[_]].value.toString).toSet
data = data.filter(partition => {
val key = partition.asInstanceOf[BufferedRows].keyString
matchingKeys.contains(key)
})
}
}
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.catalog

import java.util

import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.types.StructType

class InMemoryTableWithV2FilterCatalog extends InMemoryTableCatalog {
import CatalogV2Implicits._

override def createTable(
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
properties: util.Map[String, String]): Table = {
if (tables.containsKey(ident)) {
throw new TableAlreadyExistsException(ident)
}

InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)

val tableName = s"$name.${ident.quoted}"
val table = new InMemoryTableWithV2Filter(tableName, schema, partitions, properties)
tables.put(ident, table)
namespaces.putIfAbsent(ident.namespace.toList, Map())
table
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators
import org.apache.spark.sql.execution.streaming.StreamingRelation
Expand Down Expand Up @@ -652,25 +652,6 @@ object DataSourceStrategy
}
}

/**
* Translates a runtime filter into a data source filter.
*
* Runtime filters usually contain a subquery that must be evaluated before the translation.
* If the underlying subquery hasn't completed yet, this method will throw an exception.
*/
protected[sql] def translateRuntimeFilter(expr: Expression): Option[Filter] = expr match {
case in @ InSubqueryExec(e @ PushableColumnAndNestedColumn(name), _, _, _, _, _) =>
val values = in.values().getOrElse {
throw new IllegalStateException(s"Can't translate $in to source filter, no subquery result")
}
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
Some(sources.In(name, values.map(toScala)))

case other =>
logWarning(s"Can't translate $other to source filter, unsupported expression")
None
}

/**
* Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s
* and can be handled by `relation`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.InternalRowSet
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeFiltering}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeV2Filtering}

/**
* Physical plan node for scanning a batch of data from a data source v2.
Expand Down Expand Up @@ -56,15 +55,15 @@ case class BatchScanExec(

@transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = {
val dataSourceFilters = runtimeFilters.flatMap {
case DynamicPruningExpression(e) => DataSourceStrategy.translateRuntimeFilter(e)
case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e)
case _ => None
}

if (dataSourceFilters.nonEmpty) {
val originalPartitioning = outputPartitioning

// the cast is safe as runtime filters are only assigned if the scan can be filtered
val filterableScan = scan.asInstanceOf[SupportsRuntimeFiltering]
val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering]
filterableScan.filter(dataSourceFilters.toArray)

// call toBatch again to get filtered partitions
Expand Down
Loading