Skip to content

Commit

Permalink
Guard WindowPredicate with doOnDiscard(…).
Browse files Browse the repository at this point in the history
Windowed fluxes now properly discard ref-counted objects avoiding memory leaks upon cancellation.

[#492]

Signed-off-by: Mark Paluch <mpaluch@vmware.com>
  • Loading branch information
mp911de committed Feb 17, 2022
1 parent 323d57c commit abe7634
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 41 deletions.
26 changes: 14 additions & 12 deletions src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.r2dbc.postgresql.client.Binding;
import io.r2dbc.postgresql.client.ConnectionContext;
import io.r2dbc.postgresql.client.EncodedParameter;
Expand Down Expand Up @@ -230,20 +232,20 @@ private Flux<io.r2dbc.postgresql.api.PostgresqlResult> execute(String sql) {
.doOnSubscribe(it -> bindings.emitNext(iterator.next(), Sinks.EmitFailureHandler.FAIL_FAST));

}).cast(io.r2dbc.postgresql.api.PostgresqlResult.class);
} else {
// Simple Query protocol
if (this.fetchSize != NO_LIMIT) {
return ExtendedFlowDelegate.runQuery(this.resources, factory, sql, Binding.EMPTY, Collections.emptyList(), this.fetchSize)
.windowUntil(WINDOW_UNTIL)
.map(messages -> PostgresqlResult.toResult(this.resources, messages, factory))
.as(Operators::discardOnCancel);
}
}

return SimpleQueryMessageFlow.exchange(this.resources.getClient(), sql)
.windowUntil(WINDOW_UNTIL)
.map(messages -> PostgresqlResult.toResult(this.resources, messages, factory))
.as(Operators::discardOnCancel);
Flux<BackendMessage> exchange;
// Simple Query protocol
if (this.fetchSize != NO_LIMIT) {
exchange = ExtendedFlowDelegate.runQuery(this.resources, factory, sql, Binding.EMPTY, Collections.emptyList(), this.fetchSize);
} else {
exchange = SimpleQueryMessageFlow.exchange(this.resources.getClient(), sql);
}

return exchange.windowUntil(WINDOW_UNTIL)
.doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release) // ensure release of rows within WindowPredicate
.map(messages -> PostgresqlResult.toResult(this.resources, messages, factory))
.as(Operators::discardOnCancel);
}

private static void tryNextBinding(Iterator<Binding> iterator, Sinks.Many<Binding> bindingSink, AtomicBoolean canceled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,14 @@ public final class DataRow extends AbstractReferenceCounted implements BackendMe
* @throws IllegalArgumentException if {@code columns} is {@code null}
*/
public DataRow(ByteBuf... columns) {
this.columns = Assert.requireNonNull(columns, "columns must not be null");

if (columns == null) {
this.columns = new ByteBuf[0];
release();
throw new IllegalArgumentException("columns must not be null");
}

this.columns = columns;
}

@Override
Expand Down
13 changes: 11 additions & 2 deletions src/test/java/io/r2dbc/postgresql/PostgresqlRowUnitTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import io.r2dbc.postgresql.codec.MockCodecs;
import io.r2dbc.postgresql.message.backend.DataRow;
import io.r2dbc.postgresql.message.backend.RowDescription;
import io.r2dbc.postgresql.util.ReferenceCountedCleaner;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
Expand All @@ -41,6 +43,8 @@
*/
final class PostgresqlRowUnitTests {

private final ReferenceCountedCleaner cleaner = new ReferenceCountedCleaner();

private final List<RowDescription.Field> columns = Arrays.asList(
new RowDescription.Field((short) 100, 200, 300, (short) 400, FORMAT_BINARY, "test-name-1", 500),
new RowDescription.Field((short) 300, 400, 300, (short) 400, FORMAT_TEXT, "test-name-2", 500),
Expand All @@ -49,6 +53,11 @@ final class PostgresqlRowUnitTests {

private final ByteBuf[] data = new ByteBuf[]{TEST.buffer(4).writeInt(100), TEST.buffer(4).writeInt(300), null};

@AfterEach
void tearDown() {
cleaner.clean();
}

@Test
void constructorNoContext() {
assertThatIllegalArgumentException().isThrownBy(() -> new PostgresqlRow(null, null, Collections.emptyList(), null))
Expand Down Expand Up @@ -156,7 +165,7 @@ void toRow() {
.build();

RowDescription description = new RowDescription(Collections.singletonList(new RowDescription.Field((short) 200, 300, (short) 400, (short) 500, FORMAT_TEXT, "test-name-1", 600)));
PostgresqlRow row = PostgresqlRow.toRow(MockContext.builder().codecs(codecs).build(), new DataRow(TEST.buffer(4).writeInt(100)),
PostgresqlRow row = PostgresqlRow.toRow(MockContext.builder().codecs(codecs).build(), cleaner.capture(new DataRow(TEST.buffer(4).writeInt(100))),
codecs, description);

assertThat(row.get(0, Object.class)).isSameAs(value);
Expand All @@ -170,7 +179,7 @@ void toRowNoDataRow() {

@Test
void toRowNoRowDescription() {
assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlRow.toRow(MockContext.empty(), new DataRow(TEST.buffer(4).writeInt(100)), MockCodecs.empty(), null))
assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlRow.toRow(MockContext.empty(), cleaner.capture(new DataRow(TEST.buffer(4).writeInt(100))), MockCodecs.empty(), null))
.withMessage("rowDescription must not be null");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
package io.r2dbc.postgresql.message.backend;

import io.netty.buffer.ByteBuf;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.r2dbc.postgresql.util.ReferenceCountedCleaner;
import org.assertj.core.api.AbstractObjectAssert;
import org.assertj.core.api.ObjectAssert;
import org.springframework.util.ReflectionUtils;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;

Expand All @@ -35,7 +34,7 @@
*/
final class BackendMessageAssert extends AbstractObjectAssert<BackendMessageAssert, Class<? extends BackendMessage>> {

private Cleaner cleaner = new Cleaner();
private ReferenceCountedCleaner cleaner = new ReferenceCountedCleaner();

private BackendMessageAssert(Class<? extends BackendMessage> actual) {
super(actual, BackendMessageAssert.class);
Expand All @@ -45,7 +44,7 @@ static BackendMessageAssert assertThat(Class<? extends BackendMessage> actual) {
return new BackendMessageAssert(actual);
}

BackendMessageAssert cleaner(Cleaner cleaner) {
BackendMessageAssert cleaner(ReferenceCountedCleaner cleaner) {
this.cleaner = cleaner;
return this;
}
Expand All @@ -61,28 +60,11 @@ <T extends BackendMessage> ObjectAssert<T> decoded(Function<ByteBuf, ByteBuf> de
ReflectionUtils.makeAccessible(method);
T actual = (T) ReflectionUtils.invokeMethod(method, null, decoded.apply(TEST.buffer()));

return new ObjectAssert<>(this.cleaner.capture(actual));
return new ObjectAssert<>((T) (actual instanceof ReferenceCounted ? this.cleaner.capture((ReferenceCounted) actual) : actual));
}

public Cleaner cleaner() {
public ReferenceCountedCleaner cleaner() {
return this.cleaner;
}

static class Cleaner {

private final List<Object> objects = new ArrayList<>();

public void clean() {
this.objects.forEach(ReferenceCountUtil::release);
this.objects.clear();
}

public <T> T capture(T object) {
this.objects.add(object);

return object;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
package io.r2dbc.postgresql.message.backend;

import io.netty.buffer.ByteBuf;
import io.r2dbc.postgresql.util.ReferenceCountedCleaner;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;

import static io.r2dbc.postgresql.message.backend.BackendMessageAssert.Cleaner;
import static io.r2dbc.postgresql.message.backend.BackendMessageAssert.assertThat;
import static io.r2dbc.postgresql.util.TestByteBufAllocator.TEST;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
Expand All @@ -30,7 +30,7 @@
*/
final class DataRowUnitTests {

private final Cleaner cleaner = new Cleaner();
private final ReferenceCountedCleaner cleaner = new ReferenceCountedCleaner();

@AfterEach
void tearDown() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.r2dbc.postgresql.util;

import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;

import java.util.ArrayList;
import java.util.List;

public class ReferenceCountedCleaner {

private final List<Object> objects = new ArrayList<>();

public void clean() {
this.objects.forEach(ReferenceCountUtil::release);
this.objects.clear();
}

public <T extends ReferenceCounted> T capture(T object) {
this.objects.add(object);

return object;
}

}

0 comments on commit abe7634

Please sign in to comment.