Skip to content

Commit

Permalink
Replace wrapTransversable generators to prevent memory leaks (#2709)
Browse files Browse the repository at this point in the history
The generator hold a circular reference to the iterator instance.
  • Loading branch information
GromNaN authored Dec 19, 2024
1 parent 4d2e51a commit ad7827c
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 88 deletions.
49 changes: 16 additions & 33 deletions lib/Doctrine/ODM/MongoDB/Iterator/CachingIterator.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
namespace Doctrine\ODM\MongoDB\Iterator;

use Countable;
use Generator;
use Iterator as SPLIterator;
use IteratorIterator;
use ReturnTypeWillChange;
use RuntimeException;
use Traversable;
Expand Down Expand Up @@ -33,13 +34,11 @@ final class CachingIterator implements Countable, Iterator
/** @var array<mixed, TValue> */
private array $items = [];

/** @var Generator<mixed, TValue>|null */
private ?Generator $iterator;
/** @var SPLIterator<mixed, TValue>|null */
private ?SPLIterator $iterator;

private bool $iteratorAdvanced = false;

private bool $iteratorExhausted = false;

/**
* Initialize the iterator and stores the first item in the cache. This
* effectively rewinds the Traversable and the wrapping Generator, which
Expand All @@ -51,7 +50,8 @@ final class CachingIterator implements Countable, Iterator
*/
public function __construct(Traversable $iterator)
{
$this->iterator = $this->wrapTraversable($iterator);
$this->iterator = new IteratorIterator($iterator);
$this->iterator->rewind();
$this->storeCurrentItem();
}

Expand Down Expand Up @@ -94,9 +94,10 @@ public function key()
/** @see http://php.net/iterator.next */
public function next(): void
{
if (! $this->iteratorExhausted) {
$this->getIterator()->next();
if ($this->iterator !== null) {
$this->iterator->next();
$this->storeCurrentItem();
$this->iteratorAdvanced = true;
}

next($this->items);
Expand Down Expand Up @@ -126,15 +127,13 @@ public function valid(): bool
*/
private function exhaustIterator(): void
{
while (! $this->iteratorExhausted) {
while ($this->iterator !== null) {
$this->next();
}

$this->iterator = null;
}

/** @return Generator<mixed, TValue> */
private function getIterator(): Generator
/** @return SPLIterator<mixed, TValue> */
private function getIterator(): SPLIterator
{
if ($this->iterator === null) {
throw new RuntimeException('Iterator has already been destroyed');
Expand All @@ -148,28 +147,12 @@ private function getIterator(): Generator
*/
private function storeCurrentItem(): void
{
$key = $this->getIterator()->key();
$key = $this->iterator->key();

if ($key === null) {
return;
$this->iterator = null;
} else {
$this->items[$key] = $this->getIterator()->current();
}

$this->items[$key] = $this->getIterator()->current();
}

/**
* @param Traversable<mixed, TValue> $traversable
*
* @return Generator<mixed, TValue>
*/
private function wrapTraversable(Traversable $traversable): Generator
{
foreach ($traversable as $key => $value) {
yield $key => $value;

$this->iteratorAdvanced = true;
}

$this->iteratorExhausted = true;
}
}
25 changes: 7 additions & 18 deletions lib/Doctrine/ODM/MongoDB/Iterator/HydratingIterator.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

use Doctrine\ODM\MongoDB\Mapping\ClassMetadata;
use Doctrine\ODM\MongoDB\UnitOfWork;
use Generator;
use Iterator;
use IteratorIterator;
use ReturnTypeWillChange;
use RuntimeException;
use Traversable;
Expand All @@ -24,8 +24,8 @@
*/
final class HydratingIterator implements Iterator
{
/** @var Generator<mixed, array<string, mixed>>|null */
private ?Generator $iterator;
/** @var Iterator<mixed, array<string, mixed>>|null */
private ?Iterator $iterator;

/**
* @param Traversable<mixed, array<string, mixed>> $traversable
Expand All @@ -34,7 +34,8 @@ final class HydratingIterator implements Iterator
*/
public function __construct(Traversable $traversable, private UnitOfWork $unitOfWork, private ClassMetadata $class, private array $unitOfWorkHints = [])
{
$this->iterator = $this->wrapTraversable($traversable);
$this->iterator = new IteratorIterator($traversable);
$this->iterator->rewind();
}

public function __destruct()
Expand Down Expand Up @@ -74,8 +75,8 @@ public function valid(): bool
return $this->key() !== null;
}

/** @return Generator<mixed, array<string, mixed>> */
private function getIterator(): Generator
/** @return Iterator<mixed, array<string, mixed>> */
private function getIterator(): Iterator
{
if ($this->iterator === null) {
throw new RuntimeException('Iterator has already been destroyed');
Expand All @@ -93,16 +94,4 @@ private function hydrate(?array $document): ?object
{
return $document !== null ? $this->unitOfWork->getOrCreateDocument($this->class->name, $document, $this->unitOfWorkHints) : null;
}

/**
* @param Traversable<mixed, array<string, mixed>> $traversable
*
* @return Generator<mixed, array<string, mixed>>
*/
private function wrapTraversable(Traversable $traversable): Generator
{
foreach ($traversable as $key => $value) {
yield $key => $value;
}
}
}
61 changes: 24 additions & 37 deletions lib/Doctrine/ODM/MongoDB/Iterator/UnrewindableIterator.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

namespace Doctrine\ODM\MongoDB\Iterator;

use Generator;
use Iterator as SPLIterator;
use IteratorIterator;
use LogicException;
use ReturnTypeWillChange;
use RuntimeException;
Expand All @@ -23,39 +24,34 @@
*/
final class UnrewindableIterator implements Iterator
{
/** @var Generator<mixed, TValue>|null */
private ?Generator $iterator;
/** @var SPLIterator<mixed, TValue>|null */
private ?SPLIterator $iterator;

private bool $iteratorAdvanced = false;

/**
* Initialize the iterator. This effectively rewinds the Traversable and
* the wrapping Generator, which will execute up to its first yield statement.
* Additionally, this mimics behavior of the SPL iterators and allows users
* to omit an explicit call to rewind() before using the other methods.
* Initialize the iterator. This effectively rewinds the Traversable.
* This mimics behavior of the SPL iterators and allows users to omit an
* explicit call to rewind() before using the other methods.
*
* @param Traversable<mixed, TValue> $iterator
*/
public function __construct(Traversable $iterator)
{
$this->iterator = $this->wrapTraversable($iterator);
$this->iterator->key();
$this->iterator = new IteratorIterator($iterator);
$this->iterator->rewind();
}

public function toArray(): array
{
$this->preventRewinding(__METHOD__);

$toArray = function () {
if (! $this->valid()) {
return;
}

yield $this->key() => $this->current();
yield from $this->getIterator();
};

return iterator_to_array($toArray());
try {
return iterator_to_array($this->getIterator());
} finally {
$this->iteratorAdvanced = true;
$this->iterator = null;
}
}

/** @return TValue|null */
Expand Down Expand Up @@ -84,6 +80,13 @@ public function next(): void
}

$this->iterator->next();
$this->iteratorAdvanced = true;

if ($this->iterator->valid()) {
return;
}

$this->iterator = null;
}

/** @see http://php.net/iterator.rewind */
Expand All @@ -108,29 +111,13 @@ private function preventRewinding(string $method): void
}
}

/** @return Generator<mixed, TValue> */
private function getIterator(): Generator
/** @return SPLIterator<mixed, TValue> */
private function getIterator(): SPLIterator
{
if ($this->iterator === null) {
throw new RuntimeException('Iterator has already been destroyed');
}

return $this->iterator;
}

/**
* @param Traversable<mixed, TValue> $traversable
*
* @return Generator<mixed, TValue>
*/
private function wrapTraversable(Traversable $traversable): Generator
{
foreach ($traversable as $key => $value) {
yield $key => $value;

$this->iteratorAdvanced = true;
}

$this->iterator = null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ public function testIterationWithEmptySet(): void
self::assertFalse($iterator->valid());
}

public function testIterationWithInvalidIterator(): void
{
$mock = $this->createMock(Iterator::class);
// The method next() should not be called on a dead cursor.
$mock->expects(self::never())->method('next');
// The method valid() return false on a dead cursor.
$mock->expects(self::once())->method('valid')->willReturn(false);

$iterator = new CachingIterator($mock);

$this->assertEquals([], $iterator->toArray());
}

public function testPartialIterationDoesNotExhaust(): void
{
$traversable = $this->getTraversableThatThrows([1, 2, new Exception()]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ public function testRewindAfterPartialIteration(): void
iterator_to_array($iterator);
}

public function testRewindAfterToArray(): void
{
$iterator = new UnrewindableIterator($this->getTraversable([1, 2, 3]));

$iterator->toArray();
$this->expectException(LogicException::class);
$iterator->rewind();
}

public function testToArray(): void
{
$iterator = new UnrewindableIterator($this->getTraversable([1, 2, 3]));
Expand Down

0 comments on commit ad7827c

Please sign in to comment.