Skip to content

Commit

Permalink
Performance optimizations for RandomStringUtils
Browse files Browse the repository at this point in the history
This commit improves the performance of RandomStringUtils:

* Reduces the number of random bytes generated and the number of calls to the random number generator, by using a cache system `AmortizedRandomBits`.
* Optimizes the case of alphanumerical strings, reducing the number of rejections in the rejection sampling.

See comments in code for details.
  • Loading branch information
Fabrice Benhamouda committed Jun 14, 2024
1 parent 23cb811 commit 55a70c8
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 2 deletions.
92 changes: 92 additions & 0 deletions src/main/java/org/apache/commons/lang3/AmortizedRandomBits.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.lang3;

import java.util.Random;

/**
* AmortizedRandomBits enable to generate random integers of specific bit length.
*
* <p>It is more efficient way than calling Random.nextInt(1 << nbBits). It uses a cache of
* cacheSize random bytes that it replenishes when it gets empty. This is especially beneficial for
* SecureRandom Drbg implementations that incur a constant cost at each randomness generation. It is
* not thread safe.
*
* <p>Used internally by RandomStringUtils.
*/
class AmortizedRandomBits {
private final Random random;

private final byte[] cache;

// bitIndex is the index of the next bit in the cache to be used
// bitIndex=0 means the cache is fully random and none of the bits have been used yet
// bitIndex=1 means that only the LSB of cache[0] has been used and all other bits can be used
// bitIndex=8 means that only the 8 bits of cache[0] has been used
private int bitIndex;

/**
* @param cacheSize number of bytes cached (only affects performance)
* @param random random source
*/
AmortizedRandomBits(final int cacheSize, final Random random) {
if (cacheSize <= 0) {
throw new IllegalArgumentException("cacheSize must be positive");
}
this.cache = new byte[cacheSize];
this.random = random;
this.random.nextBytes(this.cache);
this.bitIndex = 0;
}

/**
* nextBits returns a random integer with the number of bits specified
*
* @param bits number of bits to generate, MUST be between 1 and 32
* @return random integer with {@code bits} bits
*/
public int nextBits(final int bits) {
if (bits > 32 || bits <= 0) {
throw new IllegalArgumentException("number of bits must be between 1 and 32");
}

int result = 0;
int generatedBits = 0; // number of generated bits up to now

while (generatedBits < bits) {
if (bitIndex / 8 >= cache.length) {

This comment has been minimized.

Copy link
@aherbert

aherbert Jul 4, 2024

Contributor

This class should replace:

n / 8   ==>   n >> 3
n % 8   ==>   n & 0x7
// we exhausted the number of bits in the cache
// this should only happen if the bitIndex is exactly matching the cache length
assert bitIndex == cache.length * 8;
random.nextBytes(cache);
bitIndex = 0;
}

// generatedBitsInIteration is the number of bits that we will generate
// in this iteration of the while loop
int generatedBitsInIteration = Math.min(8 - (bitIndex % 8), bits - generatedBits);

result = result << generatedBitsInIteration;
result |= (cache[bitIndex / 8] >> (bitIndex % 8)) & ((1 << generatedBitsInIteration) - 1);

generatedBits += generatedBitsInIteration;
bitIndex += generatedBitsInIteration;
}

return result;
}
}
63 changes: 61 additions & 2 deletions src/main/java/org/apache/commons/lang3/RandomStringUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ private static SecureRandom random() {
return RANDOM.get();
}

private static final char[] ALPHANUMERICAL_CHARS = {
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'
};

/**
* Creates a random string whose length is the number of characters
* specified.
Expand Down Expand Up @@ -226,6 +234,42 @@ public static String random(int count, int start, int end, final boolean letters
throw new IllegalArgumentException("Parameter end (" + end + ") must be greater than start (" + start + ")");
}

if (end > Character.MAX_CODE_POINT) {
// Technically, it should be `Character.MAX_CODE_POINT+1` as `end` is excluded
// But the character `Character.MAX_CODE_POINT` is private use, so it would anyway be excluded
end = Character.MAX_CODE_POINT;
}

// Optimize generation of full alphanumerical characters
// Normally, we would need to pick a 7-bit integer, since gap = 'z' - '0' + 1 = 75 > 64
// In turn, this would make us reject the sampling with probability 1 - 62 / 2^7 > 1 / 2
// Instead we can pick directly from the right set of 62 characters, which requires
// picking a 6-bit integer and only rejecting with probability 2 / 64 = 1 / 32
if (chars == null && letters && numbers && start <= '0' && end >= 'z' + 1) {
return random(count, 0, 0, false, false, ALPHANUMERICAL_CHARS, random);
}

// Optimize start and end when filtering by letters and/or numbers:
// The range provided may be too large since we filter anyway afterward.
// Note the use of Math.min/max (as opposed to setting start to '0' for example),
// since it is possible the range start/end excludes some of the letters/numbers,
// e.g., it is possible that start already is '1' when numbers = true, and start
// needs to stay equal to '1' in that case.
if (chars == null) {
if (letters && numbers) {
start = Math.max('0', start);
end = Math.min('z' + 1, end);
} else if (numbers) {
// just numbers, no letters
start = Math.max('0', start);
end = Math.min('9' + 1, end);
} else if (letters) {
// just letters, no numbers
start = Math.max('A', start);
end = Math.min('z' + 1, end);
}
}

final int zeroDigitAscii = 48;
final int firstLetterAscii = 65;

Expand All @@ -237,11 +281,26 @@ public static String random(int count, int start, int end, final boolean letters

final StringBuilder builder = new StringBuilder(count);
final int gap = end - start;
final int gapBits = Integer.SIZE - Integer.numberOfLeadingZeros(gap);
// The size of the cache we use is an heuristic:
// about twice the number of bytes required if no rejection
// Ideally the cache size depends on multiple factor, including the cost of generating x bytes
// of randomness as well as the probability of rejection. It is however not easy to know
// those values programmatically for the general case.
final AmortizedRandomBits arb = new AmortizedRandomBits((count * gapBits + 3) / 5 + 10, random);

while (count-- != 0) {
// Generate a random value between start (included) and end (excluded)
final int randomValue = arb.nextBits(gapBits) + start;
// Rejection sampling if value too large
if (randomValue >= end) {
count++;
continue;
}

final int codePoint;
if (chars == null) {
codePoint = random.nextInt(gap) + start;
codePoint = randomValue;

switch (Character.getType(codePoint)) {
case Character.UNASSIGNED:
Expand All @@ -252,7 +311,7 @@ public static String random(int count, int start, int end, final boolean letters
}

} else {
codePoint = chars[random.nextInt(gap) + start];
codePoint = chars[randomValue];
}

final int numberOfChars = Character.charCount(codePoint);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.lang3;

import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import java.util.Random;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

public class AmortizedRandomBitsTest {
/** MockRandom mocks a Random class nextBytes to use a specific list of outputs */
private static class MockRandom extends Random {
private final byte[] outputs;
private int index;

MockRandom(final byte[] outputs) {
super();
this.outputs = outputs.clone();
this.index = 0;
}

@Override
public void nextBytes(byte[] bytes) {
if (index + bytes.length > outputs.length) {
throw new RuntimeException("not enough outputs given in MockRandom");
}
System.arraycopy(outputs, index, bytes, 0, bytes.length);
index += bytes.length;
}
}

@ParameterizedTest
@ValueSource(ints = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 32})
public void testNext(int cacheSize) {
MockRandom random = new MockRandom(new byte[]{
0x11, 0x12, 0x13, 0x25,
(byte) 0xab, (byte) 0xcd, (byte) 0xef, (byte) 0xff,
0x55, 0x44, 0x12, 0x34,
0x56, 0x78, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
});

AmortizedRandomBits arb = new AmortizedRandomBits(cacheSize, random);

assertThrows(IllegalArgumentException.class, () -> arb.nextBits(0));
assertThrows(IllegalArgumentException.class, () -> arb.nextBits(33));

assertEquals(0x11, arb.nextBits(8));
assertEquals(0x12, arb.nextBits(8));
assertEquals(0x1325, arb.nextBits(16));

assertEquals((int) 0xabcdefff, arb.nextBits(32));

assertEquals(0x5, arb.nextBits(4));
assertEquals(0x1, arb.nextBits(1));
assertEquals(0x0, arb.nextBits(1));
assertEquals(0x1, arb.nextBits(2));

assertEquals(0x4, arb.nextBits(6));

assertEquals(0x40000000 | (0x12345600 >> 2) | 0x38, arb.nextBits(32));

assertEquals(1, arb.nextBits(1));
assertEquals(0, arb.nextBits(1));
assertEquals(0, arb.nextBits(9));
assertEquals(0, arb.nextBits(31));
}
}
25 changes: 25 additions & 0 deletions src/test/java/org/apache/commons/lang3/RandomStringUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.fail;

import java.lang.reflect.Constructor;
Expand Down Expand Up @@ -518,4 +519,28 @@ public void testRandomStringUtilsHomog() {
// critical value: from scipy.stats import chi2; chi2(2).isf(1e-5)
assertThat("test homogeneity -- will fail about 1 in 100,000 times", chiSquare(expected, counts), lessThan(23.025850929940457d));
}

/**
* Test {@code RandomStringUtils.random} works appropriately when chars specified.
*/
@Test
void testRandomWithChars() {
final char[] digitChars = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'};

String r1, r2, r3;

r1 = RandomStringUtils.random(50, 0, 0, true, true, digitChars);
assertEquals(50, r1.length(), "randomNumeric(50)");
for (int i = 0; i < r1.length(); i++) {
assertTrue(
Character.isDigit(r1.charAt(i)) && !Character.isLetter(r1.charAt(i)),
"r1 contains numeric");
}
r2 = RandomStringUtils.randomNumeric(50);
assertNotEquals(r1, r2);

r3 = RandomStringUtils.random(50, 0, 0, true, true, digitChars);
assertNotEquals(r1, r3);
assertNotEquals(r2, r3);
}
}

0 comments on commit 55a70c8

Please sign in to comment.