Skip to content

Commit e95884d

Browse files
committed
math/rand: add Shuffle
Shuffle uses the Fisher-Yates algorithm. Since this is new API, it affords us the opportunity to use a much faster Int31n implementation that mostly avoids division. As a result, BenchmarkPerm30ViaShuffle is about 30% faster than BenchmarkPerm30, despite requiring a separate initialization loop and using function calls to swap elements. Fixes golang#20480 Updates golang#16213 Updates golang#21211 Change-Id: Ib8956c4bebed9d84f193eb98282ec16ee7c2b2d5
1 parent 93471a1 commit e95884d

File tree

3 files changed

+217
-0
lines changed

3 files changed

+217
-0
lines changed

Diff for: src/math/rand/example_test.go

+32
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"math/rand"
1010
"os"
11+
"strings"
1112
"text/tabwriter"
1213
)
1314

@@ -105,3 +106,34 @@ func ExamplePerm() {
105106
// 2
106107
// 0
107108
}
109+
110+
func ExampleShuffle() {
111+
words := strings.Fields("ink runs from the corners of my mouth")
112+
rand.Shuffle(len(words), func(i, j int) {
113+
words[i], words[j] = words[j], words[i]
114+
})
115+
fmt.Println(words)
116+
117+
// Output:
118+
// [mouth my the of runs corners from ink]
119+
}
120+
121+
func ExampleShuffle_slicesInUnison() {
122+
numbers := []byte("12345")
123+
letters := []byte("ABCDE")
124+
// Shuffle numbers, swapping corresponding entries in letters at the same time.
125+
rand.Shuffle(len(numbers), func(i, j int) {
126+
numbers[i], numbers[j] = numbers[j], numbers[i]
127+
letters[i], letters[j] = letters[j], letters[i]
128+
})
129+
for i := range numbers {
130+
fmt.Printf("%c: %c\n", letters[i], numbers[i])
131+
}
132+
133+
// Output:
134+
// C: 3
135+
// D: 4
136+
// A: 1
137+
// E: 5
138+
// B: 2
139+
}

Diff for: src/math/rand/rand.go

+54
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,30 @@ func (r *Rand) Int31n(n int32) int32 {
135135
return v % n
136136
}
137137

138+
// int31n returns, as an int32, a non-negative pseudo-random number in [0,n).
139+
// n must be > 0, but int31n does not check this; the caller must ensure it.
140+
// int31n exists because Int31n is inefficient, but Go 1 compatibility
141+
// requires that the stream of values produced by math/rand remain unchanged.
142+
// int31n can thus only be used internally, by newly introduced APIs.
143+
//
144+
// For implementation details, see:
145+
// http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction
146+
// http://lemire.me/blog/2016/06/30/fast-random-shuffling
147+
func (r *Rand) int31n(n int32) int32 {
148+
v := r.Uint32()
149+
prod := uint64(v) * uint64(n)
150+
low := uint32(prod)
151+
if low < uint32(n) {
152+
thresh := uint32(-n) % uint32(n)
153+
for low < thresh {
154+
v = r.Uint32()
155+
prod = uint64(v) * uint64(n)
156+
low = uint32(prod)
157+
}
158+
}
159+
return int32(prod >> 32)
160+
}
161+
138162
// Intn returns, as an int, a non-negative pseudo-random number in [0,n).
139163
// It panics if n <= 0.
140164
func (r *Rand) Intn(n int) int {
@@ -202,6 +226,31 @@ func (r *Rand) Perm(n int) []int {
202226
return m
203227
}
204228

229+
// Shuffle pseudo-randomizes the order of elements.
230+
// n is the number of elements. Shuffle panics if n < 0.
231+
// swap swaps the elements with indexes i and j.
232+
func (r *Rand) Shuffle(n int, swap func(i, j int)) {
233+
if n < 0 {
234+
panic("invalid argument to Shuffle")
235+
}
236+
237+
// Fisher-Yates shuffle: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
238+
// Shuffle really ought not be called with n that doesn't fit in 32 bits.
239+
// Not only will it take a very long time, but with 2³¹! possible permutations,
240+
// there's no way that any PRNG can have a big enough internal state to
241+
// generate even a minuscule percentage of the possible permutations.
242+
// Nevertheless, the right API signature accepts an int n, so handle it as best we can.
243+
i := n - 1
244+
for ; i > 1<<31-1-1; i-- {
245+
j := int(r.Int63n(int64(i + 1)))
246+
swap(i, j)
247+
}
248+
for ; i > 0; i-- {
249+
j := int(r.int31n(int32(i + 1)))
250+
swap(i, j)
251+
}
252+
}
253+
205254
// Read generates len(p) random bytes and writes them into p. It
206255
// always returns len(p) and a nil error.
207256
// Read should not be called concurrently with any other Rand method.
@@ -288,6 +337,11 @@ func Float32() float32 { return globalRand.Float32() }
288337
// from the default Source.
289338
func Perm(n int) []int { return globalRand.Perm(n) }
290339

340+
// Shuffle pseudo-randomizes the order of elements using the default Source.
341+
// n is the number of elements. Shuffle panics if n <= 0.
342+
// swap swaps the elements with indexes i and j.
343+
func Shuffle(n int, swap func(i, j int)) { globalRand.Shuffle(n, swap) }
344+
291345
// Read generates len(p) random bytes from the default Source and
292346
// writes them into p. It always returns len(p) and a nil error.
293347
// Read, unlike the Rand.Read method, is safe for concurrent use.

Diff for: src/math/rand/rand_test.go

+131
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,113 @@ func TestReadSeedReset(t *testing.T) {
450450
}
451451
}
452452

453+
func TestShuffleSmall(t *testing.T) {
454+
// Check that Shuffle allows n=0 and n=1, but that swap is never called for them.
455+
r := New(NewSource(1))
456+
for n := 0; n <= 1; n++ {
457+
r.Shuffle(n, func(i, j int) { t.Fatalf("swap called, n=%d i=%d j=%d", n, i, j) })
458+
}
459+
}
460+
461+
// encodePerm converts from a permuted slice of length n, such as Perm generates, to an int in [0, n!).
462+
// See https://en.wikipedia.org/wiki/Lehmer_code.
463+
// encodePerm modifies the input slice.
464+
func encodePerm(s []int) int {
465+
// Convert to Lehmer code.
466+
for i, x := range s {
467+
r := s[i+1:]
468+
for j, y := range r {
469+
if y > x {
470+
r[j]--
471+
}
472+
}
473+
}
474+
// Convert to int in [0, n!).
475+
m := 0
476+
fact := 1
477+
for i := len(s) - 1; i >= 0; i-- {
478+
m += s[i] * fact
479+
fact *= len(s) - i
480+
}
481+
return m
482+
}
483+
484+
// TestUniformFactorial tests several ways of generating a uniform value in [0, n!).
485+
func TestUniformFactorial(t *testing.T) {
486+
r := New(NewSource(testSeeds[0]))
487+
top := 6
488+
if testing.Short() {
489+
top = 4
490+
}
491+
for n := 3; n <= top; n++ {
492+
t.Run(fmt.Sprintf("n=%d", n), func(t *testing.T) {
493+
// Calculate n!.
494+
nfact := 1
495+
for i := 2; i <= n; i++ {
496+
nfact *= i
497+
}
498+
499+
// Test a few different ways to generate a uniform distribution.
500+
p := make([]int, n) // re-usable slice for Shuffle generator
501+
tests := [...]struct {
502+
name string
503+
fn func() int
504+
}{
505+
{name: "Int31n", fn: func() int { return int(r.Int31n(int32(nfact))) }},
506+
{name: "int31n", fn: func() int { return int(r.int31n(int32(nfact))) }},
507+
{name: "Perm", fn: func() int { return encodePerm(r.Perm(n)) }},
508+
{name: "Shuffle", fn: func() int {
509+
// Generate permutation using Shuffle.
510+
for i := range p {
511+
p[i] = i
512+
}
513+
r.Shuffle(n, func(i, j int) { p[i], p[j] = p[j], p[i] })
514+
return encodePerm(p)
515+
}},
516+
}
517+
518+
for _, test := range tests {
519+
t.Run(test.name, func(t *testing.T) {
520+
// Gather chi-squared values and check that they follow
521+
// the expected normal distribution given n!-1 degrees of freedom.
522+
// See https://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test and
523+
// https://www.johndcook.com/Beautiful_Testing_ch10.pdf.
524+
nsamples := 10 * nfact
525+
if nsamples < 200 {
526+
nsamples = 200
527+
}
528+
samples := make([]float64, nsamples)
529+
for i := range samples {
530+
// Generate some uniformly distributed values and count their occurrences.
531+
const iters = 1000
532+
counts := make([]int, nfact)
533+
for i := 0; i < iters; i++ {
534+
counts[test.fn()]++
535+
}
536+
// Calculate chi-squared and add to samples.
537+
want := iters / float64(nfact)
538+
var χ2 float64
539+
for _, have := range counts {
540+
err := float64(have) - want
541+
χ2 += err * err
542+
}
543+
χ2 /= want
544+
samples[i] = χ2
545+
}
546+
547+
// Check that our samples approximate the appropriate normal distribution.
548+
dof := float64(nfact - 1)
549+
expected := &statsResults{mean: dof, stddev: math.Sqrt(2 * dof)}
550+
errorScale := max(1.0, expected.stddev)
551+
expected.closeEnough = 0.10 * errorScale
552+
expected.maxError = 0.08 // TODO: What is the right value here? See issue 21211.
553+
checkSampleDistribution(t, samples, expected)
554+
})
555+
}
556+
})
557+
}
558+
}
559+
453560
// Benchmarks
454561

455562
func BenchmarkInt63Threadsafe(b *testing.B) {
@@ -514,6 +621,30 @@ func BenchmarkPerm30(b *testing.B) {
514621
}
515622
}
516623

624+
func BenchmarkPerm30ViaShuffle(b *testing.B) {
625+
r := New(NewSource(1))
626+
for n := b.N; n > 0; n-- {
627+
p := make([]int, 30)
628+
for i := range p {
629+
p[i] = i
630+
}
631+
r.Shuffle(30, func(i, j int) { p[i], p[j] = p[j], p[i] })
632+
}
633+
}
634+
635+
// BenchmarkShuffleOverhead uses a minimal swap function
636+
// to measure just the shuffling overhead.
637+
func BenchmarkShuffleOverhead(b *testing.B) {
638+
r := New(NewSource(1))
639+
for n := b.N; n > 0; n-- {
640+
r.Shuffle(52, func(i, j int) {
641+
if i < 0 || i >= 52 || j < 0 || j >= 52 {
642+
b.Fatalf("bad swap(%d, %d)", i, j)
643+
}
644+
})
645+
}
646+
}
647+
517648
func BenchmarkRead3(b *testing.B) {
518649
r := New(NewSource(1))
519650
buf := make([]byte, 3)

0 commit comments

Comments
 (0)