diff --git a/src/main/java/javaslang/collection/HashSet.java b/src/main/java/javaslang/collection/HashSet.java index fd11468399..91880b8d25 100644 --- a/src/main/java/javaslang/collection/HashSet.java +++ b/src/main/java/javaslang/collection/HashSet.java @@ -446,8 +446,8 @@ public HashSet takeRight(int n) { @Override public HashSet takeWhile(Predicate predicate) { Objects.requireNonNull(predicate, "predicate is null"); - List taken = list.get().takeWhile(predicate); - return taken.length() == list.get().length() ? this : HashSet.ofAll(taken); + HashSet taken = HashSet.ofAll(iterator().takeWhile(predicate)); + return taken.length() == length() ? this : taken; } @Override diff --git a/src/main/java/javaslang/collection/Iterator.java b/src/main/java/javaslang/collection/Iterator.java index 259ee92d60..01c436bdbf 100644 --- a/src/main/java/javaslang/collection/Iterator.java +++ b/src/main/java/javaslang/collection/Iterator.java @@ -727,7 +727,35 @@ default Iterator takeRight(int n) { @Override default Iterator takeWhile(Predicate predicate) { Objects.requireNonNull(predicate, "predicate is null"); - return null; + final Iterator that = this; + return new Iterator() { + + private T next = null; + private boolean finished = false; + + @Override + public boolean hasNext() { + while (!finished && next == null && that.hasNext()) { + final T value = that.next(); + if (predicate.test(value)) { + next = value; + } else { + finished = true; + } + } + return next != null; + } + + @Override + public T next() { + if (!hasNext()) { + EMPTY.next(); + } + final T result = next; + next = null; + return result; + } + }; } default Iterator> zip(Iterable that) {