Skip to content

Commit 7bfa5ad

Browse files
committed
Merge pull request #2 from isnotinvain/alexlevenson/simplify-udp-state
Simplify user defined predicates with state, Add more test cases
2 parents 51952f8 + 0187376 commit 7bfa5ad

File tree

6 files changed

+113
-88
lines changed

6 files changed

+113
-88
lines changed

parquet-column/src/main/java/parquet/filter2/predicate/FilterApi.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import parquet.filter2.predicate.Operators.BinaryColumn;
88
import parquet.filter2.predicate.Operators.BooleanColumn;
99
import parquet.filter2.predicate.Operators.Column;
10-
import parquet.filter2.predicate.Operators.ConfiguredUserDefined;
1110
import parquet.filter2.predicate.Operators.DoubleColumn;
1211
import parquet.filter2.predicate.Operators.Eq;
1312
import parquet.filter2.predicate.Operators.FloatColumn;
@@ -22,8 +21,9 @@
2221
import parquet.filter2.predicate.Operators.Or;
2322
import parquet.filter2.predicate.Operators.SupportsEqNotEq;
2423
import parquet.filter2.predicate.Operators.SupportsLtGt;
25-
import parquet.filter2.predicate.Operators.SimpleUserDefined;
2624
import parquet.filter2.predicate.Operators.UserDefined;
25+
import parquet.filter2.predicate.Operators.UserDefinedByClass;
26+
import parquet.filter2.predicate.Operators.UserDefinedByInstance;
2727

2828
/**
2929
* The Filter API is expressed through these static methods.
@@ -148,18 +148,23 @@ public static <T extends Comparable<T>, C extends Column<T> & SupportsLtGt> GtEq
148148

149149
/**
150150
* Keeps records that pass the provided {@link UserDefinedPredicate}
151+
*
152+
* The provided class must have a default constructor. To use an instance
153+
* of a UserDefinedPredicate instead, see {@link #userDefined(column, udp)} below.
151154
*/
152155
public static <T extends Comparable<T>, U extends UserDefinedPredicate<T>>
153156
UserDefined<T, U> userDefined(Column<T> column, Class<U> clazz) {
154-
return new SimpleUserDefined<T, U>(column, clazz);
157+
return new UserDefinedByClass<T, U>(column, clazz);
155158
}
156159

157160
/**
158-
* Similar to above but allows to pass Serializable {@link UserDefinedPredicate}
161+
* Keeps records that pass the provided {@link UserDefinedPredicate}
162+
*
163+
* The provided instance of UserDefinedPredicate must be serializable.
159164
*/
160165
public static <T extends Comparable<T>, U extends UserDefinedPredicate<T> & Serializable>
161166
UserDefined<T, U> userDefined(Column<T> column, U udp) {
162-
return new ConfiguredUserDefined<T, U> (column, udp);
167+
return new UserDefinedByInstance<T, U>(column, udp);
163168
}
164169

165170
/**

parquet-column/src/main/java/parquet/filter2/predicate/Operators.java

Lines changed: 20 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -342,14 +342,9 @@ public int hashCode() {
342342

343343
public static abstract class UserDefined<T extends Comparable<T>, U extends UserDefinedPredicate<T>> implements FilterPredicate, Serializable {
344344
protected final Column<T> column;
345-
protected String toString;
346-
private static final String INSTANTIATION_ERROR_MESSAGE =
347-
"Could not instantiate custom filter: %s. User defined predicates must be static classes with a default constructor.";
348345

349346
UserDefined(Column<T> column) {
350347
this.column = checkNotNull(column, "column");
351-
String name = getClass().getSimpleName().toLowerCase();
352-
this.toString = name + "(" + column.getColumnPath().toDotString() + ", UserDefined)";
353348
}
354349

355350
public Column<T> getColumn() {
@@ -359,39 +354,18 @@ public Column<T> getColumn() {
359354
public abstract U getUserDefinedPredicate();
360355

361356
@Override
362-
public abstract <R> R accept(Visitor<R> visitor);
363-
364-
@Override
365-
public String toString() {
366-
return toString;
367-
}
368-
369-
@Override
370-
public boolean equals(Object o) {
371-
if (this == o) return true;
372-
if (o == null || getClass() != o.getClass()) return false;
373-
374-
UserDefined that = (UserDefined) o;
375-
376-
if (!column.equals(that.column)) return false;
377-
378-
return true;
379-
}
380-
381-
@Override
382-
public int hashCode() {
383-
int result = column.hashCode();
384-
result = result * 31 + getClass().hashCode();
385-
return result;
357+
public <R> R accept(Visitor<R> visitor) {
358+
return visitor.visit(this);
386359
}
387360
}
388-
389-
public static final class SimpleUserDefined<T extends Comparable<T>, U extends UserDefinedPredicate<T>> extends UserDefined<T, U> {
361+
362+
public static final class UserDefinedByClass<T extends Comparable<T>, U extends UserDefinedPredicate<T>> extends UserDefined<T, U> {
390363
private final Class<U> udpClass;
364+
private final String toString;
391365
private static final String INSTANTIATION_ERROR_MESSAGE =
392-
"Could not instantiate custom filter: %s. User defined predicates must be static classes with a default constructor.";
366+
"Could not instantiate custom filter: %s. User defined predicates must be static classes with a default constructor.";
393367

394-
SimpleUserDefined(Column<T> column, Class<U> udpClass) {
368+
UserDefinedByClass(Column<T> column, Class<U> udpClass) {
395369
super(column);
396370
this.udpClass = checkNotNull(udpClass, "udpClass");
397371
String name = getClass().getSimpleName().toLowerCase();
@@ -401,14 +375,11 @@ public static final class SimpleUserDefined<T extends Comparable<T>, U extends U
401375
getUserDefinedPredicate();
402376
}
403377

404-
public Column<T> getColumn() {
405-
return column;
406-
}
407-
408378
public Class<U> getUserDefinedPredicateClass() {
409379
return udpClass;
410380
}
411381

382+
@Override
412383
public U getUserDefinedPredicate() {
413384
try {
414385
return udpClass.newInstance();
@@ -419,11 +390,6 @@ public U getUserDefinedPredicate() {
419390
}
420391
}
421392

422-
@Override
423-
public <R> R accept(Visitor<R> visitor) {
424-
return visitor.visit(this);
425-
}
426-
427393
@Override
428394
public String toString() {
429395
return toString;
@@ -434,7 +400,7 @@ public boolean equals(Object o) {
434400
if (this == o) return true;
435401
if (o == null || getClass() != o.getClass()) return false;
436402

437-
SimpleUserDefined that = (SimpleUserDefined) o;
403+
UserDefinedByClass that = (UserDefinedByClass) o;
438404

439405
if (!column.equals(that.column)) return false;
440406
if (!udpClass.equals(that.udpClass)) return false;
@@ -450,28 +416,21 @@ public int hashCode() {
450416
return result;
451417
}
452418
}
453-
454-
public static final class ConfiguredUserDefined<T extends Comparable<T>, U extends UserDefinedPredicate<T> & Serializable > extends UserDefined<T, U> {
455-
//private final Column<T> column;
456-
private final U udp;
419+
420+
public static final class UserDefinedByInstance<T extends Comparable<T>, U extends UserDefinedPredicate<T> & Serializable> extends UserDefined<T, U> {
457421
private final String toString;
422+
private final U udpInstance;
458423

459-
ConfiguredUserDefined(Column<T> column, U udp) {
460-
//column = checkNotNull(column, "column");
424+
UserDefinedByInstance(Column<T> column, U udpInstance) {
461425
super(column);
462-
this.udp = checkNotNull(udp, "udp");
426+
this.udpInstance = checkNotNull(udpInstance, "udpInstance");
463427
String name = getClass().getSimpleName().toLowerCase();
464-
this.toString = name + "(" + column.getColumnPath().toDotString() + ", " + udp.getClass().getName() + ")";
428+
this.toString = name + "(" + column.getColumnPath().toDotString() + ", " + udpInstance + ")";
465429
}
466430

467431
@Override
468432
public U getUserDefinedPredicate() {
469-
return udp;
470-
}
471-
472-
@Override
473-
public <R> R accept(Visitor<R> visitor) {
474-
return visitor.visit(this);
433+
return udpInstance;
475434
}
476435

477436
@Override
@@ -484,18 +443,18 @@ public boolean equals(Object o) {
484443
if (this == o) return true;
485444
if (o == null || getClass() != o.getClass()) return false;
486445

487-
ConfiguredUserDefined that = (ConfiguredUserDefined) o;
446+
UserDefinedByInstance that = (UserDefinedByInstance) o;
488447

489448
if (!column.equals(that.column)) return false;
490-
if (!udp.equals(that.udp)) return false;
449+
if (!udpInstance.equals(that.udpInstance)) return false;
491450

492451
return true;
493452
}
494453

495454
@Override
496455
public int hashCode() {
497456
int result = column.hashCode();
498-
result = 31 * result + udp.hashCode();
457+
result = 31 * result + udpInstance.hashCode();
499458
result = result * 31 + getClass().hashCode();
500459
return result;
501460
}
@@ -545,4 +504,5 @@ public int hashCode() {
545504
return result;
546505
}
547506
}
507+
548508
}

parquet-column/src/test/java/parquet/filter2/predicate/TestFilterApiMethods.java

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import java.io.ByteArrayOutputStream;
55
import java.io.ObjectInputStream;
66
import java.io.ObjectOutputStream;
7+
import java.io.Serializable;
78

89
import org.junit.Test;
910

@@ -14,10 +15,11 @@
1415
import parquet.filter2.predicate.Operators.Eq;
1516
import parquet.filter2.predicate.Operators.Gt;
1617
import parquet.filter2.predicate.Operators.IntColumn;
18+
import parquet.filter2.predicate.Operators.LongColumn;
1719
import parquet.filter2.predicate.Operators.Not;
1820
import parquet.filter2.predicate.Operators.Or;
19-
import parquet.filter2.predicate.Operators.SimpleUserDefined;
2021
import parquet.filter2.predicate.Operators.UserDefined;
22+
import parquet.filter2.predicate.Operators.UserDefinedByClass;
2123
import parquet.io.api.Binary;
2224

2325
import static org.junit.Assert.assertEquals;
@@ -28,6 +30,7 @@
2830
import static parquet.filter2.predicate.FilterApi.eq;
2931
import static parquet.filter2.predicate.FilterApi.gt;
3032
import static parquet.filter2.predicate.FilterApi.intColumn;
33+
import static parquet.filter2.predicate.FilterApi.longColumn;
3134
import static parquet.filter2.predicate.FilterApi.not;
3235
import static parquet.filter2.predicate.FilterApi.notEq;
3336
import static parquet.filter2.predicate.FilterApi.or;
@@ -37,6 +40,7 @@
3740
public class TestFilterApiMethods {
3841

3942
private static final IntColumn intColumn = intColumn("a.b.c");
43+
private static final LongColumn longColumn = longColumn("a.b.l");
4044
private static final DoubleColumn doubleColumn = doubleColumn("x.y.z");
4145
private static final BinaryColumn binColumn = binaryColumn("a.string.column");
4246

@@ -83,15 +87,15 @@ public void testUdp() {
8387
FilterPredicate predicate = or(eq(doubleColumn, 12.0), userDefined(intColumn, DummyUdp.class));
8488
assertTrue(predicate instanceof Or);
8589
FilterPredicate ud = ((Or) predicate).getRight();
86-
assertTrue(ud instanceof SimpleUserDefined);
87-
assertEquals(DummyUdp.class, ((SimpleUserDefined) ud).getUserDefinedPredicateClass());
90+
assertTrue(ud instanceof UserDefinedByClass);
91+
assertEquals(DummyUdp.class, ((UserDefinedByClass) ud).getUserDefinedPredicateClass());
8892
assertTrue(((UserDefined) ud).getUserDefinedPredicate() instanceof DummyUdp);
8993
}
9094

9195
@Test
92-
public void testSerializable() throws Exception {
96+
public void testSerializable() throws Exception {
9397
BinaryColumn binary = binaryColumn("foo");
94-
FilterPredicate p = or(and(userDefined(intColumn, DummyUdp.class), predicate), eq(binary, Binary.fromString("hi")));
98+
FilterPredicate p = and(or(and(userDefined(intColumn, DummyUdp.class), predicate), eq(binary, Binary.fromString("hi"))), userDefined(longColumn, new IsMultipleOf(7)));
9599
ByteArrayOutputStream baos = new ByteArrayOutputStream();
96100
ObjectOutputStream oos = new ObjectOutputStream(baos);
97101
oos.writeObject(p);
@@ -101,4 +105,50 @@ public void testSerializable() throws Exception {
101105
FilterPredicate read = (FilterPredicate) is.readObject();
102106
assertEquals(p, read);
103107
}
108+
109+
public static class IsMultipleOf extends UserDefinedPredicate<Long> implements Serializable {
110+
111+
private long of;
112+
113+
public IsMultipleOf(long of) {
114+
this.of = of;
115+
}
116+
117+
@Override
118+
public boolean keep(Long value) {
119+
if (value == null) {
120+
return false;
121+
}
122+
return value % of == 0;
123+
}
124+
125+
@Override
126+
public boolean canDrop(Statistics<Long> statistics) {
127+
return false;
128+
}
129+
130+
@Override
131+
public boolean inverseCanDrop(Statistics<Long> statistics) {
132+
return false;
133+
}
134+
135+
@Override
136+
public boolean equals(Object o) {
137+
if (this == o) return true;
138+
if (o == null || getClass() != o.getClass()) return false;
139+
140+
IsMultipleOf that = (IsMultipleOf) o;
141+
return this.of == that.of;
142+
}
143+
144+
@Override
145+
public int hashCode() {
146+
return new Long(of).hashCode();
147+
}
148+
149+
@Override
150+
public String toString() {
151+
return "IsMultipleOf(" + of + ")";
152+
}
153+
}
104154
}

parquet-hadoop/src/test/java/parquet/filter2/recordlevel/TestRecordLevelFilters.java

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ public boolean inverseCanDrop(Statistics<Binary> statistics) {
168168

169169
public static class SetInFilter extends UserDefinedPredicate<Long> implements Serializable {
170170

171-
HashSet<Long> hSet;
171+
private HashSet<Long> hSet;
172172

173173
public SetInFilter(HashSet<Long> phSet) {
174174
hSet = phSet;
@@ -211,23 +211,22 @@ public boolean keep(User u) {
211211
}
212212

213213
@Test
214-
public void testIdIn() throws Exception {
214+
public void testUserDefinedByInstance() throws Exception {
215215
LongColumn name = longColumn("id");
216216

217-
HashSet<Long> h = new HashSet<Long>() {{
218-
add(20L); add(27L); add(28L);
219-
}};
217+
final HashSet<Long> h = new HashSet<Long>();
218+
h.add(20L);
219+
h.add(27L);
220+
h.add(28L);
221+
220222
FilterPredicate pred = userDefined(name, new SetInFilter(h));
221223

222224
List<Group> found = PhoneBookWriter.readFile(phonebookFile, FilterCompat.get(pred));
223225

224226
assertFilter(found, new UserFilter() {
225227
@Override
226228
public boolean keep(User u) {
227-
Set<Long> h = new HashSet<Long>() {{
228-
add(20L); add(27L); add(28L);
229-
}};
230-
return h.contains(u.getId());
229+
return u != null && h.contains(u.getId());
231230
}
232231
});
233232
}

parquet-scala/src/main/scala/parquet/filter2/dsl/Dsl.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package parquet.filter2.dsl
22

33
import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong}
4+
import java.io.Serializable
45

56
import parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators, UserDefinedPredicate}
67
import parquet.io.api.Binary
@@ -30,6 +31,8 @@ object Dsl {
3031
val javaColumn: C
3132

3233
def filterBy[U <: UserDefinedPredicate[T]](clazz: Class[U]) = FilterApi.userDefined(javaColumn, clazz)
34+
35+
def filterBy[U <: UserDefinedPredicate[T] with Serializable](udp: U) = FilterApi.userDefined(javaColumn, udp)
3336

3437
// this is not supported because it allows for easy mistakes. For example:
3538
// val pred = IntColumn("foo") == "hello"

0 commit comments

Comments
 (0)