Skip to content

Commit

Permalink
Iterate once to create two iterators in partition (#2577)
Browse files Browse the repository at this point in the history
* Reproduce the problem

* Iterate once to create two iterators in partition

* Avoid using io.vavr.collection.Stream

* Test behavior of `partition` on different classes

* Test that Stream.partition() is lazy

* Create Iterator.duplicate() and add tests

* Change the implementation of Iterator.partition()

* Fix Set

* Fix Map

* Fix Multimap

* Move duplicate to IteratorModule

* Remove synchronized keyword

* Remove hashCode and equals

* Avoid using isEqualTo

* Remove redundant tests
  • Loading branch information
mincong-h authored May 23, 2020
1 parent f59cd60 commit 1c90106
Show file tree
Hide file tree
Showing 20 changed files with 360 additions and 31 deletions.
8 changes: 6 additions & 2 deletions src/main/java/io/vavr/collection/AbstractMultimap.java
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,12 @@ public M orElse(Supplier<? extends Iterable<? extends Tuple2<K, V>>> supplier) {
@Override
public Tuple2<M, M> partition(Predicate<? super Tuple2<K, V>> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
final Tuple2<Iterator<Tuple2<K, V>>, Iterator<Tuple2<K, V>>> p = iterator().partition(predicate);
return Tuple.of((M) createFromEntries(p._1), (M) createFromEntries(p._2));
final java.util.List<Tuple2<K, V>> left = new java.util.ArrayList<>();
final java.util.List<Tuple2<K, V>> right = new java.util.ArrayList<>();
for (Tuple2<K, V> entry : this) {
(predicate.test(entry) ? left : right).add(entry);
}
return Tuple.of((M) createFromEntries(left), (M) createFromEntries(right));
}

@SuppressWarnings("unchecked")
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/io/vavr/collection/BitSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -831,8 +831,7 @@ public BitSet<T> scan(T zero, BiFunction<? super T, ? super T, ? extends T> oper

@Override
public Tuple2<BitSet<T>, BitSet<T>> partition(Predicate<? super T> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
return iterator().partition(predicate).map(this::createFromAll, this::createFromAll);
return Collections.partition(this, this::createFromAll, predicate);
}

@Override
Expand Down
15 changes: 14 additions & 1 deletion src/main/java/io/vavr/collection/Collections.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
*/
package io.vavr.collection;

import io.vavr.Tuple;
import io.vavr.Tuple2;
import io.vavr.collection.JavaConverters.ChangePolicy;
import io.vavr.collection.JavaConverters.ListView;
import io.vavr.control.Option;
Expand Down Expand Up @@ -294,6 +296,17 @@ static <K, V, K2, U extends Map<K2, V>> U mapKeys(Map<K, V> source, U zero, Func
});
}

static <C extends Traversable<T>, T> Tuple2<C, C> partition(C collection, Function<Iterable<T>, C> creator,
Predicate<? super T> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
final java.util.List<T> left = new java.util.ArrayList<>();
final java.util.List<T> right = new java.util.ArrayList<>();
for (T element : collection) {
(predicate.test(element) ? left : right).add(element);
}
return Tuple.of(creator.apply(left), creator.apply(right));
}

@SuppressWarnings("unchecked")
static <C extends Traversable<T>, T> C removeAll(C source, Iterable<? extends T> elements) {
Objects.requireNonNull(elements, "elements is null");
Expand Down Expand Up @@ -556,7 +569,7 @@ private static <T> IterableWithSize<T> withSizeTraversable(Iterable<? extends T>
return new IterableWithSize<>(iterable, ((Traversable<?>) iterable).size());
}
}

static class IterableWithSize<T> {
private final Iterable<? extends T> iterable;
private final int size;
Expand Down
6 changes: 2 additions & 4 deletions src/main/java/io/vavr/collection/HashSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ public boolean isAsync() {
public boolean isEmpty() {
return tree.isEmpty();
}

/**
* A {@code HashSet} is computed eagerly.
*
Expand Down Expand Up @@ -730,9 +730,7 @@ public HashSet<T> orElse(Supplier<? extends Iterable<? extends T>> supplier) {

@Override
public Tuple2<HashSet<T>, HashSet<T>> partition(Predicate<? super T> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
final Tuple2<Iterator<T>, Iterator<T>> p = iterator().partition(predicate);
return Tuple.of(HashSet.ofAll(p._1), HashSet.ofAll(p._2));
return Collections.partition(this, HashSet::ofAll, predicate);
}

@Override
Expand Down
48 changes: 44 additions & 4 deletions src/main/java/io/vavr/collection/Iterator.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import java.math.BigDecimal;
import java.util.*;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.*;

import static io.vavr.collection.BigDecimalHelper.areEqual;
Expand Down Expand Up @@ -1722,10 +1723,8 @@ default Tuple2<Iterator<T>, Iterator<T>> partition(Predicate<? super T> predicat
if (!hasNext()) {
return Tuple.of(empty(), empty());
} else {
final Stream<T> that = Stream.ofAll(this);
final Iterator<T> first = that.iterator().filter(predicate);
final Iterator<T> second = that.iterator().filter(predicate.negate());
return Tuple.of(first, second);
final Tuple2<Iterator<T>, Iterator<T>> dup = IteratorModule.duplicate(this);
return Tuple.of(dup._1.filter(predicate), dup._2.filterNot(predicate));
}
}

Expand Down Expand Up @@ -1952,6 +1951,7 @@ default Tuple2<Iterator<T>, Iterator<T>> span(Predicate<? super T> predicate) {
}
}


@Override
default String stringPrefix() {
return "Iterator";
Expand Down Expand Up @@ -2182,6 +2182,46 @@ public String toString() {
}
}

interface IteratorModule {
/**
* Creates two new iterators that both iterates over the same elements as
* this iterator and in the same order. The duplicate iterators are
* considered equal if they are positioned at the same element.
* <p>
* Given that most methods on iterators will make the original iterator
* unfit for further use, this methods provides a reliable way of calling
* multiple such methods on an iterator.
*
* @return a pair of iterators
*/
static <T> Tuple2<Iterator<T>, Iterator<T>> duplicate(Iterator<T> iterator) {
final java.util.Queue<T> gap = new java.util.LinkedList<>();
final AtomicReference<Iterator<T>> ahead = new AtomicReference<>();
class Partner implements Iterator<T> {

@Override
public boolean hasNext() {
return (this != ahead.get() && !gap.isEmpty()) || iterator.hasNext();
}

@Override
public T next() {
if (gap.isEmpty()) {
ahead.set(this);
}
if (this == ahead.get()) {
final T element = iterator.next();
gap.add(element);
return element;
} else {
return gap.poll();
}
}
}
return Tuple.of(new Partner(), new Partner());
}
}

final class ConcatIterator<T> implements Iterator<T> {

private static final class Iterators<T> {
Expand Down
4 changes: 1 addition & 3 deletions src/main/java/io/vavr/collection/LinkedHashSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -751,9 +751,7 @@ public LinkedHashSet<T> orElse(Supplier<? extends Iterable<? extends T>> supplie

@Override
public Tuple2<LinkedHashSet<T>, LinkedHashSet<T>> partition(Predicate<? super T> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
final Tuple2<Iterator<T>, Iterator<T>> p = iterator().partition(predicate);
return Tuple.of(LinkedHashSet.ofAll(p._1), LinkedHashSet.ofAll(p._2));
return Collections.partition(this, LinkedHashSet::ofAll, predicate);
}

@Override
Expand Down
8 changes: 6 additions & 2 deletions src/main/java/io/vavr/collection/Maps.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,12 @@ static <T, K, V, M extends Map<K, V>> M ofStream(M map, java.util.stream.Stream<
static <K, V, M extends Map<K, V>> Tuple2<M, M> partition(M map, OfEntries<K, V, M> ofEntries,
Predicate<? super Tuple2<K, V>> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
final Tuple2<Iterator<Tuple2<K, V>>, Iterator<Tuple2<K, V>>> p = map.iterator().partition(predicate);
return Tuple.of(ofEntries.apply(p._1), ofEntries.apply(p._2));
final java.util.List<Tuple2<K, V>> left = new java.util.ArrayList<>();
final java.util.List<Tuple2<K, V>> right = new java.util.ArrayList<>();
for (Tuple2<K, V> entry : map) {
(predicate.test(entry) ? left : right).add(entry);
}
return Tuple.of(ofEntries.apply(left), ofEntries.apply(right));
}

static <K, V, M extends Map<K, V>> M peek(M map, Consumer<? super Tuple2<K, V>> action) {
Expand Down
18 changes: 9 additions & 9 deletions src/main/java/io/vavr/collection/Traversable.java
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ public interface Traversable<T> extends Iterable<T>, Foldable<T>, io.vavr.Value<
static <T> Traversable<T> narrow(Traversable<? extends T> traversable) {
return (Traversable<T>) traversable;
}

/**
* Matches each element with a unique key that you extract from it.
* If the same key is present twice, the function will return {@code None}.
Expand Down Expand Up @@ -223,7 +223,7 @@ default Option<Double> average() {
throw new UnsupportedOperationException("not numeric", x);
}
}

/**
* Collects all elements that are in the domain of the given {@code partialFunction} by mapping the elements to type {@code R}.
* <p>
Expand Down Expand Up @@ -341,7 +341,7 @@ default int count(Predicate<? super T> predicate) {
Traversable<T> dropUntil(Predicate<? super T> predicate);

/**
* Drops elements while the predicate holds for the current element.
* Drops elements while the predicate holds for the current element.
* <p>
* Note: This is essentially the same as {@code dropUntil(predicate.negate())}.
* It is intended to be used with method references, which cannot be negated directly.
Expand Down Expand Up @@ -370,7 +370,7 @@ default int count(Predicate<? super T> predicate) {
* <li>contain the same elements</li>
* <li>have the same element order, if the collections are of type Seq</li>
* </ul>
*
*
* Two Map/Multimap elements, resp. entries, (key1, value1) and (key2, value2) are equal,
* if the keys are equal and the values are equal.
* <p>
Expand Down Expand Up @@ -685,7 +685,7 @@ default T get() {
default Option<T> headOption() {
return isEmpty() ? Option.none() : Option.some(head());
}

/**
* Returns the hash code of this collection.
* <br>
Expand Down Expand Up @@ -1055,7 +1055,7 @@ default <U extends Comparable<? super U>> Option<T> minBy(Function<? super T, ?
return Option.some(tm);
}
}

/**
* Joins the elements of this by concatenating their string representations.
* <p>
Expand Down Expand Up @@ -1372,7 +1372,7 @@ default Option<T> reduceRightOption(BiFunction<? super T, ? super T, ? extends T
* @throws NullPointerException if {@code operation} is null.
*/
<U> Traversable<U> scanRight(U zero, BiFunction<? super T, ? super U, ? extends U> operation);

/**
* Returns the single element of this Traversable or throws, if this is empty or contains more than one element.
*
Expand Down Expand Up @@ -1536,7 +1536,7 @@ default Number sum() {
}
}
}

/**
* Drops the first element of a non-empty Traversable.
*
Expand Down Expand Up @@ -1666,7 +1666,7 @@ default Number sum() {
* @throws NullPointerException if {@code that} is null
*/
<U> Traversable<Tuple2<T, U>> zipAll(Iterable<? extends U> that, T thisElem, U thatElem);

/**
* Returns a traversable formed from this traversable and another Iterable collection by mapping elements.
* If one of the two iterables is longer than the other, its remaining elements are ignored.
Expand Down
4 changes: 1 addition & 3 deletions src/main/java/io/vavr/collection/TreeSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -800,9 +800,7 @@ public TreeSet<T> orElse(Supplier<? extends Iterable<? extends T>> supplier) {

@Override
public Tuple2<TreeSet<T>, TreeSet<T>> partition(Predicate<? super T> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
return iterator().partition(predicate).map(i1 -> TreeSet.ofAll(tree.comparator(), i1),
i2 -> TreeSet.ofAll(tree.comparator(), i2));
return Collections.partition(this, values -> TreeSet.ofAll(tree.comparator(), values), predicate);
}

@Override
Expand Down
15 changes: 15 additions & 0 deletions src/test/java/io/vavr/collection/AbstractMapTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1461,6 +1461,21 @@ public void shouldReturnDefaultValue() {
assertThat(map.getOrElse("3", "3")).isEqualTo("3");
}

// -- partition

@Test
public void shouldPartitionInOneIteration() {
final AtomicInteger count = new AtomicInteger(0);
final Map<String, Integer> map = mapOf("1", 1, "2", 2, "3", 3);
final Tuple2<? extends Map<String, Integer>, ? extends Map<String, Integer>> results = map.partition(entry -> {
count.incrementAndGet();
return true;
});
assertThat(results._1).isEqualTo(mapOf("1", 1, "2", 2, "3", 3));
assertThat(results._2).isEmpty();
assertThat(count.get()).isEqualTo(3);
}

// -- spliterator

@Test
Expand Down
14 changes: 14 additions & 0 deletions src/test/java/io/vavr/collection/AbstractMultimapTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,20 @@ public void shouldPartitionIntsInOddAndEvenHavingOddAndEvenNumbers() {
mapOfTuples(Tuple.of(1, 2), Tuple.of(3, 4))));
}

@Test
@SuppressWarnings("unchecked")
public void shouldPartitionInOneIteration() {
final AtomicInteger count = new AtomicInteger(0);
final Multimap<String, Integer> map = mapOfTuples(Tuple.of("1", 1), Tuple.of("2", 2), Tuple.of("3", 3));
final Tuple2<? extends Multimap<String, Integer>, ? extends Multimap<String, Integer>> results = map.partition(entry -> {
count.incrementAndGet();
return true;
});
assertThat(results._1).isEqualTo(mapOfTuples(Tuple.of("1", 1), Tuple.of("2", 2), Tuple.of("3", 3)));
assertThat(results._2).isEmpty();
assertThat(count.get()).isEqualTo(3);
}

// -- put

@Test
Expand Down
18 changes: 17 additions & 1 deletion src/test/java/io/vavr/collection/AbstractSetTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
*/
package io.vavr.collection;

import io.vavr.Tuple2;
import org.junit.Test;

import java.math.BigDecimal;
import java.util.Spliterator;
import java.util.concurrent.atomic.AtomicInteger;

public abstract class AbstractSetTest extends AbstractTraversableRangeTest {

Expand Down Expand Up @@ -189,6 +191,20 @@ public void shouldRemoveElement() {
assertThat(empty().remove(5)).isEqualTo(empty());
}

// -- partition

@Test
public void shouldPartitionInOneIteration() {
final AtomicInteger count = new AtomicInteger(0);
final Tuple2<? extends Set<Integer>, ? extends Set<Integer>> results = of(1, 2, 3).partition(i -> {
count.incrementAndGet();
return true;
});
assertThat(results._1).isEqualTo(of(1, 2, 3));
assertThat(results._2).isEqualTo(of());
assertThat(count.get()).isEqualTo(3);
}

// -- removeAll

@Test
Expand Down Expand Up @@ -227,7 +243,7 @@ public void shouldReturnSameSetWhenEmptyUnionNonEmpty() {
assertThat(empty().union(set)).isSameAs(set);
}
}

@Test
public void shouldReturnSameSetWhenNonEmptyUnionEmpty() {
final Set<Integer> set = of(1, 2);
Expand Down
15 changes: 15 additions & 0 deletions src/test/java/io/vavr/collection/ArrayTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
Expand Down Expand Up @@ -243,6 +244,20 @@ public void shouldThrowExceptionWhenGetIndexEqualToLength() {
.isInstanceOf(IndexOutOfBoundsException.class).hasMessage("get(1)");
}

// -- partition

@Test
public void shouldPartitionInOneIteration() {
final AtomicInteger count = new AtomicInteger(0);
final Tuple2<Array<Integer>, Array<Integer>> results = of(1, 2, 3).partition(i -> {
count.incrementAndGet();
return true;
});
assertThat(results._1).isEqualTo(of(1, 2, 3));
assertThat(results._2).isEqualTo(of());
assertThat(count.get()).isEqualTo(3);
}

// -- transform()

@Test
Expand Down
Loading

0 comments on commit 1c90106

Please sign in to comment.