Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iterate once to create two iterators in partition #2577

Merged
merged 15 commits into from
May 23, 2020
8 changes: 6 additions & 2 deletions src/main/java/io/vavr/collection/AbstractMultimap.java
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,12 @@ public M orElse(Supplier<? extends Iterable<? extends Tuple2<K, V>>> supplier) {
@Override
public Tuple2<M, M> partition(Predicate<? super Tuple2<K, V>> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
final Tuple2<Iterator<Tuple2<K, V>>, Iterator<Tuple2<K, V>>> p = iterator().partition(predicate);
return Tuple.of((M) createFromEntries(p._1), (M) createFromEntries(p._2));
final java.util.List<Tuple2<K, V>> left = new java.util.ArrayList<>();
final java.util.List<Tuple2<K, V>> right = new java.util.ArrayList<>();
for (Tuple2<K, V> entry : this) {
(predicate.test(entry) ? left : right).add(entry);
}
return Tuple.of((M) createFromEntries(left), (M) createFromEntries(right));
}

@SuppressWarnings("unchecked")
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/io/vavr/collection/BitSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -831,8 +831,7 @@ public BitSet<T> scan(T zero, BiFunction<? super T, ? super T, ? extends T> oper

@Override
public Tuple2<BitSet<T>, BitSet<T>> partition(Predicate<? super T> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
return iterator().partition(predicate).map(this::createFromAll, this::createFromAll);
return Collections.partition(this, this::createFromAll, predicate);
}

@Override
Expand Down
15 changes: 14 additions & 1 deletion src/main/java/io/vavr/collection/Collections.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
*/
package io.vavr.collection;

import io.vavr.Tuple;
import io.vavr.Tuple2;
import io.vavr.collection.JavaConverters.ChangePolicy;
import io.vavr.collection.JavaConverters.ListView;
import io.vavr.control.Option;
Expand Down Expand Up @@ -294,6 +296,17 @@ static <K, V, K2, U extends Map<K2, V>> U mapKeys(Map<K, V> source, U zero, Func
});
}

static <C extends Traversable<T>, T> Tuple2<C, C> partition(C collection, Function<Iterable<T>, C> creator,
Predicate<? super T> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
final java.util.List<T> left = new java.util.ArrayList<>();
final java.util.List<T> right = new java.util.ArrayList<>();
for (T element : collection) {
(predicate.test(element) ? left : right).add(element);
}
return Tuple.of(creator.apply(left), creator.apply(right));
}

@SuppressWarnings("unchecked")
static <C extends Traversable<T>, T> C removeAll(C source, Iterable<? extends T> elements) {
Objects.requireNonNull(elements, "elements is null");
Expand Down Expand Up @@ -556,7 +569,7 @@ private static <T> IterableWithSize<T> withSizeTraversable(Iterable<? extends T>
return new IterableWithSize<>(iterable, ((Traversable<?>) iterable).size());
}
}

static class IterableWithSize<T> {
private final Iterable<? extends T> iterable;
private final int size;
Expand Down
6 changes: 2 additions & 4 deletions src/main/java/io/vavr/collection/HashSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ public boolean isAsync() {
public boolean isEmpty() {
return tree.isEmpty();
}

/**
* A {@code HashSet} is computed eagerly.
*
Expand Down Expand Up @@ -730,9 +730,7 @@ public HashSet<T> orElse(Supplier<? extends Iterable<? extends T>> supplier) {

@Override
public Tuple2<HashSet<T>, HashSet<T>> partition(Predicate<? super T> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
final Tuple2<Iterator<T>, Iterator<T>> p = iterator().partition(predicate);
return Tuple.of(HashSet.ofAll(p._1), HashSet.ofAll(p._2));
return Collections.partition(this, HashSet::ofAll, predicate);
}

@Override
Expand Down
48 changes: 44 additions & 4 deletions src/main/java/io/vavr/collection/Iterator.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import java.math.BigDecimal;
import java.util.*;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.*;

import static io.vavr.collection.BigDecimalHelper.areEqual;
Expand Down Expand Up @@ -1722,10 +1723,8 @@ default Tuple2<Iterator<T>, Iterator<T>> partition(Predicate<? super T> predicat
if (!hasNext()) {
return Tuple.of(empty(), empty());
} else {
final Stream<T> that = Stream.ofAll(this);
Copy link

@kefasb kefasb May 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also partition method implemented by Stream which works wrongly too.

I do not know another places. Maybe @danieldietrich could point them out.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not aware of the problem of Stream#partition. Thanks for pointing it out! 🙇 I checked all data-structures in VAVR, apart from Stream, all other implementations work correctly. About the Stream itself, I'm afraid that I don't know how to fix it and I'm not sure if it's really a bug. In Scala 2.13.2, the Stream implementation uses the predicates twice (link) as below:

override def partition(p: A => Boolean): (Stream[A], Stream[A]) = (filter(p(_)), filterNot(p(_)))

If I modify your example in issue #2559 by changing Set into Stream, you will see result in Scala as follow:

  "Stream" should "partition correctly" in {
    val fruitsToEat = Stream("apple", "banana")
    val partition = fruitsToEat.partition(name => biteAndCheck(name))

    partition._1 shouldEqual Stream()
    partition._2 shouldEqual Stream()  // not Stream("apple", "banana")

    fruitsBeingEaten.get("apple").get.name shouldEqual "apple"
    fruitsBeingEaten.get("apple").get.bites shouldEqual 2  // not 1
    fruitsBeingEaten.get("banana").get.name shouldEqual "banana"
    fruitsBeingEaten.get("banana").get.bites shouldEqual 2  // not 1
  }

So I would say the VAVR is aligned with Scala on Stream's behaviors. Also, my current approach (ArrayList) does not fit the Stream requirement, because Stream is lazy sequence of elements which may be infinitely long. So I don't know how to use only one predicate to achieve that... So my suggestion is let's keep stream as it is and avoid modifying it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stream is lazy sequence of elements which may be infinitely long

Oh sorry guys I completely forgot about it :(

final Iterator<T> first = that.iterator().filter(predicate);
final Iterator<T> second = that.iterator().filter(predicate.negate());
return Tuple.of(first, second);
final Tuple2<Iterator<T>, Iterator<T>> dup = IteratorModule.duplicate(this);
return Tuple.of(dup._1.filter(predicate), dup._2.filterNot(predicate));
}
}

Expand Down Expand Up @@ -1952,6 +1951,7 @@ default Tuple2<Iterator<T>, Iterator<T>> span(Predicate<? super T> predicate) {
}
}


@Override
default String stringPrefix() {
return "Iterator";
Expand Down Expand Up @@ -2182,6 +2182,46 @@ public String toString() {
}
}

interface IteratorModule {
/**
* Creates two new iterators that both iterates over the same elements as
* this iterator and in the same order. The duplicate iterators are
* considered equal if they are positioned at the same element.
* <p>
* Given that most methods on iterators will make the original iterator
* unfit for further use, this methods provides a reliable way of calling
* multiple such methods on an iterator.
*
* @return a pair of iterators
*/
static <T> Tuple2<Iterator<T>, Iterator<T>> duplicate(Iterator<T> iterator) {
final java.util.Queue<T> gap = new java.util.LinkedList<>();
final AtomicReference<Iterator<T>> ahead = new AtomicReference<>();
class Partner implements Iterator<T> {

@Override
public boolean hasNext() {
return (this != ahead.get() && !gap.isEmpty()) || iterator.hasNext();
}

@Override
public T next() {
if (gap.isEmpty()) {
ahead.set(this);
}
if (this == ahead.get()) {
final T element = iterator.next();
gap.add(element);
return element;
} else {
return gap.poll();
}
}
}
return Tuple.of(new Partner(), new Partner());
}
}

final class ConcatIterator<T> implements Iterator<T> {

private static final class Iterators<T> {
Expand Down
4 changes: 1 addition & 3 deletions src/main/java/io/vavr/collection/LinkedHashSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -751,9 +751,7 @@ public LinkedHashSet<T> orElse(Supplier<? extends Iterable<? extends T>> supplie

@Override
public Tuple2<LinkedHashSet<T>, LinkedHashSet<T>> partition(Predicate<? super T> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
final Tuple2<Iterator<T>, Iterator<T>> p = iterator().partition(predicate);
return Tuple.of(LinkedHashSet.ofAll(p._1), LinkedHashSet.ofAll(p._2));
return Collections.partition(this, LinkedHashSet::ofAll, predicate);
}

@Override
Expand Down
8 changes: 6 additions & 2 deletions src/main/java/io/vavr/collection/Maps.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,12 @@ static <T, K, V, M extends Map<K, V>> M ofStream(M map, java.util.stream.Stream<
static <K, V, M extends Map<K, V>> Tuple2<M, M> partition(M map, OfEntries<K, V, M> ofEntries,
Predicate<? super Tuple2<K, V>> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
final Tuple2<Iterator<Tuple2<K, V>>, Iterator<Tuple2<K, V>>> p = map.iterator().partition(predicate);
return Tuple.of(ofEntries.apply(p._1), ofEntries.apply(p._2));
final java.util.List<Tuple2<K, V>> left = new java.util.ArrayList<>();
final java.util.List<Tuple2<K, V>> right = new java.util.ArrayList<>();
for (Tuple2<K, V> entry : map) {
(predicate.test(entry) ? left : right).add(entry);
}
return Tuple.of(ofEntries.apply(left), ofEntries.apply(right));
}

static <K, V, M extends Map<K, V>> M peek(M map, Consumer<? super Tuple2<K, V>> action) {
Expand Down
18 changes: 9 additions & 9 deletions src/main/java/io/vavr/collection/Traversable.java
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ public interface Traversable<T> extends Iterable<T>, Foldable<T>, io.vavr.Value<
static <T> Traversable<T> narrow(Traversable<? extends T> traversable) {
return (Traversable<T>) traversable;
}

/**
* Matches each element with a unique key that you extract from it.
* If the same key is present twice, the function will return {@code None}.
Expand Down Expand Up @@ -223,7 +223,7 @@ default Option<Double> average() {
throw new UnsupportedOperationException("not numeric", x);
}
}

/**
* Collects all elements that are in the domain of the given {@code partialFunction} by mapping the elements to type {@code R}.
* <p>
Expand Down Expand Up @@ -341,7 +341,7 @@ default int count(Predicate<? super T> predicate) {
Traversable<T> dropUntil(Predicate<? super T> predicate);

/**
* Drops elements while the predicate holds for the current element.
* Drops elements while the predicate holds for the current element.
* <p>
* Note: This is essentially the same as {@code dropUntil(predicate.negate())}.
* It is intended to be used with method references, which cannot be negated directly.
Expand Down Expand Up @@ -370,7 +370,7 @@ default int count(Predicate<? super T> predicate) {
* <li>contain the same elements</li>
* <li>have the same element order, if the collections are of type Seq</li>
* </ul>
*
*
* Two Map/Multimap elements, resp. entries, (key1, value1) and (key2, value2) are equal,
* if the keys are equal and the values are equal.
* <p>
Expand Down Expand Up @@ -685,7 +685,7 @@ default T get() {
default Option<T> headOption() {
return isEmpty() ? Option.none() : Option.some(head());
}

/**
* Returns the hash code of this collection.
* <br>
Expand Down Expand Up @@ -1055,7 +1055,7 @@ default <U extends Comparable<? super U>> Option<T> minBy(Function<? super T, ?
return Option.some(tm);
}
}

/**
* Joins the elements of this by concatenating their string representations.
* <p>
Expand Down Expand Up @@ -1372,7 +1372,7 @@ default Option<T> reduceRightOption(BiFunction<? super T, ? super T, ? extends T
* @throws NullPointerException if {@code operation} is null.
*/
<U> Traversable<U> scanRight(U zero, BiFunction<? super T, ? super U, ? extends U> operation);

/**
* Returns the single element of this Traversable or throws, if this is empty or contains more than one element.
*
Expand Down Expand Up @@ -1536,7 +1536,7 @@ default Number sum() {
}
}
}

/**
* Drops the first element of a non-empty Traversable.
*
Expand Down Expand Up @@ -1666,7 +1666,7 @@ default Number sum() {
* @throws NullPointerException if {@code that} is null
*/
<U> Traversable<Tuple2<T, U>> zipAll(Iterable<? extends U> that, T thisElem, U thatElem);

/**
* Returns a traversable formed from this traversable and another Iterable collection by mapping elements.
* If one of the two iterables is longer than the other, its remaining elements are ignored.
Expand Down
4 changes: 1 addition & 3 deletions src/main/java/io/vavr/collection/TreeSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -800,9 +800,7 @@ public TreeSet<T> orElse(Supplier<? extends Iterable<? extends T>> supplier) {

@Override
public Tuple2<TreeSet<T>, TreeSet<T>> partition(Predicate<? super T> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
return iterator().partition(predicate).map(i1 -> TreeSet.ofAll(tree.comparator(), i1),
i2 -> TreeSet.ofAll(tree.comparator(), i2));
return Collections.partition(this, values -> TreeSet.ofAll(tree.comparator(), values), predicate);
}

@Override
Expand Down
15 changes: 15 additions & 0 deletions src/test/java/io/vavr/collection/AbstractMapTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1461,6 +1461,21 @@ public void shouldReturnDefaultValue() {
assertThat(map.getOrElse("3", "3")).isEqualTo("3");
}

// -- partition

@Test
public void shouldPartitionInOneIteration() {
final AtomicInteger count = new AtomicInteger(0);
final Map<String, Integer> map = mapOf("1", 1, "2", 2, "3", 3);
final Tuple2<? extends Map<String, Integer>, ? extends Map<String, Integer>> results = map.partition(entry -> {
count.incrementAndGet();
return true;
});
assertThat(results._1).isEqualTo(mapOf("1", 1, "2", 2, "3", 3));
assertThat(results._2).isEmpty();
assertThat(count.get()).isEqualTo(3);
}

// -- spliterator

@Test
Expand Down
14 changes: 14 additions & 0 deletions src/test/java/io/vavr/collection/AbstractMultimapTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,20 @@ public void shouldPartitionIntsInOddAndEvenHavingOddAndEvenNumbers() {
mapOfTuples(Tuple.of(1, 2), Tuple.of(3, 4))));
}

@Test
@SuppressWarnings("unchecked")
public void shouldPartitionInOneIteration() {
final AtomicInteger count = new AtomicInteger(0);
final Multimap<String, Integer> map = mapOfTuples(Tuple.of("1", 1), Tuple.of("2", 2), Tuple.of("3", 3));
final Tuple2<? extends Multimap<String, Integer>, ? extends Multimap<String, Integer>> results = map.partition(entry -> {
count.incrementAndGet();
return true;
});
assertThat(results._1).isEqualTo(mapOfTuples(Tuple.of("1", 1), Tuple.of("2", 2), Tuple.of("3", 3)));
assertThat(results._2).isEmpty();
assertThat(count.get()).isEqualTo(3);
}

// -- put

@Test
Expand Down
18 changes: 17 additions & 1 deletion src/test/java/io/vavr/collection/AbstractSetTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
*/
package io.vavr.collection;

import io.vavr.Tuple2;
import org.junit.Test;

import java.math.BigDecimal;
import java.util.Spliterator;
import java.util.concurrent.atomic.AtomicInteger;

public abstract class AbstractSetTest extends AbstractTraversableRangeTest {

Expand Down Expand Up @@ -189,6 +191,20 @@ public void shouldRemoveElement() {
assertThat(empty().remove(5)).isEqualTo(empty());
}

// -- partition

@Test
public void shouldPartitionInOneIteration() {
final AtomicInteger count = new AtomicInteger(0);
final Tuple2<? extends Set<Integer>, ? extends Set<Integer>> results = of(1, 2, 3).partition(i -> {
count.incrementAndGet();
return true;
});
assertThat(results._1).isEqualTo(of(1, 2, 3));
assertThat(results._2).isEqualTo(of());
assertThat(count.get()).isEqualTo(3);
}

// -- removeAll

@Test
Expand Down Expand Up @@ -227,7 +243,7 @@ public void shouldReturnSameSetWhenEmptyUnionNonEmpty() {
assertThat(empty().union(set)).isSameAs(set);
}
}

@Test
public void shouldReturnSameSetWhenNonEmptyUnionEmpty() {
final Set<Integer> set = of(1, 2);
Expand Down
15 changes: 15 additions & 0 deletions src/test/java/io/vavr/collection/ArrayTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
Expand Down Expand Up @@ -243,6 +244,20 @@ public void shouldThrowExceptionWhenGetIndexEqualToLength() {
.isInstanceOf(IndexOutOfBoundsException.class).hasMessage("get(1)");
}

// -- partition

@Test
public void shouldPartitionInOneIteration() {
final AtomicInteger count = new AtomicInteger(0);
final Tuple2<Array<Integer>, Array<Integer>> results = of(1, 2, 3).partition(i -> {
count.incrementAndGet();
return true;
});
assertThat(results._1).isEqualTo(of(1, 2, 3));
assertThat(results._2).isEqualTo(of());
assertThat(count.get()).isEqualTo(3);
}

// -- transform()

@Test
Expand Down
Loading