diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/client/AsyncTableImpl.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/client/AsyncTableImpl.java index e75a9411efb5..e19ed207c24b 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/client/AsyncTableImpl.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/client/AsyncTableImpl.java @@ -28,6 +28,7 @@ import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Phaser; import java.util.concurrent.TimeUnit; import java.util.function.Function; import org.apache.hadoop.conf.Configuration; @@ -298,19 +299,39 @@ public CoprocessorServiceBuilder coprocessorService( final Context context = Context.current(); CoprocessorCallback wrappedCallback = new CoprocessorCallback() { + private final Phaser regionCompletesInProgress = new Phaser(1); + @Override public void onRegionComplete(RegionInfo region, R resp) { - pool.execute(context.wrap(() -> callback.onRegionComplete(region, resp))); + regionCompletesInProgress.register(); + pool.execute(context.wrap(() -> { + try { + callback.onRegionComplete(region, resp); + } finally { + regionCompletesInProgress.arriveAndDeregister(); + } + })); } @Override public void onRegionError(RegionInfo region, Throwable error) { - pool.execute(context.wrap(() -> callback.onRegionError(region, error))); + regionCompletesInProgress.register(); + pool.execute(context.wrap(() -> { + try { + callback.onRegionError(region, error); + } finally { + regionCompletesInProgress.arriveAndDeregister(); + } + })); } @Override public void onComplete() { - pool.execute(context.wrap(callback::onComplete)); + pool.execute(context.wrap(() -> { + // Guarantee that onComplete() is called after all onRegionComplete()'s are called + regionCompletesInProgress.arriveAndAwaitAdvance(); + callback.onComplete(); + })); } @Override diff --git a/hbase-endpoint/src/test/java/org/apache/hadoop/hbase/client/TestAsyncAggregationClientWithCallbackThreadPool.java b/hbase-endpoint/src/test/java/org/apache/hadoop/hbase/client/TestAsyncAggregationClientWithCallbackThreadPool.java new file mode 100644 index 000000000000..7b37ddfd1555 --- /dev/null +++ b/hbase-endpoint/src/test/java/org/apache/hadoop/hbase/client/TestAsyncAggregationClientWithCallbackThreadPool.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.client; + +import static org.junit.Assert.assertEquals; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.stream.Collectors; +import java.util.stream.LongStream; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.HBaseClassTestRule; +import org.apache.hadoop.hbase.HBaseTestingUtility; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.coprocessor.AsyncAggregationClient; +import org.apache.hadoop.hbase.client.coprocessor.LongColumnInterpreter; +import org.apache.hadoop.hbase.coprocessor.AggregateImplementation; +import org.apache.hadoop.hbase.coprocessor.CoprocessorHost; +import org.apache.hadoop.hbase.testclassification.CoprocessorTests; +import org.apache.hadoop.hbase.testclassification.MediumTests; +import org.apache.hadoop.hbase.util.Bytes; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +/** + * Same as TestAsyncAggregationClient, except that {@link AsyncTableImpl} is involved in addition to + * {@link RawAsyncTableImpl}. Exercises the code paths in {@link AsyncTableImpl#coprocessorService}. + */ +@Category({ MediumTests.class, CoprocessorTests.class }) +public class TestAsyncAggregationClientWithCallbackThreadPool { + + @ClassRule + public static final HBaseClassTestRule CLASS_RULE = + HBaseClassTestRule.forClass(TestAsyncAggregationClientWithCallbackThreadPool.class); + + private static HBaseTestingUtility UTIL = new HBaseTestingUtility(); + + private static TableName TABLE_NAME = TableName.valueOf("TestAsyncAggregationClient"); + + private static byte[] CF = Bytes.toBytes("CF"); + + private static byte[] CQ = Bytes.toBytes("CQ"); + + private static byte[] CQ2 = Bytes.toBytes("CQ2"); + + private static long COUNT = 1000; + + private static AsyncConnection CONN; + + private static AsyncTable TABLE; + + private static ExecutorService EXECUTOR_SERVICE; + + @BeforeClass + public static void setUp() throws Exception { + Configuration conf = UTIL.getConfiguration(); + conf.setStrings(CoprocessorHost.REGION_COPROCESSOR_CONF_KEY, + AggregateImplementation.class.getName()); + UTIL.startMiniCluster(3); + byte[][] splitKeys = new byte[8][]; + for (int i = 111; i < 999; i += 111) { + splitKeys[i / 111 - 1] = Bytes.toBytes(String.format("%03d", i)); + } + UTIL.createTable(TABLE_NAME, CF, splitKeys); + CONN = ConnectionFactory.createAsyncConnection(UTIL.getConfiguration()).get(); + EXECUTOR_SERVICE = Executors.newFixedThreadPool(1); + TABLE = CONN.getTable(TABLE_NAME, EXECUTOR_SERVICE); + TABLE.putAll(LongStream.range(0, COUNT) + .mapToObj(l -> new Put(Bytes.toBytes(String.format("%03d", l))) + .addColumn(CF, CQ, Bytes.toBytes(l)).addColumn(CF, CQ2, Bytes.toBytes(l * l))) + .collect(Collectors.toList())).get(); + } + + @AfterClass + public static void tearDown() throws Exception { + CONN.close(); + UTIL.shutdownMiniCluster(); + EXECUTOR_SERVICE.shutdownNow(); + } + + @Test + public void testMax() throws InterruptedException, ExecutionException { + assertEquals(COUNT - 1, AsyncAggregationClient + .max(TABLE, new LongColumnInterpreter(), new Scan().addColumn(CF, CQ)).get().longValue()); + } + + @Test + public void testMin() throws InterruptedException, ExecutionException { + assertEquals(0, AsyncAggregationClient + .min(TABLE, new LongColumnInterpreter(), new Scan().addColumn(CF, CQ)).get().longValue()); + } + + @Test + public void testRowCount() throws InterruptedException, ExecutionException { + assertEquals(COUNT, + AsyncAggregationClient + .rowCount(TABLE, new LongColumnInterpreter(), new Scan().addColumn(CF, CQ)).get() + .longValue()); + + // Run the count twice in case some state doesn't get cleaned up inside AsyncTableImpl + // on the first time. + assertEquals(COUNT, + AsyncAggregationClient + .rowCount(TABLE, new LongColumnInterpreter(), new Scan().addColumn(CF, CQ)).get() + .longValue()); + } + + @Test + public void testSum() throws InterruptedException, ExecutionException { + assertEquals(COUNT * (COUNT - 1) / 2, AsyncAggregationClient + .sum(TABLE, new LongColumnInterpreter(), new Scan().addColumn(CF, CQ)).get().longValue()); + } + + private static final double DELTA = 1E-3; + + @Test + public void testAvg() throws InterruptedException, ExecutionException { + assertEquals( + (COUNT - 1) / 2.0, AsyncAggregationClient + .avg(TABLE, new LongColumnInterpreter(), new Scan().addColumn(CF, CQ)).get().doubleValue(), + DELTA); + } + + @Test + public void testStd() throws InterruptedException, ExecutionException { + double avgSq = + LongStream.range(0, COUNT).map(l -> l * l).reduce((l1, l2) -> l1 + l2).getAsLong() + / (double) COUNT; + double avg = (COUNT - 1) / 2.0; + double std = Math.sqrt(avgSq - avg * avg); + assertEquals( + std, AsyncAggregationClient + .std(TABLE, new LongColumnInterpreter(), new Scan().addColumn(CF, CQ)).get().doubleValue(), + DELTA); + } + +}