Skip to content

Commit

Permalink
Tasks: don't advance task RNG on task spawn (#49110)
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanKarpinski authored Mar 31, 2023
1 parent 8327e85 commit 7618e64
Show file tree
Hide file tree
Showing 10 changed files with 259 additions and 43 deletions.
1 change: 1 addition & 0 deletions base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ let
task.rngState1 = 0x7431eaead385992c
task.rngState2 = 0x503e1d32781c2608
task.rngState3 = 0x3a77f7189200c20b
task.rngState4 = 0x5502376d099035ae

# Stdlibs sorted in dependency, then alphabetical, order by contrib/print_sorted_stdlibs.jl
# Run with the `--exclude-jlls` option to filter out all JLL packages
Expand Down
6 changes: 3 additions & 3 deletions src/gc.c
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,9 @@ static void jl_gc_run_finalizers_in_list(jl_task_t *ct, arraylist_t *list) JL_NO
ct->sticky = sticky;
}

static uint64_t finalizer_rngState[4];
static uint64_t finalizer_rngState[JL_RNG_SIZE];

void jl_rng_split(uint64_t to[4], uint64_t from[4]) JL_NOTSAFEPOINT;
void jl_rng_split(uint64_t dst[JL_RNG_SIZE], uint64_t src[JL_RNG_SIZE]) JL_NOTSAFEPOINT;

JL_DLLEXPORT void jl_gc_init_finalizer_rng_state(void)
{
Expand Down Expand Up @@ -532,7 +532,7 @@ static void run_finalizers(jl_task_t *ct)
jl_atomic_store_relaxed(&jl_gc_have_pending_finalizers, 0);
arraylist_new(&to_finalize, 0);

uint64_t save_rngState[4];
uint64_t save_rngState[JL_RNG_SIZE];
memcpy(&save_rngState[0], &ct->rngState[0], sizeof(save_rngState));
jl_rng_split(ct->rngState, finalizer_rngState);

Expand Down
6 changes: 4 additions & 2 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2769,7 +2769,7 @@ void jl_init_types(void) JL_GC_DISABLED
NULL,
jl_any_type,
jl_emptysvec,
jl_perm_symsvec(15,
jl_perm_symsvec(16,
"next",
"queue",
"storage",
Expand All @@ -2781,11 +2781,12 @@ void jl_init_types(void) JL_GC_DISABLED
"rngState1",
"rngState2",
"rngState3",
"rngState4",
"_state",
"sticky",
"_isexception",
"priority"),
jl_svec(15,
jl_svec(16,
jl_any_type,
jl_any_type,
jl_any_type,
Expand All @@ -2797,6 +2798,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_uint64_type,
jl_uint64_type,
jl_uint64_type,
jl_uint64_type,
jl_uint8_type,
jl_bool_type,
jl_bool_type,
Expand Down
4 changes: 3 additions & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1911,6 +1911,8 @@ typedef struct _jl_handler_t {
size_t world_age;
} jl_handler_t;

#define JL_RNG_SIZE 5 // xoshiro 4 + splitmix 1

typedef struct _jl_task_t {
JL_DATA_TYPE
jl_value_t *next; // invasive linked list for scheduler
Expand All @@ -1922,7 +1924,7 @@ typedef struct _jl_task_t {
jl_function_t *start;
// 4 byte padding on 32-bit systems
// uint32_t padding0;
uint64_t rngState[4];
uint64_t rngState[JL_RNG_SIZE];
_Atomic(uint8_t) _state;
uint8_t sticky; // record whether this Task can be migrated to a new thread
_Atomic(uint8_t) _isexception; // set if `result` is an exception to throw or that we exited with
Expand Down
201 changes: 180 additions & 21 deletions src/task.c
Original file line number Diff line number Diff line change
Expand Up @@ -866,28 +866,187 @@ uint64_t jl_genrandom(uint64_t rngState[4]) JL_NOTSAFEPOINT
return res;
}

void jl_rng_split(uint64_t to[4], uint64_t from[4]) JL_NOTSAFEPOINT
/*
The jl_rng_split function forks a task's RNG state in a way that is essentially
guaranteed to avoid collisions between the RNG streams of all tasks. The main
RNG is the xoshiro256++ RNG whose state is stored in rngState[0..3]. There is
also a small internal RNG used for task forking stored in rngState[4]. This
state is used to iterate a LCG (linear congruential generator), which is then
put through four different variations of the strongest PCG output function,
referred to as PCG-RXS-M-XS-64 [1]. This output function is invertible: it maps
a 64-bit state to 64-bit output; which is one of the reasons it's not
recommended for general purpose RNGs unless space is at a premium, but in our
usage invertibility is actually a benefit, as is explained below.
The goal of jl_rng_split is to perturb the state of each child task's RNG in
such a way each that for an entire tree of tasks spawned starting with a given
state in a root task, no two tasks have the same RNG state. Moreover, we want to
do this in a way that is deterministic and repeatable based on (1) the root
task's seed, (2) how many random numbers are generated, and (3) the task tree
structure. The RNG state of a parent task is allowed to affect the initial RNG
state of a child task, but the mere fact that a child was spawned should not
alter the RNG output of the parent. This second requirement rules out using the
main RNG to seed children -- some separate state must be maintained and changed
upon forking a child task while leaving the main RNG state unchanged.
The basic approach is that used by the DotMix [2] and SplitMix [3] RNG systems:
each task is uniquely identified by a sequence of "pedigree" numbers, indicating
where in the task tree it was spawned. This vector of pedigree coordinates is
then reduced to a single value by computing a dot product with a common vector
of random weights. The DotMix paper provides a proof that this dot product hash
value (referred to as a "compression function") is collision resistant in the
sense the the pairwise collision probability of two distinct tasks is 1/N where
N is the number of possible weight values. Both DotMix and SplitMix use a prime
value of N because the proof requires that the difference between two distinct
pedigree coordinates must be invertible, which is guaranteed by N being prime.
We take a different approach: we instead limit pedigree coordinates to being
binary instead -- when a task spawns a child, both tasks share the same pedigree
prefix, with the parent appending a zero and the child appending a one. This way
a binary pedigree vector uniquely identifies each task. Moreover, since the
coordinates are binary, the difference between coordinates is always one which
is its own inverse regardless of whether N is prime or not. This allows us to
compute the dot product modulo 2^64 using native machine arithmetic, which is
considerably more efficient and simpler to implement than arithmetic in a prime
modulus. It also means that when accumulating the dot product incrementally, as
described in SplitMix, we don't need to multiply weights by anything, we simply
add the random weight for the current task tree depth to the parent's dot
product to derive the child's dot product.
We use the LCG in rngState[4] to derive generate pseudorandom weights for the
dot product. Each time a child is forked, we update the LCG in both parent and
child tasks. In the parent, that's all we have to do -- the main RNG state
remains unchanged (recall that spawning a child should *not* affect subsequence
RNG draws in the parent). The next time the parent forks a child, the dot
product weight used will be different, corresponding to being a level deeper in
the binary task tree. In the child, we use the LCG state to generate four
pseudorandom 64-bit weights (more below) and add each weight to one of the
xoshiro256 state registers, rngState[0..3]. If we assume the main RNG remains
unused in all tasks, then each register rngState[0..3] accumulates a different
Dot/SplitMix dot product hash as additional child tasks are spawned. Each one is
collision resistant with a pairwise collision chance of only 1/2^64. Assuming
that the four pseudorandom 64-bit weight streams are sufficiently independent,
the pairwise collision probability for distinct tasks is 1/2^256. If we somehow
managed to spawn a trillion tasks, the probability of a collision would be on
the order of 1/10^54. Practically impossible. Put another way, this is the same
as the probability of two SHA256 hash values accidentally colliding, which we
generally consider so unlikely as not to be worth worrying about.
What about the random "junk" that's in the xoshiro256 state registers from
normal use of the RNG? For a tree of tasks spawned with no intervening samples
taken from the main RNG, all tasks start with the same junk which doesn't affect
the chance of collision. The Dot/SplitMix papers even suggest adding a random
base value to the dot product, so we can consider whatever happens to be in the
xoshiro256 registers to be that. What if the main RNG gets used between task
forks? In that case, the initial state registers will be different. The DotMix
collision resistance proof doesn't apply without modification, but we can
generalize the setup by adding a different base constant to each compression
function and observe that we still have a 1/N chance of the weight value
matching that exact difference. This proves collision resistance even between
tasks whose dot product hashes are computed with arbitrary offsets. We can
conclude that this scheme provides collision resistance even in the face of
different starting states of the main RNG. Does this seem too good to be true?
Perhaps another way of thinking about it will help. Suppose we seeded each task
completely randomly. Then there would also be a 1/2^256 chance of collision,
just as the DotMix proof gives. Essentially what the proof is telling us is that
if the weights are chosen uniformly and uncorrelated with the rest of the
compression function, then the dot product construction is a good enough way to
pseudorandomly seed each task. From that perspective, it's easier to believe
that adding an arbitrary constant to each seed doesn't worsen its randomness.
This leaves us with the question of how to generate four pseudorandom weights to
add to the rngState[0..3] registers at each depth of the task tree. The scheme
used here is that a single 64-bit LCG state is iterated in both parent and child
at each task fork, and four different variations of the PCG-RXS-M-XS-64 output
function are applied to that state to generate four different pseudorandom
weights. Another obvious way to generate four weights would be to iterate the
LCG four times per task split. There are two main reasons we've chosen to use
four output variants instead:
1. Advancing four times per fork reduces the set of possible weights that each
register can be perturbed by from 2^64 to 2^60. Since collision resistance is
proportional to the number of possible weight values, that would reduce
collision resistance.
2. It's easier to compute four PCG output variants in parallel. Iterating the
LCG is inherently sequential. Each PCG variant can be computed independently
from the LCG state. All four can even be computed at once with SIMD vector
instructions, but the compiler doesn't currently choose to do that.
A key question is whether the approach of using four variations of PCG-RXS-M-XS
is sufficiently random both within and between streams to provide the collision
resistance we expect. We obviously can't test that with 256 bits, but we have
tested it with a reduced state analogue using four PCG-RXS-M-XS-8 output
variations applied to a common 8-bit LCG. Test results do indicate sufficient
independence: a single register has collisions at 2^5 while four registers only
start having collisions at 2^20, which is actually better scaling of collision
resistance than we expect in theory. In theory, with one byte of resistance we
have a 50% chance of some collision at 20, which matches, but four bytes gives a
50% chance of collision at 2^17 and our (reduced size analogue) construction is
still collision free at 2^19. This may be due to the next observation, which guarantees collision avoidance for certain shapes of task trees as a result of using an
invertible RNG to generate weights.
In the specific case where a parent task spawns a sequence of child tasks with
no intervening usage of its main RNG, the parent and child tasks are actually
_guaranteed_ to have different RNG states. This is true because the four PCG
streams each produce every possible 2^64 bit output exactly once in the full
2^64 period of the LCG generator. This is considered a weakness of PCG-RXS-M-XS
when used as a general purpose RNG, but is quite beneficial in this application.
Since each of up to 2^64 children will be perturbed by different weights, they
cannot have hash collisions. What about parent colliding with child? That can
only happen if all four main RNG registers are perturbed by exactly zero. This
seems unlikely, but could it occur? Consider this part of each output function:
p ^= p >> ((p >> 59) + 5);
p *= m[i];
p ^= p >> 43
It's easy to check that this maps zero to zero. An unchanged parent RNG can only
happen if all four `p` values are zero at the end of this, which implies that
they were all zero at the beginning. However, that is impossible since the four
`p` values differ from `x` by different additive constants, so they cannot all
be zero. Stated more generally, this non-collision property: assuming the main
RNG isn't used between task forks, sibling and parent tasks cannot have RNG
collisions. If the task tree structure is more deeply nested or if there are
intervening uses of the main RNG, we're back to relying on "merely" 256 bits of
collision resistance, but it's nice to know that in what is likely the most
common case, RNG collisions are actually impossible. This fact may also explain
better-than-theoretical collision resistance observed in our experiment with a
reduced size analogue of our hashing system.
[1]: https://www.pcg-random.org/pdf/hmc-cs-2014-0905.pdf
[2]: http://supertech.csail.mit.edu/papers/dprng.pdf
[3]: https://gee.cs.oswego.edu/dl/papers/oopsla14.pdf
*/
void jl_rng_split(uint64_t dst[JL_RNG_SIZE], uint64_t src[JL_RNG_SIZE]) JL_NOTSAFEPOINT
{
/* TODO: consider a less ad-hoc construction
Ideally we could just use the output of the random stream to seed the initial
state of the child. Out of an overabundance of caution we multiply with
effectively random coefficients, to break possible self-interactions.
It is not the goal to mix bits -- we work under the assumption that the
source is well-seeded, and its output looks effectively random.
However, xoshiro has never been studied in the mode where we seed the
initial state with the output of another xoshiro instance.
Constants have nothing up their sleeve:
0x02011ce34bce797f == hash(UInt(1))|0x01
0x5a94851fb48a6e05 == hash(UInt(2))|0x01
0x3688cf5d48899fa7 == hash(UInt(3))|0x01
0x867b4bb4c42e5661 == hash(UInt(4))|0x01
*/
to[0] = 0x02011ce34bce797f * jl_genrandom(from);
to[1] = 0x5a94851fb48a6e05 * jl_genrandom(from);
to[2] = 0x3688cf5d48899fa7 * jl_genrandom(from);
to[3] = 0x867b4bb4c42e5661 * jl_genrandom(from);
// load and advance the internal LCG state
uint64_t x = src[4];
src[4] = dst[4] = x * 0xd1342543de82ef95 + 1;
// high spectrum multiplier from https://arxiv.org/abs/2001.05304

static const uint64_t a[4] = {
0xe5f8fa077b92a8a8, // random additive offsets...
0x7a0cd918958c124d,
0x86222f7d388588d4,
0xd30cbd35f2b64f52
};
static const uint64_t m[4] = {
0xaef17502108ef2d9, // standard PCG multiplier
0xf34026eeb86766af, // random odd multipliers...
0x38fd70ad58dd9fbb,
0x6677f9b93ab0c04d
};

// PCG-RXS-M-XS output with four variants
for (int i = 0; i < 4; i++) {
uint64_t p = x + a[i];
p ^= p >> ((p >> 59) + 5);
p *= m[i];
p ^= p >> 43;
dst[i] = src[i] + p; // SplitMix dot product
}
}

JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion_future, size_t ssize)
Expand Down
27 changes: 16 additions & 11 deletions stdlib/Random/src/Xoshiro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,17 @@ struct TaskLocalRNG <: AbstractRNG end
TaskLocalRNG(::Nothing) = TaskLocalRNG()
rng_native_52(::TaskLocalRNG) = UInt64

function setstate!(x::TaskLocalRNG, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
function setstate!(
x::TaskLocalRNG,
s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64, # xoshiro256 state
s4::UInt64 = 1s0 + 3s1 + 5s2 + 7s3, # internal splitmix state
)
t = current_task()
t.rngState0 = s0
t.rngState1 = s1
t.rngState2 = s2
t.rngState3 = s3
t.rngState4 = s4
x
end

Expand All @@ -128,11 +133,11 @@ end
tmp = s0 + s3
res = ((tmp << 23) | (tmp >> 41)) + s0
t = s1 << 17
s2 = xor(s2, s0)
s3 = xor(s3, s1)
s1 = xor(s1, s2)
s0 = xor(s0, s3)
s2 = xor(s2, t)
s2 ⊻= s0
s3 ⊻= s1
s1 ⊻= s2
s0 ⊻= s3
s2 ⊻= t
s3 = s3 << 45 | s3 >> 19
task.rngState0, task.rngState1, task.rngState2, task.rngState3 = s0, s1, s2, s3
res
Expand All @@ -159,7 +164,7 @@ seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Integer) = seed!(rng, make_seed(s
@inline function rand(rng::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{UInt128})
first = rand(rng, UInt64)
second = rand(rng,UInt64)
second + UInt128(first)<<64
second + UInt128(first) << 64
end

@inline rand(rng::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{Int128}) = rand(rng, UInt128) % Int128
Expand All @@ -178,14 +183,14 @@ end

function copy!(dst::TaskLocalRNG, src::Xoshiro)
t = current_task()
t.rngState0, t.rngState1, t.rngState2, t.rngState3 = src.s0, src.s1, src.s2, src.s3
dst
setstate!(dst, src.s0, src.s1, src.s2, src.s3)
return dst
end

function copy!(dst::Xoshiro, src::TaskLocalRNG)
t = current_task()
dst.s0, dst.s1, dst.s2, dst.s3 = t.rngState0, t.rngState1, t.rngState2, t.rngState3
dst
setstate!(dst, t.rngState0, t.rngState1, t.rngState2, t.rngState3)
return dst
end

function ==(a::Xoshiro, b::TaskLocalRNG)
Expand Down
Loading

0 comments on commit 7618e64

Please sign in to comment.