Skip to content

Commit

Permalink
slices: rework the APIs of BinarySearch*
Browse files Browse the repository at this point in the history
For golang/go#50340

Change-Id: If115b2b66d463d5f3788d017924f8dd38867551c
Reviewed-on: https://go-review.googlesource.com/c/exp/+/395414
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Trust: Eli Bendersky‎ <eliben@golang.org>
  • Loading branch information
eliben committed Mar 28, 2022
1 parent 054d857 commit 053ad81
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 35 deletions.
43 changes: 26 additions & 17 deletions slices/sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,34 @@ func IsSortedFunc[E any](x []E, less func(a, b E) bool) bool {
return true
}

// BinarySearch searches for target in a sorted slice and returns the smallest
// index at which target is found. If the target is not found, the index at
// which it could be inserted into the slice is returned; therefore, if the
// intention is to find target itself a separate check for equality with the
// element at the returned index is required.
func BinarySearch[E constraints.Ordered](x []E, target E) int {
return search(len(x), func(i int) bool { return x[i] >= target })
// BinarySearch searches for target in a sorted slice and returns the position
// where target is found, or the position where target would appear in the
// sort order; it also returns a bool saying whether the target is really found
// in the slice. The slice must be sorted in increasing order.
func BinarySearch[E constraints.Ordered](x []E, target E) (int, bool) {
// search returns the leftmost position where f returns true, or len(x) if f
// returns false for all x. This is the insertion position for target in x,
// and could point to an element that's either == target or not.
pos := search(len(x), func(i int) bool { return x[i] >= target })
if pos >= len(x) || x[pos] != target {
return pos, false
} else {
return pos, true
}
}

// BinarySearchFunc uses binary search to find and return the smallest index i
// in [0, n) at which ok(i) is true, assuming that on the range [0, n),
// ok(i) == true implies ok(i+1) == true. That is, BinarySearchFunc requires
// that ok is false for some (possibly empty) prefix of the input range [0, n)
// and then true for the (possibly empty) remainder; BinarySearchFunc returns
// the first true index. If there is no such index, BinarySearchFunc returns n.
// (Note that the "not found" return value is not -1 as in, for instance,
// strings.Index.) Search calls ok(i) only for i in the range [0, n).
func BinarySearchFunc[E any](x []E, ok func(E) bool) int {
return search(len(x), func(i int) bool { return ok(x[i]) })
// BinarySearchFunc works like BinarySearch, but uses a custom comparison
// function. The slice must be sorted in increasing order, where "increasing" is
// defined by cmp. cmp(a, b) is expected to return an integer comparing the two
// parameters: 0 if a == b, a negative number if a < b and a positive number if
// a > b.
func BinarySearchFunc[E any](x []E, target E, cmp func(E, E) int) (int, bool) {
pos := search(len(x), func(i int) bool { return cmp(x[i], target) >= 0 })
if pos >= len(x) || cmp(x[pos], target) != 0 {
return pos, false
} else {
return pos, true
}
}

// maxDepth returns a threshold at which quicksort should switch
Expand Down
119 changes: 101 additions & 18 deletions slices/sort_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ package slices
import (
"math"
"math/rand"
"strconv"
"strings"
"testing"
)

Expand Down Expand Up @@ -151,31 +153,112 @@ func TestStability(t *testing.T) {
}

func TestBinarySearch(t *testing.T) {
data := []string{"aa", "ad", "ca", "xy"}
str1 := []string{"foo"}
str2 := []string{"ab", "ca"}
str3 := []string{"mo", "qo", "vo"}
str4 := []string{"ab", "ad", "ca", "xy"}

// slice with repeating elements
strRepeats := []string{"ba", "ca", "da", "da", "da", "ka", "ma", "ma", "ta"}

// slice with all element equal
strSame := []string{"xx", "xx", "xx"}

tests := []struct {
target string
want int
data []string
target string
wantPos int
wantFound bool
}{
{"aa", 0},
{"ab", 1},
{"ad", 1},
{"ax", 2},
{"ca", 2},
{"cc", 3},
{"dd", 3},
{"xy", 3},
{"zz", 4},
{[]string{}, "foo", 0, false},
{[]string{}, "", 0, false},

{str1, "foo", 0, true},
{str1, "bar", 0, false},
{str1, "zx", 1, false},

{str2, "aa", 0, false},
{str2, "ab", 0, true},
{str2, "ad", 1, false},
{str2, "ca", 1, true},
{str2, "ra", 2, false},

{str3, "bb", 0, false},
{str3, "mo", 0, true},
{str3, "nb", 1, false},
{str3, "qo", 1, true},
{str3, "tr", 2, false},
{str3, "vo", 2, true},
{str3, "xr", 3, false},

{str4, "aa", 0, false},
{str4, "ab", 0, true},
{str4, "ac", 1, false},
{str4, "ad", 1, true},
{str4, "ax", 2, false},
{str4, "ca", 2, true},
{str4, "cc", 3, false},
{str4, "dd", 3, false},
{str4, "xy", 3, true},
{str4, "zz", 4, false},

{strRepeats, "da", 2, true},
{strRepeats, "db", 5, false},
{strRepeats, "ma", 6, true},
{strRepeats, "mb", 8, false},

{strSame, "xx", 0, true},
{strSame, "ab", 0, false},
{strSame, "zz", 3, false},
}
for _, tt := range tests {
t.Run(tt.target, func(t *testing.T) {
i := BinarySearch(data, tt.target)
if i != tt.want {
t.Errorf("BinarySearch want %d, got %d", tt.want, i)
{
pos, found := BinarySearch(tt.data, tt.target)
if pos != tt.wantPos || found != tt.wantFound {
t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
}
}

{
pos, found := BinarySearchFunc(tt.data, tt.target, strings.Compare)
if pos != tt.wantPos || found != tt.wantFound {
t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
}
}
})
}
}

func TestBinarySearchInts(t *testing.T) {
data := []int{20, 30, 40, 50, 60, 70, 80, 90}
tests := []struct {
target int
wantPos int
wantFound bool
}{
{20, 0, true},
{23, 1, false},
{43, 3, false},
{80, 6, true},
}
for _, tt := range tests {
t.Run(strconv.Itoa(tt.target), func(t *testing.T) {
{
pos, found := BinarySearch(data, tt.target)
if pos != tt.wantPos || found != tt.wantFound {
t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
}
}

j := BinarySearchFunc(data, func(s string) bool { return s >= tt.target })
if j != tt.want {
t.Errorf("BinarySearchFunc want %d, got %d", tt.want, j)
{
cmp := func(a, b int) int {
return a - b
}
pos, found := BinarySearchFunc(data, tt.target, cmp)
if pos != tt.wantPos || found != tt.wantFound {
t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
}
}
})
}
Expand Down

0 comments on commit 053ad81

Please sign in to comment.