diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml
index 47fed02..001f254 100644
--- a/.github/workflows/go.yml
+++ b/.github/workflows/go.yml
@@ -19,7 +19,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
- go-version: 1.21
+ go-version: 1.18
- name: Test
run: go test -v ./...
diff --git a/README.md b/README.md
index 5ce9712..5fc5c95 100644
--- a/README.md
+++ b/README.md
@@ -1,36 +1,99 @@
DAO
-
道生一, 一生二, 二生三, 三生万物; 万物负阴而抱阳, 冲气以为和.
+
道生一, 一生二, 二生三, 三生万物; 万物负阴而抱阳, 冲气以为和
-[![Build Status](https://github.com/lxzan/dao/workflows/Go%20Test/badge.svg?branch=main)](https://github.com/lxzan/dao/actions?query=branch%3Amain) [![go-version](https://img.shields.io/badge/go-%3E%3D1.21-30dff3?style=flat-square&logo=go)](https://github.com/lxzan/dao)
-
+[![Build Status](https://github.com/lxzan/dao/workflows/Go%20Test/badge.svg?branch=main)](https://github.com/lxzan/dao/actions?query=branch%3Amain) [![codecov](https://codecov.io/gh/lxzan/dao/graph/badge.svg?token=BQM1JHCDEE)](https://codecov.io/gh/lxzan/dao) [![go-version](https://img.shields.io/badge/go-%3E%3D1.18-30dff3?style=flat-square&logo=go)](https://github.com/lxzan/dao) [![license](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE)
### 简介
-Go 数据结构与算法库
+`DAO` 是一个基于泛型的数据结构与算法库
### 目录
- [简介](#简介)
- [目录](#目录)
+- [动态数组](#动态数组)
+ - [去重](#去重)
+ - [排序](#排序)
+ - [过滤](#过滤)
- [堆](#堆)
- [二叉堆](#二叉堆)
- [四叉堆](#四叉堆)
+ - [八叉堆](#八叉堆)
- [栈](#栈)
+- [队列](#队列)
- [双端队列](#双端队列)
- [双向链表](#双向链表)
- [红黑树](#红黑树)
- - [区间查询](#区间查询)
- - [极值查询](#极值查询)
-- [前缀树](#前缀树)
+- [字典树](#字典树)
- [哈希表](#哈希表)
- [线段树](#线段树)
- [基准测试](#基准测试)
+### 动态数组
+
+#### 去重
+
+```go
+package main
+
+import (
+ "fmt"
+ "github.com/lxzan/dao/vector"
+)
+
+func main() {
+ var v = vector.NewFromInts(1, 3, 5, 3)
+ v.Uniq()
+ fmt.Printf("%v", v.Elem())
+}
+
+```
+
+#### 排序
+
+```go
+package main
+
+import (
+ "fmt"
+ "github.com/lxzan/dao/vector"
+)
+
+func main() {
+ var v = vector.NewFromInts(1, 3, 5, 2, 4, 6)
+ v.Sort()
+ fmt.Printf("%v", v.Elem())
+}
+
+```
+
+#### 过滤
+
+```go
+package main
+
+import (
+ "fmt"
+ "github.com/lxzan/dao/vector"
+)
+
+func main() {
+ var v = vector.NewFromInts(1, 3, 5, 2, 4, 6)
+ v.Filter(func(i int, v vector.Int) bool {
+ return v.GetID()%2 == 0
+ })
+ fmt.Printf("%v", v.Elem())
+}
+
+```
+
### 堆
+**堆** 又称之为优先队列, 堆顶元素总是最大或最小的. 常用的是四叉堆, `Push/Pop` 性能较为均衡.
+
#### 二叉堆
```go
@@ -38,10 +101,11 @@ package main
import (
"github.com/lxzan/dao/heap"
+ "github.com/lxzan/dao/types/cmp"
)
func main() {
- var h = heap.New[int]().SetForkNumber(heap.Binary)
+ var h = heap.NewWithForks(heap.Binary, cmp.Less[int])
h.Push(1)
h.Push(3)
h.Push(5)
@@ -62,10 +126,11 @@ package main
import (
"github.com/lxzan/dao/heap"
+ "github.com/lxzan/dao/types/cmp"
)
func main() {
- var h = heap.New[int]().SetForkNumber(heap.Quadratic)
+ var h = heap.NewWithForks(heap.Quadratic, cmp.Less[int])
h.Push(1)
h.Push(3)
h.Push(5)
@@ -76,10 +141,38 @@ func main() {
println(h.Pop())
}
}
+
+```
+
+#### 八叉堆
+
+```go
+package main
+
+import (
+ "github.com/lxzan/dao/heap"
+ "github.com/lxzan/dao/types/cmp"
+)
+
+func main() {
+ var h = heap.NewWithForks(heap.Octal, cmp.Less[int])
+ h.Push(1)
+ h.Push(3)
+ h.Push(5)
+ h.Push(2)
+ h.Push(4)
+ h.Push(6)
+ for h.Len() > 0 {
+ println(h.Pop())
+ }
+}
+
```
### 栈
+**栈** 先进后出 (`LIFO`) 的数据结构
+
```go
package main
@@ -98,8 +191,35 @@ func main() {
}
```
+### 队列
+
+**队列** 先进先出 (`FIFO`) 的数据结构. `dao/queue` 在全部元素弹出后会自动重置, 复用内存空间
+
+```go
+package main
+
+import (
+ "github.com/lxzan/dao/queue"
+)
+
+func main() {
+ var s = queue.New[int](0)
+ s.Push(1)
+ s.Push(3)
+ s.Push(5)
+ for s.Len() > 0 {
+ println(s.Pop())
+ }
+}
+
+```
+
### 双端队列
+**双端队列** 类似于双向链表, 两端均可高效执行插入删除操作.
+
+`dao/deque` 基于数组下标模拟指针实现, 删除后的空间后续仍可复用, 且不依赖 `sync.Pool`
+
```go
package main
@@ -156,13 +276,12 @@ func main() {
### 红黑树
-#### 区间查询
+高性能红黑树实现, 可作为内存数据库使用.
```go
package main
import (
- "github.com/lxzan/dao"
"github.com/lxzan/dao/rbtree"
)
@@ -176,39 +295,18 @@ func main() {
NewQuery().
Left(func(key int) bool { return key >= 3 }).
Right(func(key int) bool { return key <= 5 }).
- Order(dao.DESC).
- Do()
+ Order(rbtree.ASC).
+ FindAll()
for _, item := range results {
println(item.Key)
}
}
-```
-
-#### 极值查询
-
-```go
-package main
-import (
- "fmt"
- "github.com/lxzan/dao/rbtree"
-)
-
-func main() {
- var tree = rbtree.New[int, struct{}]()
- for i := 0; i < 10; i++ {
- tree.Set(i, struct{}{})
- }
-
- minimum, _ := tree.GetMinKey(rbtree.TrueFunc[int])
- maximum, _ := tree.GetMaxKey(rbtree.TrueFunc[int])
- fmt.Printf("%v %v", minimum.Key, maximum.Key)
-}
```
-### 前缀树
+### 字典树
-可以动态配置槽位宽度的前缀树
+**字典树** 又叫前缀树, 可以高效匹配字符串前缀. `dao/dict` 可以动态配置槽位宽度(由索引控制).
注意: 合理设置索引, 超出索引长度的字符不能被索引优化
@@ -258,6 +356,8 @@ func main() {
### 线段树
+**线段树** 是一种二叉树,它的每个节点都表示一个区间。 线段树的特点是可以在`O(logn)`的时间内进行区间查询和区间更新。
+
```go
package main
@@ -268,7 +368,7 @@ import (
func main() {
var data = []tree.Int64{1, 3, 5, 7, 9, 2, 4, 6, 8, 10}
var lines = tree.New[tree.Int64Schema, tree.Int64](data)
- var result = lines.Query(0, 9)
+ var result = lines.Query(0, 10)
println(result.MinValue, result.MaxValue, result.Sum)
}
@@ -276,7 +376,42 @@ func main() {
### 基准测试
-- 10,000 elements
+- 1,000 elements
+
+```
+go test -benchmem -bench '^Benchmark' ./benchmark/
+goos: windows
+goarch: amd64
+pkg: github.com/lxzan/dao/benchmark
+cpu: AMD Ryzen 5 PRO 4650G with Radeon Graphics
+BenchmarkDict_Set-12 8647 124.1 ns/op 48190 B/op 1003 allocs/op
+BenchmarkDict_Get-12 9609 122.0 ns/op 48000 B/op 1000 allocs/op
+BenchmarkDict_Match-12 3270 349.3 ns/op 48000 B/op 1000 allocs/op
+BenchmarkHeap_Push_Binary-12 55809 21.0 ns/op 38912 B/op 3 allocs/op
+BenchmarkHeap_Push_Quadratic-12 65932 18.0 ns/op 38912 B/op 3 allocs/op
+BenchmarkHeap_Push_Octal-12 71005 16.1 ns/op 38912 B/op 3 allocs/op
+BenchmarkHeap_Pop_Binary-12 10000 100.1 ns/op 16384 B/op 1 allocs/op
+BenchmarkHeap_Pop_Quadratic-12 10000 100.7 ns/op 16384 B/op 1 allocs/op
+BenchmarkHeap_Pop_Octal-12 9681 124.9 ns/op 16384 B/op 1 allocs/op
+BenchmarkStdList_Push-12 24715 48.7 ns/op 54000 B/op 1745 allocs/op
+BenchmarkStdList_PushAndPop-12 22006 54.7 ns/op 54000 B/op 1745 allocs/op
+BenchmarkLinkedList_Push-12 38464 31.7 ns/op 24000 B/op 1000 allocs/op
+BenchmarkLinkedList_PushAndPop-12 36898 32.6 ns/op 24000 B/op 1000 allocs/op
+BenchmarkDeque_Push-12 100468 11.7 ns/op 24576 B/op 1 allocs/op
+BenchmarkDeque_PushAndPop-12 51649 21.7 ns/op 37496 B/op 12 allocs/op
+BenchmarkRBTree_Set-12 9999 113.9 ns/op 72048 B/op 2001 allocs/op
+BenchmarkRBTree_Get-12 51806 22.7 ns/op 0 B/op 0 allocs/op
+BenchmarkRBTree_FindAll-12 2808 421.3 ns/op 288001 B/op 5000 allocs/op
+BenchmarkRBTree_FindAOne-12 4722 252.2 ns/op 56000 B/op 5000 allocs/op
+BenchmarkSegmentTree_Query-12 7498 164.4 ns/op 20 B/op 0 allocs/op
+BenchmarkSegmentTree_Update-12 10000 108.6 ns/op 15 B/op 0 allocs/op
+BenchmarkSort_Quick-12 24488 48.5 ns/op 0 B/op 0 allocs/op
+BenchmarkSort_Std-12 21703 54.9 ns/op 8216 B/op 2 allocs/op
+PASS
+ok github.com/lxzan/dao/benchmark 31.100s
+```
+
+- 1,000,000 elements
```
go test -benchmem -bench '^Benchmark' ./benchmark/
@@ -284,28 +419,29 @@ goos: windows
goarch: amd64
pkg: github.com/lxzan/dao/benchmark
cpu: AMD Ryzen 5 PRO 4650G with Radeon Graphics
-BenchmarkDict_Set-12 423 244.9 ns/op 517811 B/op 10645 allocs/op
-BenchmarkDict_Get-12 499 241.9 ns/op 480001 B/op 10000 allocs/op
-BenchmarkDict_Match-12 265 456.1 ns/op 480000 B/op 10000 allocs/op
-BenchmarkHeap_Push_Binary-12 4455 28.7 ns/op 507905 B/op 4 allocs/op
-BenchmarkHeap_Push_Quadratic-12 5960 26.4 ns/op 507906 B/op 4 allocs/op
-BenchmarkHeap_Push_Octal-12 5793 22.5 ns/op 507907 B/op 4 allocs/op
-BenchmarkHeap_Pop_Binary-12 808 149.7 ns/op 163840 B/op 1 allocs/op
-BenchmarkHeap_Pop_Quadratic-12 846 145.4 ns/op 163840 B/op 1 allocs/op
-BenchmarkHeap_Pop_Octal-12 673 178.8 ns/op 163840 B/op 1 allocs/op
-BenchmarkStdList_Push-12 1958 59.6 ns/op 558002 B/op 19745 allocs/op
-BenchmarkStdList_PushAndPop-12 1729 65.2 ns/op 558001 B/op 19745 allocs/op
-BenchmarkLinkedList_Push-12 3770 31.9 ns/op 240001 B/op 10000 allocs/op
-BenchmarkLinkedList_PushAndPop-12 2539 46.5 ns/op 240002 B/op 10000 allocs/op
-BenchmarkDeque_Push-12 8560 12.2 ns/op 245761 B/op 1 allocs/op
-BenchmarkDeque_PushAndPop-12 5599 37.8 ns/op 386937 B/op 18 allocs/op
-BenchmarkRBTree_Set-12 540 219.4 ns/op 720051 B/op 20001 allocs/op
-BenchmarkRBTree_Get-12 3272 36.5 ns/op 0 B/op 0 allocs/op
-BenchmarkRBTree_Query-12 60 1809.6 ns/op 3680048 B/op 60000 allocs/op
-BenchmarkSegmentTree_Query-12 418 273.4 ns/op 3917 B/op 47 allocs/op
-BenchmarkSegmentTree_Update-12 686 174.5 ns/op 2387 B/op 29 allocs/op
-BenchmarkSort_Quick-12 1588 75.8 ns/op 81920 B/op 1 allocs/op
-BenchmarkSort_Std-12 1377 86.2 ns/op 81944 B/op 2 allocs/op
+BenchmarkDict_Set-12 1 2295.2 ns/op 1405087408 B/op 24868109 allocs/op
+BenchmarkDict_Get-12 2 784.0 ns/op 48000000 B/op 1000000 allocs/op
+BenchmarkDict_Match-12 2 961.0 ns/op 48000000 B/op 1000000 allocs/op
+BenchmarkHeap_Push_Binary-12 48 24.8 ns/op 65708034 B/op 5 allocs/op
+BenchmarkHeap_Push_Quadratic-12 58 19.4 ns/op 65708033 B/op 5 allocs/op
+BenchmarkHeap_Push_Octal-12 69 17.1 ns/op 65708033 B/op 5 allocs/op
+BenchmarkHeap_Pop_Binary-12 3 376.3 ns/op 16007168 B/op 1 allocs/op
+BenchmarkHeap_Pop_Quadratic-12 3 342.8 ns/op 16007168 B/op 1 allocs/op
+BenchmarkHeap_Pop_Octal-12 3 374.8 ns/op 16007168 B/op 1 allocs/op
+BenchmarkStdList_Push-12 21 55.0 ns/op 55998007 B/op 1999745 allocs/op
+BenchmarkStdList_PushAndPop-12 15 67.5 ns/op 55998008 B/op 1999745 allocs/op
+BenchmarkLinkedList_Push-12 43 29.5 ns/op 24000000 B/op 1000000 allocs/op
+BenchmarkLinkedList_PushAndPop-12 39 34.7 ns/op 24000002 B/op 1000000 allocs/op
+BenchmarkDeque_Push-12 123 9.4 ns/op 24002560 B/op 1 allocs/op
+BenchmarkDeque_PushAndPop-12 60 18.7 ns/op 45098876 B/op 37 allocs/op
+BenchmarkRBTree_Set-12 6 171.9 ns/op 72000064 B/op 2000001 allocs/op
+BenchmarkRBTree_Get-12 22 50.1 ns/op 0 B/op 0 allocs/op
+BenchmarkRBTree_FindAll-12 1 1936.8 ns/op 288000128 B/op 5000001 allocs/op
+BenchmarkRBTree_FindAOne-12 1 1793.4 ns/op 56000000 B/op 5000000 allocs/op
+BenchmarkSegmentTree_Query-12 1 1630.0 ns/op 169678048 B/op 2000038 allocs/op
+BenchmarkSegmentTree_Update-12 1 1025.0 ns/op 169678048 B/op 2000038 allocs/op
+BenchmarkSort_Quick-12 10 109.5 ns/op 8003584 B/op 1 allocs/op
+BenchmarkSort_Std-12 9 123.0 ns/op 8003608 B/op 2 allocs/op
PASS
-ok github.com/lxzan/dao/benchmark 32.279s
-```
\ No newline at end of file
+ok github.com/lxzan/dao/benchmark 47.376s
+```
diff --git a/algorithm/helper.go b/algorithm/helper.go
index 416f4c2..358eae1 100644
--- a/algorithm/helper.go
+++ b/algorithm/helper.go
@@ -1,14 +1,12 @@
package algorithm
import (
- "cmp"
- "github.com/lxzan/dao"
- "slices"
+ "github.com/lxzan/dao/types/cmp"
"strconv"
)
// ToString 数字转字符串
-func ToString[T dao.Integer](x T) string {
+func ToString[T cmp.Integer](x T) string {
return strconv.Itoa(int(x))
}
@@ -35,12 +33,12 @@ func Swap[T any](a, b *T) {
*b = temp
}
-func Unique[T cmp.Ordered](arr []T) []T {
+func Unique[T cmp.Ordered, A ~[]T](arr A) A {
if len(arr) == 0 {
return arr
}
- slices.Sort(arr)
+ Sort(arr)
var n = len(arr)
var j = 1
@@ -54,21 +52,13 @@ func Unique[T cmp.Ordered](arr []T) []T {
return arr
}
-func UniqueBy[T any, K cmp.Ordered](arr []T, getKey func(item T) K) []T {
+func UniqueBy[T any, K cmp.Ordered, A ~[]T](arr A, getKey func(item T) K) A {
if len(arr) == 0 {
return arr
}
- slices.SortFunc(arr, func(a, b T) int {
- x := getKey(a)
- y := getKey(b)
- if x < y {
- return -1
- } else if x > y {
- return 1
- } else {
- return 0
- }
+ SortBy(arr, func(a, b T) int {
+ return cmp.Compare(getKey(a), getKey(b))
})
var n = len(arr)
@@ -84,11 +74,12 @@ func UniqueBy[T any, K cmp.Ordered](arr []T, getKey func(item T) K) []T {
}
// Reverse 反转数组
-func Reverse[T any](arr []T) {
+func Reverse[T any, A ~[]T](arr A) A {
var n = len(arr)
for i := 0; i < n/2; i++ {
arr[i], arr[n-1-i] = arr[n-1-i], arr[i]
}
+ return arr
}
// SelectValue 选择一个值 三元操作符替代品
@@ -119,7 +110,7 @@ func Map[A any, B any](arr []A, transfer func(i int, v A) B) []B {
}
// Filter 过滤器
-func Filter[T any](arr []T, check func(i int, v T) bool) []T {
+func Filter[T any, A ~[]T](arr A, check func(i int, v T) bool) A {
var results = make([]T, 0, len(arr))
for i, v := range arr {
if check(i, v) {
diff --git a/algorithm/helper_test.go b/algorithm/helper_test.go
index e0a61df..5394dee 100644
--- a/algorithm/helper_test.go
+++ b/algorithm/helper_test.go
@@ -15,6 +15,8 @@ func TestToString(t *testing.T) {
}
func TestUnique(t *testing.T) {
+ Unique[int, []int](nil)
+
t.Run("", func(t *testing.T) {
arr := Unique([]int{})
assert.ElementsMatch(t, arr, []int{})
@@ -133,6 +135,8 @@ func TestSwap(t *testing.T) {
}
func TestReverse(t *testing.T) {
+ Reverse[int, []int](nil)
+
t.Run("", func(t *testing.T) {
var list = []int{1, 2, 3, 4}
Reverse(list)
diff --git a/algorithm/sort.go b/algorithm/sort.go
index 3bc0a88..3867fa4 100644
--- a/algorithm/sort.go
+++ b/algorithm/sort.go
@@ -1,17 +1,16 @@
package algorithm
import (
- "cmp"
- "github.com/lxzan/dao"
+ "github.com/lxzan/dao/types/cmp"
)
-func IsSorted[T any](arr []T, cmp dao.CompareFunc[T]) bool {
+func IsSorted[T any](arr []T, compare cmp.CompareFunc[T]) bool {
var n = len(arr)
if n <= 1 {
return true
}
for i := 1; i < n; i++ {
- if cmp(arr[i], arr[i-1]) < 0 {
+ if compare(arr[i], arr[i-1]) < 0 {
return false
}
}
@@ -26,22 +25,22 @@ func Sort[T cmp.Ordered](arr []T) {
QuickSort(arr, 0, len(arr)-1, f)
}
-func SortBy[T any](arr []T, cmp dao.CompareFunc[T]) {
- if IsSorted(arr, cmp) {
+func SortBy[T any](arr []T, compare cmp.CompareFunc[T]) {
+ if IsSorted(arr, compare) {
return
}
- QuickSort(arr, 0, len(arr)-1, cmp)
+ QuickSort(arr, 0, len(arr)-1, compare)
}
-func getMedium[T any](arr []T, begin int, end int, cmp dao.CompareFunc[T]) int {
+func getMedium[T any](arr []T, begin int, end int, compare cmp.CompareFunc[T]) int {
var mid = (begin + end) / 2
- var x = cmp(arr[begin], arr[mid])
- var y = cmp(arr[mid], arr[end])
+ var x = compare(arr[begin], arr[mid])
+ var y = compare(arr[mid], arr[end])
if x+y != 0 {
return mid
}
- var z = cmp(arr[begin], arr[end])
+ var z = compare(arr[begin], arr[end])
y *= -1
if y+z != 0 {
return end
@@ -49,9 +48,9 @@ func getMedium[T any](arr []T, begin int, end int, cmp dao.CompareFunc[T]) int {
return begin
}
-func insertionSort[T any](arr []T, a, b int, cmp dao.CompareFunc[T]) {
+func insertionSort[T any](arr []T, a, b int, compare cmp.CompareFunc[T]) {
for i := a + 1; i <= b; i++ {
- for j := i; j > a && cmp(arr[j], arr[j-1]) == dao.Less; j-- {
+ for j := i; j > a && compare(arr[j], arr[j-1]) == cmp.LT; j-- {
arr[j], arr[j-1] = arr[j-1], arr[j]
}
}
@@ -59,34 +58,34 @@ func insertionSort[T any](arr []T, a, b int, cmp dao.CompareFunc[T]) {
// QuickSort 快速排序 begin <= x <= end 区间
// 对于随机数据, 此算法比标准库稍快; 对于本身比较有序的数据, 标准库表现更佳.
-func QuickSort[T any](arr []T, begin int, end int, cmp dao.CompareFunc[T]) {
+func QuickSort[T any](arr []T, begin int, end int, compare cmp.CompareFunc[T]) {
if begin >= end {
return
}
if end-begin <= 15 {
- insertionSort(arr, begin, end, cmp)
+ insertionSort(arr, begin, end, compare)
return
}
var index = begin
- var mid = getMedium(arr, begin, end, cmp)
+ var mid = getMedium(arr, begin, end, compare)
arr[mid], arr[begin] = arr[begin], arr[mid]
for i := begin + 1; i <= end; i++ {
- var flag = cmp(arr[i], arr[begin])
- if flag == dao.Less || (flag == dao.Equal && i%2 == 0) {
+ var flag = compare(arr[i], arr[begin])
+ if flag == cmp.LT || (flag == cmp.EQ && i%2 == 0) {
index++
arr[index], arr[i] = arr[i], arr[index]
}
}
arr[index], arr[begin] = arr[begin], arr[index]
- QuickSort(arr, begin, index-1, cmp)
- QuickSort(arr, index+1, end, cmp)
+ QuickSort(arr, begin, index-1, compare)
+ QuickSort(arr, index+1, end, compare)
}
// BinarySearch 二分搜索
// @return 数组下标 如果不存在, 返回-1
-func BinarySearch[T any](arr []T, target T, cmp dao.CompareFunc[T]) int {
+func BinarySearch[T any](arr []T, target T, compare cmp.CompareFunc[T]) int {
var n = len(arr)
if n == 0 {
return -1
@@ -96,19 +95,19 @@ func BinarySearch[T any](arr []T, target T, cmp dao.CompareFunc[T]) int {
var right = n - 1
for right-left > 1 {
var mid = (left + right) / 2
- switch cmp(arr[mid], target) {
- case dao.Equal:
+ switch compare(arr[mid], target) {
+ case cmp.EQ:
return mid
- case dao.Greater:
+ case cmp.GT:
right = mid
default:
left = mid
}
}
- if cmp(arr[left], target) == dao.Equal {
+ if compare(arr[left], target) == cmp.EQ {
return left
- } else if cmp(arr[right], target) == dao.Equal {
+ } else if compare(arr[right], target) == cmp.EQ {
return right
} else {
return -1
diff --git a/algorithm/sort_test.go b/algorithm/sort_test.go
index 5b32ef7..0fc3048 100644
--- a/algorithm/sort_test.go
+++ b/algorithm/sort_test.go
@@ -1,10 +1,11 @@
package algorithm
import (
- "cmp"
"github.com/lxzan/dao/internal/utils"
+ "github.com/lxzan/dao/types/cmp"
"github.com/stretchr/testify/assert"
"math/rand"
+ "sort"
"testing"
)
@@ -64,9 +65,9 @@ func TestSort(t *testing.T) {
}
t.Run("", func(t *testing.T) {
- var a = []int{1, 2, 3}
+ var a = []int{2, 1}
SortBy(a, cmp.Compare[int])
- assert.True(t, utils.IsSameSlice(a, []int{1, 2, 3}))
+ assert.True(t, utils.IsSameSlice(a, []int{1, 2}))
})
t.Run("", func(t *testing.T) {
@@ -80,6 +81,42 @@ func TestSort(t *testing.T) {
Sort(a)
assert.True(t, utils.IsSameSlice(a, []int{1, 2, 3, 4}))
})
+
+ t.Run("", func(t *testing.T) {
+ Sort([]int{})
+ Sort([]int{1})
+ })
+
+ t.Run("", func(t *testing.T) {
+ for j := 0; j < 100; j++ {
+ var count = rand.Intn(100)
+ var arr0 []int
+ for i := 0; i < count; i++ {
+ arr0 = append(arr0, rand.Intn(count))
+ }
+ var arr1 = utils.Clone(arr0)
+
+ Sort(arr0)
+ sort.Ints(arr1)
+ assert.True(t, utils.IsSameSlice(arr0, arr1))
+ }
+
+ for j := 0; j < 100; j++ {
+ var count = rand.Intn(100)
+ var arr0 []int
+ for i := 0; i < count; i++ {
+ arr0 = append(arr0, rand.Intn(count))
+ }
+ var arr1 = utils.Clone(arr0)
+
+ SortBy(arr0, func(a, b int) int {
+ return -1 * cmp.Compare(a, b)
+ })
+ sort.Ints(arr1)
+ Reverse(arr1)
+ assert.True(t, utils.IsSameSlice(arr0, arr1))
+ }
+ })
}
func TestBinarySearch(t *testing.T) {
diff --git a/benchmark/heap_test.go b/benchmark/heap_test.go
index 21e9d93..f8cedce 100644
--- a/benchmark/heap_test.go
+++ b/benchmark/heap_test.go
@@ -2,12 +2,13 @@ package benchmark
import (
"github.com/lxzan/dao/heap"
+ "github.com/lxzan/dao/types/cmp"
"math/rand"
"testing"
)
func BenchmarkHeap_Push_Binary(b *testing.B) {
- var tpl = heap.New[int]().SetForkNumber(heap.Binary)
+ var tpl = heap.NewWithForks(heap.Binary, cmp.Less[int])
for j := 0; j < bench_count; j++ {
tpl.Push(rand.Int())
}
@@ -23,7 +24,7 @@ func BenchmarkHeap_Push_Binary(b *testing.B) {
}
func BenchmarkHeap_Push_Quadratic(b *testing.B) {
- var tpl = heap.New[int]().SetForkNumber(heap.Quadratic)
+ var tpl = heap.NewWithForks(heap.Quadratic, cmp.Less[int])
for j := 0; j < bench_count; j++ {
tpl.Push(rand.Int())
}
@@ -38,7 +39,7 @@ func BenchmarkHeap_Push_Quadratic(b *testing.B) {
}
}
func BenchmarkHeap_Push_Octal(b *testing.B) {
- var tpl = heap.New[int]().SetForkNumber(heap.Octal)
+ var tpl = heap.NewWithForks(heap.Octal, cmp.Less[int])
for j := 0; j < bench_count; j++ {
tpl.Push(rand.Int())
}
@@ -54,7 +55,7 @@ func BenchmarkHeap_Push_Octal(b *testing.B) {
}
func BenchmarkHeap_Pop_Binary(b *testing.B) {
- var tpl = heap.New[int]().SetForkNumber(heap.Binary)
+ var tpl = heap.NewWithForks(heap.Binary, cmp.Less[int])
for j := 0; j < bench_count*2; j++ {
tpl.Push(rand.Int())
}
@@ -69,7 +70,7 @@ func BenchmarkHeap_Pop_Binary(b *testing.B) {
}
func BenchmarkHeap_Pop_Quadratic(b *testing.B) {
- var tpl = heap.New[int]().SetForkNumber(heap.Quadratic)
+ var tpl = heap.NewWithForks(heap.Quadratic, cmp.Less[int])
for j := 0; j < bench_count*2; j++ {
tpl.Push(rand.Int())
}
@@ -84,7 +85,7 @@ func BenchmarkHeap_Pop_Quadratic(b *testing.B) {
}
func BenchmarkHeap_Pop_Octal(b *testing.B) {
- var tpl = heap.New[int]().SetForkNumber(heap.Octal)
+ var tpl = heap.NewWithForks(heap.Octal, cmp.Less[int])
for j := 0; j < bench_count*2; j++ {
tpl.Push(rand.Int())
}
diff --git a/benchmark/rbtree_test.go b/benchmark/rbtree_test.go
index cf6ba72..807d922 100644
--- a/benchmark/rbtree_test.go
+++ b/benchmark/rbtree_test.go
@@ -30,7 +30,7 @@ func BenchmarkRBTree_Get(b *testing.B) {
}
}
-func BenchmarkRBTree_Query(b *testing.B) {
+func BenchmarkRBTree_FindAll(b *testing.B) {
var tree = rbtree.New[int, string]()
for j := 0; j < bench_count; j++ {
tree.Set(j, "")
@@ -47,7 +47,32 @@ func BenchmarkRBTree_Query(b *testing.B) {
NewQuery().
Left(func(key int) bool { return key >= x }).
Right(func(key int) bool { return key <= y }).
- Do()
+ Order(rbtree.DESC).
+ Limit(10).
+ FindAll()
+ }
+ }
+}
+
+func BenchmarkRBTree_FindAOne(b *testing.B) {
+ var tree = rbtree.New[int, string]()
+ for j := 0; j < bench_count; j++ {
+ tree.Set(j, "")
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ for j := 0; j < bench_count; j++ {
+ x, y := rand.Intn(bench_count), rand.Intn(bench_count)
+ if x > y {
+ algorithm.Swap(&x, &y)
+ }
+ tree.
+ NewQuery().
+ Left(func(key int) bool { return key >= x }).
+ Right(func(key int) bool { return key <= y }).
+ Order(rbtree.ASC).
+ FindOne()
}
}
}
diff --git a/deque/deque.go b/deque/deque.go
index ecc33ae..073c269 100644
--- a/deque/deque.go
+++ b/deque/deque.go
@@ -45,7 +45,7 @@ func (c *Element[T]) Value() T {
}
// New 创建双端队列
-func New[T any](capacity uint32) *Deque[T] {
+func New[T any](capacity int) *Deque[T] {
return &Deque[T]{elements: make([]Element[T], 1, 1+capacity)}
}
@@ -83,7 +83,6 @@ func (c *Deque[T]) putElement(ele *Element[T]) {
// Reset 重置
func (c *Deque[T]) Reset() {
- clear(c.elements)
c.autoReset()
}
diff --git a/deque/deque_test.go b/deque/deque_test.go
index 9613ba9..da3cf1b 100644
--- a/deque/deque_test.go
+++ b/deque/deque_test.go
@@ -1,9 +1,9 @@
package deque
import (
- "cmp"
"container/list"
"github.com/lxzan/dao/internal/utils"
+ "github.com/lxzan/dao/types/cmp"
"github.com/stretchr/testify/assert"
"math/rand"
"testing"
diff --git a/dict/dict.go b/dict/dict.go
index 18dd5f7..783dd65 100644
--- a/dict/dict.go
+++ b/dict/dict.go
@@ -12,30 +12,33 @@ type element struct {
}
type Dict[T any] struct {
- indexes []uint8 // 索引
- root *element // 根节点
- storage *mlist.MList[string, T] // 存储
+ indexes []uint8 // 索引
+ binaryIndex bool // 是否使用2进制索引
+ root *element // 根节点
+ storage *mlist.MList[string, T] // 存储
}
// New 新建字典树
// 注意: key不能重复
func New[T any]() *Dict[T] {
return &Dict[T]{
- indexes: defaultIndexes,
- root: &element{Children: make([]*element, defaultIndexes[0])},
- storage: mlist.NewMList[string, T](8),
+ indexes: defaultIndexes,
+ binaryIndex: true,
+ root: &element{Children: make([]*element, defaultIndexes[0])},
+ storage: mlist.NewMList[string, T](8),
}
}
// WithIndexes 设置索引
-// 索引元素必须满足 y=2^x
+// 索引长度至少为2; 如果每个数字都满足y=pow(2,x), 索引效率更高.
func (c *Dict[T]) WithIndexes(indexes []uint8) *Dict[T] {
if len(indexes) < 2 {
panic("indexes length at least 2")
}
for _, item := range indexes {
if !utils.IsBinaryNumber(item) {
- panic("indexes contains elements that must satisfy y=2^x")
+ c.binaryIndex = false
+ break
}
}
c.indexes = indexes
diff --git a/dict/dict_test.go b/dict/dict_test.go
index 13363a0..606408d 100644
--- a/dict/dict_test.go
+++ b/dict/dict_test.go
@@ -186,7 +186,7 @@ func TestDict_WithIndexes(t *testing.T) {
err := f(func() {
New[uint8]().WithIndexes([]uint8{1, 2, 3, 4})
})
- assert.Error(t, err)
+ assert.NoError(t, err)
})
t.Run("", func(t *testing.T) {
@@ -199,7 +199,9 @@ func TestDict_WithIndexes(t *testing.T) {
func TestDict_Random(t *testing.T) {
var count = 1000000
- var d = New[int]()
+ var d = New[int]().WithIndexes(
+ []uint8{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1},
+ )
var m = make(map[string]int)
for i := 0; i < count; i++ {
key := strconv.Itoa(i)
diff --git a/dict/encode.go b/dict/encode.go
index c3ed003..74ec4ce 100644
--- a/dict/encode.go
+++ b/dict/encode.go
@@ -1,5 +1,7 @@
package dict
+import "github.com/lxzan/dao/algorithm"
+
var defaultIndexes = []uint8{32, 32, 16, 16, 16, 16, 16, 8, 8, 8, 8, 8, 8, 4, 4, 4, 4}
type iterator struct {
@@ -15,11 +17,14 @@ func (c *iterator) hit() bool {
}
func (c *Dict[T]) getIndex(iter *iterator) int {
- return int(iter.Key[iter.Cursor]) & int(c.indexes[iter.Cursor]-1)
+ if c.binaryIndex {
+ return int(iter.Key[iter.Cursor]) & int(c.indexes[iter.Cursor]-1)
+ }
+ return int(iter.Key[iter.Cursor]) % int(c.indexes[iter.Cursor])
}
func (c *Dict[T]) begin(key string, initialize bool) *iterator {
- var iter = &iterator{Node: c.root, Key: key, N: min(len(key), len(c.indexes)-1), Initialize: initialize}
+ var iter = &iterator{Node: c.root, Key: key, N: algorithm.Min(len(key), len(c.indexes)-1), Initialize: initialize}
var idx = c.getIndex(iter)
if iter.Node.Children[idx] == nil {
if !iter.Initialize {
diff --git a/go.mod b/go.mod
index 10d38f7..a9d68ac 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
module github.com/lxzan/dao
-go 1.21
+go 1.18
require github.com/stretchr/testify v1.8.4
diff --git a/hashmap/hashmap.go b/hashmap/hashmap.go
index 93c43f0..af56003 100644
--- a/hashmap/hashmap.go
+++ b/hashmap/hashmap.go
@@ -4,14 +4,17 @@ type HashMap[K comparable, V any] map[K]V
// New instantiates a hashmap
// at most one param, means initial capacity
-func New[K comparable, V any](capacity uint32) HashMap[K, V] {
+func New[K comparable, V any](capacity int) HashMap[K, V] {
return make(map[K]V, capacity)
}
-// Reset clear contents
-func (c HashMap[K, V]) Reset() {
- clear(c)
-}
+//Reset clear contents
+//func (c HashMap[K, V]) Reset() {
+// keys := c.Keys()
+// for _, key := range keys {
+// delete(c, key)
+// }
+//}
// Len get the length of hashmap
func (c HashMap[K, V]) Len() int {
diff --git a/hashmap/hashmap_test.go b/hashmap/hashmap_test.go
index b8df2df..471db56 100644
--- a/hashmap/hashmap_test.go
+++ b/hashmap/hashmap_test.go
@@ -118,11 +118,3 @@ func TestHashMap_Range(t *testing.T) {
})
assert.Equal(t, len(keys), 2)
}
-
-func TestHashMap_Reset(t *testing.T) {
- var m = New[string, int](8)
- m.Set("a", 1)
- assert.Equal(t, m.Len(), 1)
- m.Reset()
- assert.Equal(t, m.Len(), 0)
-}
diff --git a/heap/heap.go b/heap/heap.go
index e21c3a8..d732c08 100644
--- a/heap/heap.go
+++ b/heap/heap.go
@@ -1,8 +1,7 @@
package heap
import (
- "cmp"
- "github.com/lxzan/dao"
+ "github.com/lxzan/dao/types/cmp"
)
const (
@@ -13,35 +12,34 @@ const (
Octal = 8
)
-// New 新建一个最小堆
-// Create a new minimum heap
-func New[T cmp.Ordered]() *Heap[T] { return NewHeap(cmp.Less[T]) }
+// New 新建一个最小四叉堆
+// Create a new minimum quadratic heap
+func New[T cmp.Ordered]() *Heap[T] { return NewWithForks(Quadratic, cmp.Less[T]) }
-// NewHeap 新建一个堆
-// Create a new heap
-func NewHeap[T any](less dao.LessFunc[T]) *Heap[T] {
- h := &Heap[T]{cmp: less}
- h.SetForkNumber(Quadratic)
+// NewWithForks 新建堆
+// @forks 分叉数, 可选值为: 2,4,6
+// @lessFunc 比较函数
+func NewWithForks[T any](forks uint32, lessFunc cmp.LessFunc[T]) *Heap[T] {
+ h := &Heap[T]{lessFunc: lessFunc}
+ h.setForkNumber(forks)
return h
}
-// Heap 可以不使用New函数, 声明为值类型自动初始化
-// 如果使用值类型, 需要设置比较函数和分叉数
type Heap[T any] struct {
- bits uint32
- forks int
- data []T
- cmp func(a, b T) bool
+ bits uint32
+ forks int
+ data []T
+ lessFunc func(a, b T) bool
}
// SetCap 设置预分配容量
-func (c *Heap[T]) SetCap(n uint32) *Heap[T] {
+func (c *Heap[T]) SetCap(n int) *Heap[T] {
c.data = make([]T, 0, n)
return c
}
-// SetForkNumber 设置分叉数
-func (c *Heap[T]) SetForkNumber(n uint32) *Heap[T] {
+// setForkNumber 设置分叉数
+func (c *Heap[T]) setForkNumber(n uint32) *Heap[T] {
c.forks = int(n)
switch n {
case Quadratic, Binary:
@@ -54,19 +52,13 @@ func (c *Heap[T]) SetForkNumber(n uint32) *Heap[T] {
return c
}
-// SetLessFunc 设置比较函数
-func (c *Heap[T]) SetLessFunc(less dao.LessFunc[T]) *Heap[T] {
- c.cmp = less
- return c
-}
-
// Len 获取元素数量
func (c *Heap[T]) Len() int {
return len(c.data)
}
func (c *Heap[T]) less(i, j int) bool {
- return c.cmp(c.data[i], c.data[j])
+ return c.lessFunc(c.data[i], c.data[j])
}
func (c *Heap[T]) swap(i, j int) {
@@ -102,7 +94,6 @@ func (c *Heap[T]) down(i, n int) {
// Reset 重置堆
func (c *Heap[T]) Reset() {
- clear(c.data)
c.data = c.data[:0]
}
@@ -149,3 +140,8 @@ func (c *Heap[T]) Clone() *Heap[T] {
copy(v.data, c.data)
return &v
}
+
+// UnWrap 解包, 返回底层数组
+func (c *Heap[T]) UnWrap() []T {
+ return c.data
+}
diff --git a/heap/heap_test.go b/heap/heap_test.go
index 4001c50..f5e4344 100644
--- a/heap/heap_test.go
+++ b/heap/heap_test.go
@@ -2,20 +2,22 @@ package heap
import (
"fmt"
- "github.com/lxzan/dao"
"github.com/lxzan/dao/internal/utils"
+ "github.com/lxzan/dao/types/cmp"
"github.com/stretchr/testify/assert"
"sort"
"testing"
"unsafe"
)
+func desc[T cmp.Ordered](a, b T) bool {
+ return a > b
+}
+
func TestNew(t *testing.T) {
const count = 1000
{
- var h = NewHeap[string](func(a, b string) bool {
- return a < b
- })
+ var h = New[string]()
var arr1 = make([]string, 0)
var arr2 = make([]string, 0)
for i := 0; i < count; i++ {
@@ -34,9 +36,8 @@ func TestNew(t *testing.T) {
}
func TestDesc(t *testing.T) {
- var h = NewHeap(dao.DescFunc[int])
+ var h = NewWithForks(Octal, desc[int])
h.SetCap(8)
- h.SetForkNumber(Octal)
h.Push(1)
assert.Equal(t, h.Top(), 1)
h.Push(3)
@@ -52,9 +53,8 @@ func TestDesc(t *testing.T) {
}
func TestAsc(t *testing.T) {
- var h = New[int]()
+ var h = NewWithForks(Binary, cmp.Less[int])
h.SetCap(8)
- h.SetForkNumber(Binary)
h.Push(1)
h.Push(3)
h.Push(2)
@@ -69,8 +69,8 @@ func TestAsc(t *testing.T) {
}
func TestHeap_Range(t *testing.T) {
- var h Heap[int]
- h.SetForkNumber(Quadratic).SetLessFunc(dao.AscFunc[int]).SetCap(8)
+ var h = NewWithForks(Quadratic, cmp.Less[int])
+ h.SetCap(8)
h.Push(1)
h.Push(3)
h.Push(2)
@@ -108,14 +108,13 @@ func TestHeap_Reset(t *testing.T) {
}
func TestHeap_Pop(t *testing.T) {
- var h = NewHeap(dao.AscFunc[int])
+ var h = New[int]()
assert.Equal(t, h.Pop(), 0)
h.Push(1)
assert.Equal(t, h.Pop(), 1)
}
func TestHeap_SetForkNumber(t *testing.T) {
- var h = NewHeap(dao.AscFunc[int])
var catch = func(f func()) (err error) {
defer func() {
if excp := recover(); excp != nil {
@@ -127,19 +126,18 @@ func TestHeap_SetForkNumber(t *testing.T) {
}
var err1 = catch(func() {
- h.SetForkNumber(3)
+ NewWithForks(3, cmp.Less[int])
})
assert.Error(t, err1)
var err2 = catch(func() {
- h.SetForkNumber(4)
+ NewWithForks(4, cmp.Less[int])
})
assert.Nil(t, err2)
}
func TestHeap_Clone(t *testing.T) {
- var h = NewHeap(dao.AscFunc[int])
- h.SetForkNumber(4)
+ var h = New[int]()
h.Push(1)
h.Push(3)
h.Push(2)
@@ -154,3 +152,11 @@ func TestHeap_Clone(t *testing.T) {
assert.NotEqual(t, addr, addr1)
assert.Equal(t, addr, addr2)
}
+
+func TestHeap_UnWrap(t *testing.T) {
+ var h = NewWithForks(2, cmp.Less[int])
+ h.Push(1)
+ h.Push(2)
+ h.Push(3)
+ assert.True(t, utils.IsSameSlice(h.UnWrap(), []int{1, 2, 3}))
+}
diff --git a/interface.go b/interface.go
index 9083c52..4f7a8ff 100644
--- a/interface.go
+++ b/interface.go
@@ -1,44 +1,6 @@
package dao
-import "cmp"
-
-const (
- Less = -1
- Equal = 0
- Greater = 1
-)
-
-type (
- // LessFunc 比大小
- LessFunc[T any] func(a, b T) bool
-
- // CompareFunc 比较函数
- // a>b, 返回1; a b }
-
-type Order uint8
-
-const (
- ASC Order = 0 // 升序
- DESC Order = 1 // 降序
-)
-
type (
- Number interface {
- Integer | ~float32 | ~float64
- }
-
- Integer interface {
- ~int64 | ~int | ~int32 | ~int16 | ~int8 | ~uint64 | ~uint | ~uint32 | ~uint16 | ~uint8
- }
-
// Map 键不可重复
Map[K comparable, V any] interface {
Len() int
diff --git a/internal/mlist/mlist.go b/internal/mlist/mlist.go
index f52e84b..e61fd7b 100644
--- a/internal/mlist/mlist.go
+++ b/internal/mlist/mlist.go
@@ -35,7 +35,6 @@ func NewMList[K comparable, V any](size uint32) *MList[K, V] {
func (c *MList[K, V]) Reset() {
c.length, c.serial = 0, 0
- clear(c.Buckets)
c.recyclable = c.recyclable[:0]
c.Buckets = c.Buckets[:1]
}
diff --git a/internal/utils/helper.go b/internal/utils/helper.go
index a370e53..3795c59 100644
--- a/internal/utils/helper.go
+++ b/internal/utils/helper.go
@@ -63,3 +63,10 @@ type Integer interface {
func IsBinaryNumber[T Integer](x T) bool {
return x&(x-1) == 0
}
+
+func Clone[S ~[]E, E any](s S) S {
+ if s == nil {
+ return nil
+ }
+ return append(S([]E{}), s...)
+}
diff --git a/internal/utils/helper_test.go b/internal/utils/helper_test.go
index a101d39..0aa6490 100644
--- a/internal/utils/helper_test.go
+++ b/internal/utils/helper_test.go
@@ -62,3 +62,17 @@ func TestIsBinaryNumber(t *testing.T) {
assert.False(t, IsBinaryNumber(7))
assert.False(t, IsBinaryNumber(21))
}
+
+func TestClone(t *testing.T) {
+ {
+ var a []int
+ var b = Clone(a)
+ assert.True(t, len(b) == 0)
+ }
+
+ {
+ var a = []int{1, 2, 3}
+ var b = Clone(a)
+ assert.ElementsMatch(t, b, a)
+ }
+}
diff --git a/queue/queue.go b/queue/queue.go
new file mode 100644
index 0000000..966f6c8
--- /dev/null
+++ b/queue/queue.go
@@ -0,0 +1,67 @@
+package queue
+
+import "github.com/lxzan/dao/internal/utils"
+
+type Queue[T any] struct {
+ offset int
+ tpl T
+ data []T
+}
+
+// New 创建队列
+func New[T any](capacity int) *Queue[T] {
+ return &Queue[T]{data: make([]T, 0, capacity)}
+}
+
+// NewFrom 从切片创建队列
+func NewFrom[T any](values ...T) *Queue[T] {
+ return &Queue[T]{data: values}
+}
+
+// Reset 重置
+func (c *Queue[T]) Reset() {
+ c.offset = 0
+ c.data = c.data[:0]
+}
+
+// Len 获取队列长度
+func (c *Queue[T]) Len() int {
+ return len(c.data) - c.offset
+}
+
+// Push 追加元素到队列尾部
+func (c *Queue[T]) Push(v T) {
+ c.data = append(c.data, v)
+}
+
+// Pop 从队列头部弹出元素
+func (c *Queue[T]) Pop() (value T) {
+ if n := c.Len(); n > 0 {
+ value = c.data[c.offset]
+ c.data[c.offset] = c.tpl
+ c.offset++
+ if c.offset == len(c.data) {
+ c.Reset()
+ }
+ }
+ return value
+}
+
+// Range 遍历
+func (c *Queue[T]) Range(f func(value T) bool) {
+ for _, item := range c.data {
+ if !f(item) {
+ return
+ }
+ }
+}
+
+// UnWrap 解包, 返回底层数组
+func (c *Queue[T]) UnWrap() []T {
+ return c.data[c.offset:]
+}
+
+// Clone 拷贝副本
+func (c *Queue[T]) Clone() *Queue[T] {
+ return &Queue[T]{data: utils.Clone(c.data)}
+}
diff --git a/queue/queue_test.go b/queue/queue_test.go
new file mode 100644
index 0000000..accec45
--- /dev/null
+++ b/queue/queue_test.go
@@ -0,0 +1,86 @@
+package queue
+
+import (
+ "github.com/lxzan/dao/internal/utils"
+ "github.com/stretchr/testify/assert"
+ "testing"
+)
+
+func TestQueue(t *testing.T) {
+ t.Run("", func(t *testing.T) {
+ q := NewFrom(1, 3, 5, 7, 9)
+ var a []int
+ for q.Len() > 0 {
+ a = append(a, q.Pop())
+ }
+ assert.True(t, utils.IsSameSlice(a, []int{1, 3, 5, 7, 9}))
+ assert.Equal(t, q.offset, 0)
+ })
+
+ t.Run("", func(t *testing.T) {
+ q := NewFrom[int](1)
+ q.Push(3)
+ q.Pop()
+ assert.Equal(t, q.offset, 1)
+ assert.Equal(t, q.data[0], 0)
+ assert.Equal(t, q.Len(), 1)
+ })
+}
+
+func TestQueue_UnWrap(t *testing.T) {
+ t.Run("", func(t *testing.T) {
+ q := NewFrom(1, 3, 5, 7, 9)
+ q.Pop()
+ a := q.UnWrap()
+ assert.True(t, utils.IsSameSlice(a, []int{3, 5, 7, 9}))
+ })
+
+ t.Run("", func(t *testing.T) {
+ q := NewFrom(1)
+ q.Pop()
+ a := q.UnWrap()
+ assert.Equal(t, len(a), 0)
+ })
+
+ t.Run("", func(t *testing.T) {
+ q := New[int](8)
+ a := q.UnWrap()
+ assert.Equal(t, len(a), 0)
+ })
+}
+
+func TestQueue_Clone(t *testing.T) {
+ q := NewFrom(1, 3, 5, 7, 9)
+ b := q.Clone()
+ assert.True(t, utils.IsSameSlice(b.UnWrap(), []int{1, 3, 5, 7, 9}))
+}
+
+func TestQueue_Range(t *testing.T) {
+ t.Run("", func(t *testing.T) {
+ var s = NewFrom(1, 3, 5)
+
+ var arr []int
+ s.Range(func(value int) bool {
+ arr = append(arr, value)
+ return true
+ })
+ assert.True(t, utils.IsSameSlice(arr, []int{1, 3, 5}))
+
+ s.Reset()
+ assert.Equal(t, s.Len(), 0)
+ })
+
+ t.Run("", func(t *testing.T) {
+ var s = New[int](8)
+ s.Push(1)
+ s.Push(3)
+ s.Push(5)
+
+ var arr []int
+ s.Range(func(value int) bool {
+ arr = append(arr, value)
+ return len(arr) < 2
+ })
+ assert.True(t, utils.IsSameSlice(arr, []int{1, 3}))
+ })
+}
diff --git a/rbtree/query.go b/rbtree/query.go
index bd65e59..4504b38 100644
--- a/rbtree/query.go
+++ b/rbtree/query.go
@@ -1,105 +1,35 @@
package rbtree
import (
- "cmp"
- "github.com/lxzan/dao"
"github.com/lxzan/dao/algorithm"
"github.com/lxzan/dao/stack"
+ "github.com/lxzan/dao/types/cmp"
)
-// Get 查询一个key
-func (c *RBTree[K, V]) Get(key K) (result V, exist bool) {
- for i := c.begin(); !c.end(i); i = c.next(i, key) {
- if key == i.data.Key {
- return i.data.Val, true
- }
- }
- return result, false
-}
-
-// Range 遍历树
-func (c *RBTree[K, V]) Range(fn func(key K, value V) bool) {
- var next = true
- c.do_range(c.root, &next, fn)
-}
-
-func (c *RBTree[K, V]) do_range(node *rbtree_node[K, V], next *bool, fn func(K, V) bool) {
- if c.end(node) || !*next {
- return
- }
+type Order uint8
- if ok := fn(node.data.Key, node.data.Val); !ok {
- *next = ok
- }
-
- c.do_range(node.left, next, fn)
- c.do_range(node.right, next, fn)
-}
-
-// GetMinKey 获取最小的key, 过滤条件可为空
-func (c *RBTree[K, V]) GetMinKey(filter func(key K) bool) (result Pair[K, V], exist bool) {
- return c.doGetMinKey(stack.New[*rbtree_node[K, V]](10), filter)
-}
-
-func (c *RBTree[K, V]) doGetMinKey(s *stack.Stack[*rbtree_node[K, V]], filter func(key K) bool) (result Pair[K, V], exist bool) {
- s.Reset()
- filter = algorithm.SelectValue(filter == nil, TrueFunc[K], filter)
- s.Push(c.root)
- for s.Len() > 0 {
- var node = s.Pop()
- if c.end(node) {
- continue
- }
- if filter(node.data.Key) {
- if !exist || node.data.Key < result.Key {
- exist = true
- result = *node.data
- }
- s.Push(node.left)
- } else {
- s.Push(node.right)
- }
- }
- return result, exist
-}
+const (
+ ASC Order = 0 // 升序
+ DESC Order = 1 // 降序
+)
-// GetMaxKey 获取最大的key, 过滤条件可为空
-func (c *RBTree[K, V]) GetMaxKey(filter func(key K) bool) (result Pair[K, V], exist bool) {
- return c.doGetMaxKey(stack.New[*rbtree_node[K, V]](10), filter)
-}
+func TrueFunc[K cmp.Ordered](d K) bool { return true }
-func (c *RBTree[K, V]) doGetMaxKey(s *stack.Stack[*rbtree_node[K, V]], filter func(key K) bool) (result Pair[K, V], exist bool) {
- s.Reset()
- filter = algorithm.SelectValue(filter == nil, TrueFunc[K], filter)
- s.Push(c.root)
- for s.Len() > 0 {
- var node = s.Pop()
- if c.end(node) {
- continue
- }
- if filter(node.data.Key) {
- if !exist || node.data.Key > result.Key {
- exist = true
- result = *node.data
- }
- s.Push(node.right)
- } else {
- s.Push(node.left)
- }
- }
- return result, exist
-}
-
-func TrueFunc[K cmp.Ordered](d K) bool {
- return true
+// NewQuery 新建一个查询
+func (c *RBTree[K, V]) NewQuery() *QueryBuilder[K, V] {
+ return &QueryBuilder[K, V]{tree: c}
}
type QueryBuilder[K cmp.Ordered, V any] struct {
- tree *RBTree[K, V]
- leftFilter func(key K) bool
- rightFilter func(key K) bool
- limit int
- order dao.Order
+ tree *RBTree[K, V] // 红黑树
+ results []Pair[K, V] // 查询结果
+
+ leftFilter func(key K) bool // 左边界条件
+ rightFilter func(key K) bool // 右边界条件
+ limit int // 单页限制
+ total int // 总条数
+ offset int // 偏移量
+ order Order // 排序
}
func (c *QueryBuilder[K, V]) init() *QueryBuilder[K, V] {
@@ -112,6 +42,7 @@ func (c *QueryBuilder[K, V]) init() *QueryBuilder[K, V] {
if c.limit <= 0 {
c.limit = 10
}
+ c.total = c.offset + c.limit
return c
}
@@ -128,7 +59,7 @@ func (c *QueryBuilder[K, V]) Right(f func(key K) bool) *QueryBuilder[K, V] {
}
// Order 设置排序, 默认ASC
-func (c *QueryBuilder[K, V]) Order(o dao.Order) *QueryBuilder[K, V] {
+func (c *QueryBuilder[K, V]) Order(o Order) *QueryBuilder[K, V] {
c.order = o
return c
}
@@ -139,59 +70,182 @@ func (c *QueryBuilder[K, V]) Limit(n int) *QueryBuilder[K, V] {
return c
}
-// Do 执行查询
-func (c *QueryBuilder[K, V]) Do() []Pair[K, V] {
- return c.tree.do_query(c)
+func (c *QueryBuilder[K, V]) Offset(n int) *QueryBuilder[K, V] {
+ c.offset = n
+ return c
}
-// NewQuery 新建一个查询
-func (c *RBTree[K, V]) NewQuery() *QueryBuilder[K, V] {
- return &QueryBuilder[K, V]{tree: c}
+// FindAll 执行查询
+func (c *QueryBuilder[K, V]) FindAll() []Pair[K, V] {
+ c.init()
+ c.results = make([]Pair[K, V], 0, c.total)
+
+ switch c.order {
+ case DESC:
+ c.rangeDesc(c.tree.root)
+ case ASC:
+ c.rangeAsc(c.tree.root)
+ }
+
+ if c.offset > 0 {
+ if len(c.results) > c.offset {
+ c.results = c.results[c.offset:]
+ } else {
+ c.results = c.results[:0]
+ }
+ }
+ return c.results
}
-func (c *RBTree[K, V]) do_query(q *QueryBuilder[K, V]) []Pair[K, V] {
- q.init()
- var results = make([]Pair[K, V], 0, q.limit)
- var s = stack.New[*rbtree_node[K, V]](uint32(q.limit))
+// 降序遍历 中序遍历是有序的
+func (c *QueryBuilder[K, V]) rangeDesc(node *rbtree_node[K, V]) {
+ if c.tree.end(node) || len(c.results) >= c.total {
+ return
+ }
+
+ state := 0
+ if c.rightFilter(node.data.Key) {
+ state += 1
+ }
+ if c.leftFilter(node.data.Key) {
+ state += 2
+ }
+
+ switch state {
+ case 3:
+ c.rangeDesc(node.right)
+ if len(c.results) < c.total {
+ c.results = append(c.results, *node.data)
+ } else {
+ return
+ }
+ c.rangeDesc(node.left)
+ case 2:
+ c.rangeDesc(node.left)
+ case 1:
+ c.rangeDesc(node.right)
+ }
+}
- if q.order == dao.DESC {
- maxEle, exist := c.doGetMaxKey(s, q.rightFilter)
- if exist && q.leftFilter(maxEle.Key) {
- results = append(results, maxEle)
+// 升序遍历 中序遍历是有序的
+func (c *QueryBuilder[K, V]) rangeAsc(node *rbtree_node[K, V]) {
+ if c.tree.end(node) || len(c.results) >= c.total {
+ return
+ }
+
+ state := 0
+ if c.rightFilter(node.data.Key) {
+ state += 1
+ }
+ if c.leftFilter(node.data.Key) {
+ state += 2
+ }
+
+ switch state {
+ case 3:
+ c.rangeAsc(node.left)
+ if len(c.results) < c.total {
+ c.results = append(c.results, *node.data)
} else {
- return results
+ return
}
+ c.rangeAsc(node.right)
+ case 2:
+ c.rangeAsc(node.left)
+ case 1:
+ c.rangeAsc(node.right)
+ }
+}
- for i := 0; i < q.limit-1; i++ {
- result, exist := c.doGetMaxKey(s, func(key K) bool {
- return key < maxEle.Key
- })
- if exist && q.leftFilter(result.Key) {
- results = append(results, result)
- maxEle = result
- } else {
- break
+func (c *QueryBuilder[K, V]) FindOne() (p Pair[K, V], exist bool) {
+ c.Limit(1).Offset(0).init()
+
+ switch c.order {
+ case DESC:
+ if v, ok := c.getMaxPair(c.rightFilter); ok {
+ if c.leftFilter(v.Key) {
+ return *v, true
+ }
+ }
+ case ASC:
+ if v, ok := c.getMinPair(c.leftFilter); ok {
+ if c.rightFilter(v.Key) {
+ return *v, true
}
}
- } else {
- minEle, exist := c.doGetMinKey(s, q.leftFilter)
- if exist && q.rightFilter(minEle.Key) {
- results = append(results, minEle)
+ }
+
+ return p, exist
+}
+
+func (c *QueryBuilder[K, V]) getMaxPair(filter func(key K) bool) (result *Pair[K, V], exist bool) {
+ var s = stack.Stack[*rbtree_node[K, V]]{}
+ s.Push(c.tree.root)
+ for s.Len() > 0 {
+ var node = s.Pop()
+ if c.tree.end(node) {
+ continue
+ }
+ if filter(node.data.Key) {
+ if !exist || node.data.Key > result.Key {
+ exist = true
+ result = node.data
+ }
+ s.Push(node.right)
} else {
- return results
+ s.Push(node.left)
}
+ }
+ return result, exist
+}
- for i := 0; i < q.limit-1; i++ {
- result, exist := c.doGetMinKey(s, func(key K) bool {
- return key > minEle.Key
- })
- if exist && q.rightFilter(result.Key) {
- results = append(results, result)
- minEle = result
- } else {
- break
+func (c *QueryBuilder[K, V]) getMinPair(filter func(key K) bool) (result *Pair[K, V], exist bool) {
+ var s = stack.Stack[*rbtree_node[K, V]]{}
+ filter = algorithm.SelectValue(filter == nil, TrueFunc[K], filter)
+ s.Push(c.tree.root)
+ for s.Len() > 0 {
+ var node = s.Pop()
+ if c.tree.end(node) {
+ continue
+ }
+ if filter(node.data.Key) {
+ if !exist || node.data.Key < result.Key {
+ exist = true
+ result = node.data
}
+ s.Push(node.left)
+ } else {
+ s.Push(node.right)
+ }
+ }
+ return result, exist
+}
+
+// Get 查询一个key
+func (c *RBTree[K, V]) Get(key K) (result V, exist bool) {
+ for i := c.begin(); !c.end(i); i = c.next(i, key) {
+ if key == i.data.Key {
+ return i.data.Val, true
}
}
- return results
+ return result, false
+}
+
+// Range 遍历树
+func (c *RBTree[K, V]) Range(fn func(key K, value V) bool) {
+ var next = true
+ c.do_range(c.root, &next, fn)
+}
+
+func (c *RBTree[K, V]) do_range(node *rbtree_node[K, V], next *bool, fn func(K, V) bool) {
+ if c.end(node) || !*next {
+ return
+ }
+
+ if ok := fn(node.data.Key, node.data.Val); !ok {
+ *next = ok
+ }
+
+ c.do_range(node.left, next, fn)
+ c.do_range(node.right, next, fn)
}
diff --git a/rbtree/rbtree.go b/rbtree/rbtree.go
index 9823470..01a70b8 100644
--- a/rbtree/rbtree.go
+++ b/rbtree/rbtree.go
@@ -1,7 +1,7 @@
package rbtree
import (
- "cmp"
+ "github.com/lxzan/dao/types/cmp"
)
type Color uint8
diff --git a/rbtree/rbtree_test.go b/rbtree/rbtree_test.go
index db73079..d441990 100644
--- a/rbtree/rbtree_test.go
+++ b/rbtree/rbtree_test.go
@@ -1,14 +1,13 @@
package rbtree
import (
- "cmp"
"fmt"
- "github.com/lxzan/dao"
"github.com/lxzan/dao/algorithm"
+ "github.com/lxzan/dao/hashmap"
"github.com/lxzan/dao/internal/utils"
"github.com/lxzan/dao/internal/validator"
+ "github.com/lxzan/dao/types/cmp"
"github.com/stretchr/testify/assert"
- "math/rand"
"sort"
"strconv"
"testing"
@@ -141,49 +140,131 @@ func TestRBTree_ForEach(t *testing.T) {
}
func TestRBTree_Between(t *testing.T) {
- var tree = New[string, int]()
- var m = make(map[string]int)
- for i := 0; i < 10000; i++ {
- var length = utils.Rand.Intn(16) + 1
- var key = utils.Numeric.Generate(4)
- m[key] = length
- tree.Set(key, length)
- }
+ t.Run("desc", func(t *testing.T) {
+ var tree = New[string, int]()
+ var m = make(map[string]int)
+ for i := 0; i < 10000; i++ {
+ var length = utils.Rand.Intn(16) + 1
+ var key = utils.Numeric.Generate(4)
+ m[key] = length
+ tree.Set(key, length)
+ }
- var limit = 100
- for i := 0; i < 100; i++ {
- var left = utils.Numeric.Generate(4)
- x, _ := strconv.Atoi(left)
+ var limit = 100
+ for i := 0; i < 100; i++ {
+ var left = utils.Numeric.Generate(4)
+ x, _ := strconv.Atoi(left)
- var right = fmt.Sprintf("%04d", x+limit)
- if left > right {
- right, left = left, right
- }
- var keys1 = tree.
- NewQuery().
- Left(func(key string) bool { return key >= left }).
- Right(func(key string) bool { return key <= right }).
- Order(dao.DESC).
- Limit(limit).
- Do()
- var keys2 = make([]string, 0)
- for k := range m {
- if k >= left && k <= right {
- keys2 = append(keys2, k)
+ var right = fmt.Sprintf("%04d", x+limit)
+ if left > right {
+ right, left = left, right
+ }
+ var values = tree.
+ NewQuery().
+ Left(func(key string) bool { return key >= left }).
+ Right(func(key string) bool { return key <= right }).
+ Order(DESC).
+ Limit(limit).
+ FindAll()
+ var keys1 = algorithm.Map[Pair[string, int], string](values, func(i int, v Pair[string, int]) string {
+ return v.Key
+ })
+
+ var keys2 = make([]string, 0)
+ for k := range m {
+ if k >= left && k <= right {
+ keys2 = append(keys2, k)
+ }
}
+ sort.Strings(keys2)
+ algorithm.Reverse(keys2)
+ if len(keys2) > limit {
+ keys2 = keys2[:limit]
+ }
+
+ assert.True(t, utils.IsSameSlice(keys1, keys2))
}
- sort.Strings(keys2)
- algorithm.Reverse(keys2)
- if len(keys2) > limit {
- keys2 = keys2[:limit]
+ })
+
+ t.Run("asc", func(t *testing.T) {
+ var tree = New[string, int]()
+ var m = make(map[string]int)
+ for i := 0; i < 10000; i++ {
+ var length = utils.Rand.Intn(16) + 1
+ var key = utils.Numeric.Generate(4)
+ m[key] = length
+ tree.Set(key, length)
}
- if !utils.IsSameSlice(keys2, algorithm.Map(keys1, func(i int, x Pair[string, int]) string {
- return x.Key
- })) {
- t.Fatal("error!")
+ var limit = 100
+ for i := 0; i < 100; i++ {
+ var left = utils.Numeric.Generate(4)
+ x, _ := strconv.Atoi(left)
+
+ var right = fmt.Sprintf("%04d", x+limit)
+ if left > right {
+ right, left = left, right
+ }
+ var values = tree.
+ NewQuery().
+ Left(func(key string) bool { return key >= left }).
+ Right(func(key string) bool { return key <= right }).
+ Order(ASC).
+ Limit(limit).
+ Offset(10).
+ FindAll()
+ var keys1 = algorithm.Map[Pair[string, int], string](values, func(i int, v Pair[string, int]) string {
+ return v.Key
+ })
+
+ var keys2 = make([]string, 0)
+ for k := range m {
+ if k >= left && k <= right {
+ keys2 = append(keys2, k)
+ }
+ }
+ sort.Strings(keys2)
+ if len(keys2) > 10 {
+ keys2 = keys2[10:]
+ } else {
+ keys2 = keys2[:0]
+ }
+ if len(keys2) > limit {
+ keys2 = keys2[:limit]
+ }
+
+ assert.True(t, utils.IsSameSlice(keys1, keys2))
}
- }
+ })
+
+ t.Run("", func(t *testing.T) {
+ var tree = New[int, uint8]()
+ tree.Set(1, 1)
+ tree.Set(2, 1)
+ tree.Set(3, 1)
+ tree.Set(4, 1)
+ tree.Set(5, 1)
+
+ var values0 = tree.
+ NewQuery().
+ Left(func(key int) bool { return key >= 1 }).
+ Right(func(key int) bool { return key <= 3 }).
+ Order(ASC).
+ Limit(10).
+ Offset(5).
+ FindAll()
+ assert.Equal(t, len(values0), 0)
+
+ var values1 = tree.
+ NewQuery().
+ Left(func(key int) bool { return key >= 1 }).
+ Right(func(key int) bool { return key <= 3 }).
+ Order(ASC).
+ Limit(10).
+ FindAll()
+ var keys1 = algorithm.Map(values1, func(i int, v Pair[int, uint8]) int { return v.Key })
+ assert.True(t, utils.IsSameSlice(keys1, []int{1, 2, 3}))
+ })
}
func TestRBTree_GreaterEqual(t *testing.T) {
@@ -199,11 +280,14 @@ func TestRBTree_GreaterEqual(t *testing.T) {
var limit = 100
for i := 0; i < 100; i++ {
var left = utils.Numeric.Generate(4)
- var keys1 = tree.
+ var values = tree.
NewQuery().
Left(func(key string) bool { return key >= left }).
Limit(limit).
- Do()
+ FindAll()
+ var keys1 = algorithm.Map[Pair[string, int], string](values, func(i int, v Pair[string, int]) string {
+ return v.Key
+ })
var keys2 = make([]string, 0)
for k := range m {
if k >= left {
@@ -215,11 +299,7 @@ func TestRBTree_GreaterEqual(t *testing.T) {
keys2 = keys2[:limit]
}
- if !utils.IsSameSlice(keys2, algorithm.Map(keys1, func(i int, x Pair[string, int]) string {
- return x.Key
- })) {
- t.Fatal("error!")
- }
+ assert.True(t, utils.IsSameSlice(keys1, keys2))
}
}
@@ -236,11 +316,15 @@ func TestRBTree_LessEqual(t *testing.T) {
var limit = 10
for i := 0; i < 100; i++ {
var target = utils.Numeric.Generate(4)
- var keys1 = tree.
+ var results = tree.
NewQuery().
Right(func(key string) bool { return key <= target }).
- Order(dao.DESC).
- Do()
+ Order(DESC).
+ Limit(limit).
+ FindAll()
+ var keys1 = algorithm.Map[Pair[string, int], string](results, func(i int, v Pair[string, int]) string {
+ return v.Key
+ })
var keys2 = make([]string, 0)
for k := range m {
if k <= target {
@@ -253,78 +337,85 @@ func TestRBTree_LessEqual(t *testing.T) {
keys2 = keys2[:limit]
}
- if !utils.IsSameSlice(keys2, algorithm.Map(keys1, func(i int, x Pair[string, int]) string {
- return x.Key
- })) {
- t.Fatal("error!")
- }
+ assert.True(t, utils.IsSameSlice(keys1, keys2))
}
}
-func TestRBTree_GetMinKey(t *testing.T) {
- var tree = New[string, int]()
-
- const test_count = 100
- for i := 0; i < test_count; i++ {
- var v = rand.Intn(10000)
- tree.Set(strconv.Itoa(v), v)
- }
-
- for i := 0; i < test_count; i++ {
- var k = strconv.Itoa(rand.Intn(10000))
- result, exist := tree.GetMinKey(func(key string) bool {
- return key >= k
- })
+func TestRBTree_FindOne(t *testing.T) {
+ t.Run("desc", func(t *testing.T) {
+ var tree = New[string, int]()
+ var m = hashmap.New[string, int](0)
+ for i := 0; i < 10000; i++ {
+ var length = utils.Rand.Intn(16) + 1
+ var key = utils.Numeric.Generate(4)
+ m.Set(key, length)
+ tree.Set(key, length)
+ }
- if !exist {
- tree.Range(func(key string, value int) bool {
- if key >= k {
- t.Fatal("error!")
+ for i := 0; i < 100; i++ {
+ var target = utils.Numeric.Generate(4)
+ v0, ok0 := tree.
+ NewQuery().
+ Right(func(key string) bool { return key <= target }).
+ Order(DESC).
+ FindOne()
+
+ var v1, ok1 = "", false
+ m.Range(func(key string, val int) bool {
+ if key <= target && (v1 == "" || key > v1) {
+ v1 = key
+ ok1 = true
}
return true
})
- } else {
- tree.Range(func(key string, value int) bool {
- if key < result.Key && key >= k {
- t.Fatal("error!")
- }
- return true
- })
- }
- }
-}
-
-func TestRBTree_GetMaxKey(t *testing.T) {
- var tree = New[string, int]()
- const test_count = 100
- for i := 0; i < test_count; i++ {
- var v = rand.Intn(10000)
- tree.Set(strconv.Itoa(v), v)
- }
+ assert.Equal(t, ok0, ok1)
+ if ok0 {
+ assert.Equal(t, v0.Key, v1)
+ }
+ }
+ })
- for i := 0; i < test_count; i++ {
- var k = strconv.Itoa(rand.Intn(10000))
- result, exist := tree.GetMaxKey(func(key string) bool {
- return key <= k
- })
+ t.Run("asc", func(t *testing.T) {
+ var tree = New[string, int]()
+ var m = hashmap.New[string, int](0)
+ for i := 0; i < 10000; i++ {
+ var length = utils.Rand.Intn(16) + 1
+ var key = utils.Numeric.Generate(4)
+ m.Set(key, length)
+ tree.Set(key, length)
+ }
- if !exist {
- tree.Range(func(key string, value int) bool {
- if key <= k {
- t.Fatal("error!")
- }
- return true
- })
- } else {
- tree.Range(func(key string, value int) bool {
- if key > result.Key && key <= k {
- t.Fatal("error!")
+ for i := 0; i < 100; i++ {
+ var target = utils.Numeric.Generate(4)
+ v0, ok0 := tree.
+ NewQuery().
+ Left(func(key string) bool { return key >= target }).
+ Order(ASC).
+ FindOne()
+
+ var v1, ok1 = "", false
+ m.Range(func(key string, val int) bool {
+ if key >= target && (v1 == "" || key < v1) {
+ v1 = key
+ ok1 = true
}
return true
})
+
+ assert.Equal(t, ok0, ok1)
+ if ok0 {
+ assert.Equal(t, v0.Key, v1)
+ }
}
- }
+ })
+
+ t.Run("", func(t *testing.T) {
+ var tree = New[string, int]()
+ var qb = QueryBuilder[string, int]{tree: tree}
+ _, ok := qb.FindOne()
+ assert.False(t, ok)
+ })
}
func TestDict_Map(t *testing.T) {
@@ -344,7 +435,7 @@ func TestRBTree_NewQuery(t *testing.T) {
var results = tree.
NewQuery().
Left(func(key int) bool { return key > 10 }).
- Do()
+ FindAll()
assert.Equal(t, len(results), 0)
})
@@ -352,8 +443,8 @@ func TestRBTree_NewQuery(t *testing.T) {
var results = tree.
NewQuery().
Left(func(key int) bool { return key > 10 }).
- Order(dao.DESC).
- Do()
+ Order(DESC).
+ FindAll()
assert.Equal(t, len(results), 0)
})
}
diff --git a/segment_tree/segement_tree_test.go b/segment_tree/segement_tree_test.go
index 1c93634..681dc5e 100644
--- a/segment_tree/segement_tree_test.go
+++ b/segment_tree/segement_tree_test.go
@@ -21,7 +21,7 @@ func TestSegmentTree_Query(t *testing.T) {
if left > right {
left, right = right, left
}
- var result1 = tree.Query(left, right)
+ var result1 = tree.Query(left, right+1)
var result2 = Int64Schema{
MaxValue: arr[left].Value(),
@@ -51,7 +51,7 @@ func TestSegmentTree_Query(t *testing.T) {
if left > right {
left, right = right, left
}
- var result1 = tree.Query(left, right)
+ var result1 = tree.Query(left, right+1)
var result2 = Int64Schema{
MaxValue: arr[left].Value(),
diff --git a/segment_tree/segment_tree.go b/segment_tree/segment_tree.go
index 13d7aac..c9bf094 100644
--- a/segment_tree/segment_tree.go
+++ b/segment_tree/segment_tree.go
@@ -63,11 +63,11 @@ func (c *SegmentTree[S, T]) build(cur *Element[S, T]) {
cur.data = cur.son.data.Merge(cur.daughter.data)
}
-// Query 查询 left <= index <= right 区间
-func (c *SegmentTree[S, T]) Query(left int, right int) S {
+// Query 查询 begin <= index < end 区间
+func (c *SegmentTree[S, T]) Query(begin int, end int) S {
var result S
- result = c.arr[left].Init(OperateQuery)
- c.doQuery(c.root, left, right, &result)
+ result = c.arr[begin].Init(OperateQuery)
+ c.doQuery(c.root, begin, end-1, &result)
return result
}
diff --git a/stack/stack.go b/stack/stack.go
index 6dae270..cccfdce 100644
--- a/stack/stack.go
+++ b/stack/stack.go
@@ -3,24 +3,34 @@ package stack
// Stack 可以不使用New函数, 声明为值类型自动初始化
type Stack[T any] []T
-func New[T any](capacity uint32) *Stack[T] {
+// New 创建栈
+func New[T any](capacity int) *Stack[T] {
s := Stack[T](make([]T, 0, capacity))
return &s
}
+// NewFrom 从可变参数切片创建栈
+func NewFrom[T any](values ...T) *Stack[T] {
+ c := Stack[T](values)
+ return &c
+}
+
+// Reset 重置
func (c *Stack[T]) Reset() {
- clear(*c)
*c = (*c)[:0]
}
+// Len 获取元素数量
func (c *Stack[T]) Len() int {
return len(*c)
}
+// Push 追加元素
func (c *Stack[T]) Push(v T) {
*c = append(*c, v)
}
+// Pop 弹出元素
func (c *Stack[T]) Pop() (value T) {
n := c.Len()
switch n {
@@ -33,6 +43,7 @@ func (c *Stack[T]) Pop() (value T) {
}
}
+// Range 遍历
func (c *Stack[T]) Range(f func(value T) bool) {
for _, item := range *c {
if !f(item) {
@@ -40,3 +51,8 @@ func (c *Stack[T]) Range(f func(value T) bool) {
}
}
}
+
+// UnWrap 解包, 返回底层数组
+func (c *Stack[T]) UnWrap() []T {
+ return *(*[]T)(c)
+}
diff --git a/stack/stack_test.go b/stack/stack_test.go
index cd2d809..e7d7788 100644
--- a/stack/stack_test.go
+++ b/stack/stack_test.go
@@ -22,10 +22,7 @@ func TestStack_Pop(t *testing.T) {
func TestStack_Range(t *testing.T) {
t.Run("", func(t *testing.T) {
- var s = New[int](8)
- s.Push(1)
- s.Push(3)
- s.Push(5)
+ var s = NewFrom(1, 3, 5)
var arr []int
s.Range(func(value int) bool {
@@ -52,3 +49,9 @@ func TestStack_Range(t *testing.T) {
assert.True(t, utils.IsSameSlice(arr, []int{1, 3}))
})
}
+
+func TestStack_UnWrap(t *testing.T) {
+ var s = NewFrom(1, 3, 5)
+ var a = s.UnWrap()
+ assert.ElementsMatch(t, a, []int{1, 3, 5})
+}
diff --git a/types/cmp/cmp.go b/types/cmp/cmp.go
new file mode 100644
index 0000000..76572d3
--- /dev/null
+++ b/types/cmp/cmp.go
@@ -0,0 +1,47 @@
+package cmp
+
+type (
+ Ordered interface {
+ ~int | ~int8 | ~int16 | ~int32 | ~int64 |
+ ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr |
+ ~float32 | ~float64 |
+ ~string
+ }
+
+ Number interface {
+ Integer | ~float32 | ~float64
+ }
+
+ Integer interface {
+ ~int64 | ~int | ~int32 | ~int16 | ~int8 | ~uint64 | ~uint | ~uint32 | ~uint16 | ~uint8
+ }
+)
+
+const (
+ LT = -1 // 小于
+ EQ = 0 // 等于
+ GT = 1 // 大于
+)
+
+type (
+ // LessFunc 比大小
+ LessFunc[T any] func(a, b T) bool
+
+ // CompareFunc 比较函数
+ // a>b, 返回1; a y {
+ return +1
+ }
+ return 0
+}
diff --git a/vector/types.go b/vector/types.go
new file mode 100644
index 0000000..476c60c
--- /dev/null
+++ b/vector/types.go
@@ -0,0 +1,27 @@
+package vector
+
+import (
+ "github.com/lxzan/dao/types/cmp"
+)
+
+type Document[T cmp.Ordered] interface {
+ GetID() T
+}
+
+type (
+ Int int
+
+ Int64 int64
+
+ String string
+)
+
+func (c Int) GetID() int {
+ return int(c)
+}
+
+func (c Int64) GetID() int64 {
+ return int64(c)
+}
+
+func (c String) GetID() string { return string(c) }
diff --git a/vector/vector.go b/vector/vector.go
new file mode 100644
index 0000000..fbcbdc0
--- /dev/null
+++ b/vector/vector.go
@@ -0,0 +1,183 @@
+package vector
+
+import (
+ "github.com/lxzan/dao/algorithm"
+ "github.com/lxzan/dao/hashmap"
+ "github.com/lxzan/dao/internal/utils"
+ "github.com/lxzan/dao/types/cmp"
+ "unsafe"
+)
+
+// New 创建动态数组
+func New[D Document[K], K cmp.Ordered](capacity int) *Vector[D, K] {
+ c := Vector[D, K](make([]D, 0, capacity))
+ return &c
+}
+
+// NewFromDocs 从可变参数创建动态数组
+func NewFromDocs[D Document[K], K cmp.Ordered](values ...D) *Vector[D, K] {
+ c := Vector[D, K](values)
+ return &c
+}
+
+// NewFromInts 创建动态数组
+func NewFromInts(values ...int) *Vector[Int, int] {
+ var b = *(*[]Int)(unsafe.Pointer(&values))
+ v := Vector[Int, int](b)
+ return &v
+}
+
+// NewFromInt64s 创建动态数组
+func NewFromInt64s(values ...int64) *Vector[Int64, int64] {
+ var b = *(*[]Int64)(unsafe.Pointer(&values))
+ v := Vector[Int64, int64](b)
+ return &v
+}
+
+// NewFromStrings 创建动态数组
+func NewFromStrings(values ...string) *Vector[String, string] {
+ var b = *(*[]String)(unsafe.Pointer(&values))
+ v := Vector[String, string](b)
+ return &v
+}
+
+// Vector 动态数组
+type Vector[D Document[K], K cmp.Ordered] []D
+
+// Reset 重置
+func (c *Vector[D, K]) Reset() {
+ *c = (*c)[:0]
+}
+
+// Len 获取元素数量
+func (c *Vector[D, K]) Len() int {
+ return len(*c)
+}
+
+// Get 根据下标取值
+func (c *Vector[D, K]) Get(index int) D {
+ return (*c)[index]
+}
+
+// Update 根据下标修改值
+func (c *Vector[D, K]) Update(index int, value D) {
+ (*c)[index] = value
+}
+
+// Elem 取值
+func (c *Vector[D, K]) Elem() []D {
+ return *c
+}
+
+// Exists 根据id判断某条数据是否存在
+func (c *Vector[D, K]) Exists(id K) (v D, exist bool) {
+ for _, item := range *c {
+ if item.GetID() == id {
+ return item, true
+ }
+ }
+ return v, exist
+}
+
+// Unique 排序并根据id去重
+func (c *Vector[D, K]) Unique() *Vector[D, K] {
+ *c = algorithm.UniqueBy(*c, func(item D) K {
+ return item.GetID()
+ })
+ return c
+}
+
+// Filter 过滤
+func (c *Vector[D, K]) Filter(f func(i int, v D) bool) *Vector[D, K] {
+ *c = algorithm.Filter(*c, f)
+ return c
+}
+
+// Sort 排序
+func (c *Vector[D, K]) Sort() *Vector[D, K] {
+ algorithm.SortBy(*c, func(a, b D) int {
+ return cmp.Compare(a.GetID(), b.GetID())
+ })
+ return c
+}
+
+// IdList 获取id数组
+func (c *Vector[D, K]) IdList() []K {
+ var d D
+ switch any(d).(type) {
+ case Int, Int64, String:
+ var keys = *(*[]K)(unsafe.Pointer(c))
+ return keys
+ default:
+ var keys = make([]K, 0, c.Len())
+ for _, item := range *c {
+ keys = append(keys, item.GetID())
+ }
+ return keys
+ }
+}
+
+// ToMap 生成map[K]D
+func (c *Vector[D, K]) ToMap() hashmap.HashMap[K, D] {
+ var m = hashmap.New[K, D](c.Len())
+ for _, item := range *c {
+ m.Set(item.GetID(), item)
+ }
+ return m
+}
+
+// PushBack 向尾部追加元素
+func (c *Vector[D, K]) PushBack(v D) {
+ *c = append(*c, v)
+}
+
+// PopFront 从头部弹出元素
+func (c *Vector[D, K]) PopFront() (value D) {
+ switch c.Len() {
+ case 0:
+ return value
+ default:
+ value = (*c)[0]
+ *c = (*c)[1:]
+ return value
+ }
+}
+
+// PopBack 从尾部弹出元素
+func (c *Vector[D, K]) PopBack() (value D) {
+ n := c.Len()
+ switch n {
+ case 0:
+ return value
+ default:
+ value = (*c)[n-1]
+ *c = (*c)[:n-1]
+ return value
+ }
+}
+
+// Range 遍历
+func (c *Vector[D, K]) Range(f func(i int, v D) bool) {
+ for index, value := range *c {
+ if !f(index, value) {
+ return
+ }
+ }
+}
+
+// Clone 拷贝
+func (c *Vector[D, K]) Clone() *Vector[D, K] {
+ var d = utils.Clone(*c)
+ return &d
+}
+
+// Slice 截取子数组
+func (c *Vector[D, K]) Slice(start, end int) *Vector[D, K] {
+ var children = (*c)[start:end]
+ return &children
+}
+
+func (c *Vector[D, K]) Reverse() *Vector[D, K] {
+ *c = algorithm.Reverse(*c)
+ return c
+}
diff --git a/vector/vector_test.go b/vector/vector_test.go
new file mode 100644
index 0000000..f87cbb5
--- /dev/null
+++ b/vector/vector_test.go
@@ -0,0 +1,180 @@
+package vector
+
+import (
+ "github.com/lxzan/dao/internal/utils"
+ "github.com/stretchr/testify/assert"
+ "testing"
+ "unsafe"
+)
+
+func TestUser_GetID(t *testing.T) {
+ var docs Vector[user, string]
+ docs = append(docs, user{ID: "a"})
+ docs = append(docs, user{ID: "c"})
+ docs = append(docs, user{ID: "c"})
+ docs = append(docs, user{ID: "b"})
+ docs.Unique()
+ docs.Sort()
+ docs.Filter(func(i int, v user) bool {
+ return v.ID == "b"
+ })
+}
+
+type user struct {
+ ID string
+}
+
+func (u user) GetID() string {
+ return u.ID
+}
+
+func TestNewFromInts(t *testing.T) {
+ var a = NewFromInts(1, 3, 5)
+ var b = a.IdList()
+ assert.ElementsMatch(t, b, []int{1, 3, 5})
+}
+
+func TestNewFromInt64s(t *testing.T) {
+ var a = NewFromInt64s(1, 3, 5)
+ var b = a.IdList()
+ assert.ElementsMatch(t, b, []int64{1, 3, 5})
+}
+
+func TestVector_Keys(t *testing.T) {
+ t.Run("", func(t *testing.T) {
+ var a = NewFromStrings("a", "b", "c")
+ var b = a.IdList()
+ assert.ElementsMatch(t, b, []string{"a", "b", "c"})
+ assert.Equal(t, a.Get(0).GetID(), "a")
+
+ var addr0 = (uintptr)(unsafe.Pointer(&(*a)[0]))
+ var addr1 = (uintptr)(unsafe.Pointer(&b[0]))
+ assert.Equal(t, addr0, addr1)
+
+ var values = a.Elem()
+ assert.ElementsMatch(t, values, []String{"a", "b", "c"})
+ })
+
+ t.Run("", func(t *testing.T) {
+ var docs = NewFromDocs[user, string](
+ user{ID: "a"},
+ user{ID: "b"},
+ user{ID: "c"},
+ )
+ assert.ElementsMatch(t, docs.IdList(), []string{"a", "b", "c"})
+ })
+}
+
+func TestVector_Exists(t *testing.T) {
+ var v = New[Int, int](8)
+ v.PushBack(1)
+ v.PushBack(3)
+ v.PushBack(5)
+
+ {
+ _, ok := v.Exists(1)
+ assert.True(t, ok)
+ }
+
+ {
+ _, ok := v.Exists(3)
+ assert.True(t, ok)
+ }
+
+ {
+ _, ok := v.Exists(2)
+ assert.False(t, ok)
+ }
+}
+
+func TestVector_PushBack(t *testing.T) {
+ var v = New[Int, int](8)
+ v.PushBack(1)
+ v.PushBack(3)
+ v.PushBack(5)
+
+ var arr []int
+ for v.Len() > 0 {
+ arr = append(arr, v.PopBack().GetID())
+ }
+ assert.True(t, utils.IsSameSlice(arr, []int{5, 3, 1}))
+ assert.Equal(t, v.PopBack().GetID(), 0)
+}
+
+func TestVector_PopFront(t *testing.T) {
+ var v = New[Int, int](8)
+ v.PushBack(1)
+ v.PushBack(3)
+ v.PushBack(5)
+
+ var arr []int
+ for v.Len() > 0 {
+ arr = append(arr, v.PopFront().GetID())
+ }
+ assert.True(t, utils.IsSameSlice(arr, []int{1, 3, 5}))
+ assert.Equal(t, v.PopFront().GetID(), 0)
+}
+
+func TestVector_Range(t *testing.T) {
+ t.Run("", func(t *testing.T) {
+ var a = NewFromInt64s(1, 3, 5)
+ var v = a.Clone()
+ var arr []int64
+ v.Range(func(i int, value Int64) bool {
+ arr = append(arr, value.GetID())
+ return true
+ })
+ assert.True(t, utils.IsSameSlice(arr, []int64{1, 3, 5}))
+ })
+
+ t.Run("", func(t *testing.T) {
+ var v = NewFromInt64s(1, 3, 5)
+ var arr []int64
+ v.Range(func(i int, value Int64) bool {
+ arr = append(arr, value.GetID())
+ return len(arr) < 2
+ })
+ assert.True(t, utils.IsSameSlice(arr, []int64{1, 3}))
+ })
+}
+
+func TestVector_ToMap(t *testing.T) {
+ var a = NewFromDocs[user, string](
+ user{ID: "a"},
+ user{ID: "b"},
+ user{ID: "c"},
+ )
+ var values = a.ToMap().Keys()
+ assert.ElementsMatch(t, values, []string{"a", "b", "c"})
+}
+
+func TestVector_Slice(t *testing.T) {
+ var a = NewFromStrings("a", "b", "c", "d")
+ var b = a.Slice(1, 3)
+ var values = b.IdList()
+ assert.ElementsMatch(t, values, []string{"b", "c"})
+
+ assert.Equal(t, a.Len(), 4)
+ a.Reset()
+ assert.Equal(t, a.Len(), 0)
+}
+
+func TestVector_Sort(t *testing.T) {
+ var a = NewFromInts(1, 3, 5, 2, 4, 6).Sort().IdList()
+ assert.True(t, utils.IsSameSlice(a, []int{1, 2, 3, 4, 5, 6}))
+}
+
+func TestVector_Update(t *testing.T) {
+ var v = NewFromInts(1, 3, 5)
+ assert.True(t, utils.IsSameSlice(v.Elem(), []Int{1, 3, 5}))
+ v.Update(0, 2)
+ v.Update(1, 4)
+ v.Update(2, 6)
+ assert.True(t, utils.IsSameSlice(v.Elem(), []Int{2, 4, 6}))
+}
+
+func TestVector_Reverse(t *testing.T) {
+ var v = NewFromInts(1, 2, 3)
+ v.Reverse()
+ assert.True(t, utils.IsSameSlice(v.Elem(), []Int{3, 2, 1}))
+}