diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 47fed02..001f254 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -19,7 +19,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.21 + go-version: 1.18 - name: Test run: go test -v ./... diff --git a/README.md b/README.md index 5ce9712..5fc5c95 100644 --- a/README.md +++ b/README.md @@ -1,36 +1,99 @@

DAO

logo -
道生一, 一生二, 二生三, 三生万物; 万物负阴而抱阳, 冲气以为和.
+
道生一, 一生二, 二生三, 三生万物; 万物负阴而抱阳, 冲气以为和
-[![Build Status](https://github.com/lxzan/dao/workflows/Go%20Test/badge.svg?branch=main)](https://github.com/lxzan/dao/actions?query=branch%3Amain) [![go-version](https://img.shields.io/badge/go-%3E%3D1.21-30dff3?style=flat-square&logo=go)](https://github.com/lxzan/dao) - +[![Build Status](https://github.com/lxzan/dao/workflows/Go%20Test/badge.svg?branch=main)](https://github.com/lxzan/dao/actions?query=branch%3Amain) [![codecov](https://codecov.io/gh/lxzan/dao/graph/badge.svg?token=BQM1JHCDEE)](https://codecov.io/gh/lxzan/dao) [![go-version](https://img.shields.io/badge/go-%3E%3D1.18-30dff3?style=flat-square&logo=go)](https://github.com/lxzan/dao) [![license](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE) ### 简介 -Go 数据结构与算法库 +`DAO` 是一个基于泛型的数据结构与算法库 ### 目录 - [简介](#简介) - [目录](#目录) +- [动态数组](#动态数组) + - [去重](#去重) + - [排序](#排序) + - [过滤](#过滤) - [堆](#堆) - [二叉堆](#二叉堆) - [四叉堆](#四叉堆) + - [八叉堆](#八叉堆) - [栈](#栈) +- [队列](#队列) - [双端队列](#双端队列) - [双向链表](#双向链表) - [红黑树](#红黑树) - - [区间查询](#区间查询) - - [极值查询](#极值查询) -- [前缀树](#前缀树) +- [字典树](#字典树) - [哈希表](#哈希表) - [线段树](#线段树) - [基准测试](#基准测试) +### 动态数组 + +#### 去重 + +```go +package main + +import ( + "fmt" + "github.com/lxzan/dao/vector" +) + +func main() { + var v = vector.NewFromInts(1, 3, 5, 3) + v.Uniq() + fmt.Printf("%v", v.Elem()) +} + +``` + +#### 排序 + +```go +package main + +import ( + "fmt" + "github.com/lxzan/dao/vector" +) + +func main() { + var v = vector.NewFromInts(1, 3, 5, 2, 4, 6) + v.Sort() + fmt.Printf("%v", v.Elem()) +} + +``` + +#### 过滤 + +```go +package main + +import ( + "fmt" + "github.com/lxzan/dao/vector" +) + +func main() { + var v = vector.NewFromInts(1, 3, 5, 2, 4, 6) + v.Filter(func(i int, v vector.Int) bool { + return v.GetID()%2 == 0 + }) + fmt.Printf("%v", v.Elem()) +} + +``` + ### 堆 +**堆** 又称之为优先队列, 堆顶元素总是最大或最小的. 常用的是四叉堆, `Push/Pop` 性能较为均衡. + #### 二叉堆 ```go @@ -38,10 +101,11 @@ package main import ( "github.com/lxzan/dao/heap" + "github.com/lxzan/dao/types/cmp" ) func main() { - var h = heap.New[int]().SetForkNumber(heap.Binary) + var h = heap.NewWithForks(heap.Binary, cmp.Less[int]) h.Push(1) h.Push(3) h.Push(5) @@ -62,10 +126,11 @@ package main import ( "github.com/lxzan/dao/heap" + "github.com/lxzan/dao/types/cmp" ) func main() { - var h = heap.New[int]().SetForkNumber(heap.Quadratic) + var h = heap.NewWithForks(heap.Quadratic, cmp.Less[int]) h.Push(1) h.Push(3) h.Push(5) @@ -76,10 +141,38 @@ func main() { println(h.Pop()) } } + +``` + +#### 八叉堆 + +```go +package main + +import ( + "github.com/lxzan/dao/heap" + "github.com/lxzan/dao/types/cmp" +) + +func main() { + var h = heap.NewWithForks(heap.Octal, cmp.Less[int]) + h.Push(1) + h.Push(3) + h.Push(5) + h.Push(2) + h.Push(4) + h.Push(6) + for h.Len() > 0 { + println(h.Pop()) + } +} + ``` ### 栈 +**栈** 先进后出 (`LIFO`) 的数据结构 + ```go package main @@ -98,8 +191,35 @@ func main() { } ``` +### 队列 + +**队列** 先进先出 (`FIFO`) 的数据结构. `dao/queue` 在全部元素弹出后会自动重置, 复用内存空间 + +```go +package main + +import ( + "github.com/lxzan/dao/queue" +) + +func main() { + var s = queue.New[int](0) + s.Push(1) + s.Push(3) + s.Push(5) + for s.Len() > 0 { + println(s.Pop()) + } +} + +``` + ### 双端队列 +**双端队列** 类似于双向链表, 两端均可高效执行插入删除操作. + +`dao/deque` 基于数组下标模拟指针实现, 删除后的空间后续仍可复用, 且不依赖 `sync.Pool` + ```go package main @@ -156,13 +276,12 @@ func main() { ### 红黑树 -#### 区间查询 +高性能红黑树实现, 可作为内存数据库使用. ```go package main import ( - "github.com/lxzan/dao" "github.com/lxzan/dao/rbtree" ) @@ -176,39 +295,18 @@ func main() { NewQuery(). Left(func(key int) bool { return key >= 3 }). Right(func(key int) bool { return key <= 5 }). - Order(dao.DESC). - Do() + Order(rbtree.ASC). + FindAll() for _, item := range results { println(item.Key) } } -``` - -#### 极值查询 - -```go -package main -import ( - "fmt" - "github.com/lxzan/dao/rbtree" -) - -func main() { - var tree = rbtree.New[int, struct{}]() - for i := 0; i < 10; i++ { - tree.Set(i, struct{}{}) - } - - minimum, _ := tree.GetMinKey(rbtree.TrueFunc[int]) - maximum, _ := tree.GetMaxKey(rbtree.TrueFunc[int]) - fmt.Printf("%v %v", minimum.Key, maximum.Key) -} ``` -### 前缀树 +### 字典树 -可以动态配置槽位宽度的前缀树 +**字典树** 又叫前缀树, 可以高效匹配字符串前缀. `dao/dict` 可以动态配置槽位宽度(由索引控制). 注意: 合理设置索引, 超出索引长度的字符不能被索引优化 @@ -258,6 +356,8 @@ func main() { ### 线段树 +**线段树** 是一种二叉树,它的每个节点都表示一个区间。 线段树的特点是可以在`O(logn)`的时间内进行区间查询和区间更新。 + ```go package main @@ -268,7 +368,7 @@ import ( func main() { var data = []tree.Int64{1, 3, 5, 7, 9, 2, 4, 6, 8, 10} var lines = tree.New[tree.Int64Schema, tree.Int64](data) - var result = lines.Query(0, 9) + var result = lines.Query(0, 10) println(result.MinValue, result.MaxValue, result.Sum) } @@ -276,7 +376,42 @@ func main() { ### 基准测试 -- 10,000 elements +- 1,000 elements + +``` +go test -benchmem -bench '^Benchmark' ./benchmark/ +goos: windows +goarch: amd64 +pkg: github.com/lxzan/dao/benchmark +cpu: AMD Ryzen 5 PRO 4650G with Radeon Graphics +BenchmarkDict_Set-12 8647 124.1 ns/op 48190 B/op 1003 allocs/op +BenchmarkDict_Get-12 9609 122.0 ns/op 48000 B/op 1000 allocs/op +BenchmarkDict_Match-12 3270 349.3 ns/op 48000 B/op 1000 allocs/op +BenchmarkHeap_Push_Binary-12 55809 21.0 ns/op 38912 B/op 3 allocs/op +BenchmarkHeap_Push_Quadratic-12 65932 18.0 ns/op 38912 B/op 3 allocs/op +BenchmarkHeap_Push_Octal-12 71005 16.1 ns/op 38912 B/op 3 allocs/op +BenchmarkHeap_Pop_Binary-12 10000 100.1 ns/op 16384 B/op 1 allocs/op +BenchmarkHeap_Pop_Quadratic-12 10000 100.7 ns/op 16384 B/op 1 allocs/op +BenchmarkHeap_Pop_Octal-12 9681 124.9 ns/op 16384 B/op 1 allocs/op +BenchmarkStdList_Push-12 24715 48.7 ns/op 54000 B/op 1745 allocs/op +BenchmarkStdList_PushAndPop-12 22006 54.7 ns/op 54000 B/op 1745 allocs/op +BenchmarkLinkedList_Push-12 38464 31.7 ns/op 24000 B/op 1000 allocs/op +BenchmarkLinkedList_PushAndPop-12 36898 32.6 ns/op 24000 B/op 1000 allocs/op +BenchmarkDeque_Push-12 100468 11.7 ns/op 24576 B/op 1 allocs/op +BenchmarkDeque_PushAndPop-12 51649 21.7 ns/op 37496 B/op 12 allocs/op +BenchmarkRBTree_Set-12 9999 113.9 ns/op 72048 B/op 2001 allocs/op +BenchmarkRBTree_Get-12 51806 22.7 ns/op 0 B/op 0 allocs/op +BenchmarkRBTree_FindAll-12 2808 421.3 ns/op 288001 B/op 5000 allocs/op +BenchmarkRBTree_FindAOne-12 4722 252.2 ns/op 56000 B/op 5000 allocs/op +BenchmarkSegmentTree_Query-12 7498 164.4 ns/op 20 B/op 0 allocs/op +BenchmarkSegmentTree_Update-12 10000 108.6 ns/op 15 B/op 0 allocs/op +BenchmarkSort_Quick-12 24488 48.5 ns/op 0 B/op 0 allocs/op +BenchmarkSort_Std-12 21703 54.9 ns/op 8216 B/op 2 allocs/op +PASS +ok github.com/lxzan/dao/benchmark 31.100s +``` + +- 1,000,000 elements ``` go test -benchmem -bench '^Benchmark' ./benchmark/ @@ -284,28 +419,29 @@ goos: windows goarch: amd64 pkg: github.com/lxzan/dao/benchmark cpu: AMD Ryzen 5 PRO 4650G with Radeon Graphics -BenchmarkDict_Set-12 423 244.9 ns/op 517811 B/op 10645 allocs/op -BenchmarkDict_Get-12 499 241.9 ns/op 480001 B/op 10000 allocs/op -BenchmarkDict_Match-12 265 456.1 ns/op 480000 B/op 10000 allocs/op -BenchmarkHeap_Push_Binary-12 4455 28.7 ns/op 507905 B/op 4 allocs/op -BenchmarkHeap_Push_Quadratic-12 5960 26.4 ns/op 507906 B/op 4 allocs/op -BenchmarkHeap_Push_Octal-12 5793 22.5 ns/op 507907 B/op 4 allocs/op -BenchmarkHeap_Pop_Binary-12 808 149.7 ns/op 163840 B/op 1 allocs/op -BenchmarkHeap_Pop_Quadratic-12 846 145.4 ns/op 163840 B/op 1 allocs/op -BenchmarkHeap_Pop_Octal-12 673 178.8 ns/op 163840 B/op 1 allocs/op -BenchmarkStdList_Push-12 1958 59.6 ns/op 558002 B/op 19745 allocs/op -BenchmarkStdList_PushAndPop-12 1729 65.2 ns/op 558001 B/op 19745 allocs/op -BenchmarkLinkedList_Push-12 3770 31.9 ns/op 240001 B/op 10000 allocs/op -BenchmarkLinkedList_PushAndPop-12 2539 46.5 ns/op 240002 B/op 10000 allocs/op -BenchmarkDeque_Push-12 8560 12.2 ns/op 245761 B/op 1 allocs/op -BenchmarkDeque_PushAndPop-12 5599 37.8 ns/op 386937 B/op 18 allocs/op -BenchmarkRBTree_Set-12 540 219.4 ns/op 720051 B/op 20001 allocs/op -BenchmarkRBTree_Get-12 3272 36.5 ns/op 0 B/op 0 allocs/op -BenchmarkRBTree_Query-12 60 1809.6 ns/op 3680048 B/op 60000 allocs/op -BenchmarkSegmentTree_Query-12 418 273.4 ns/op 3917 B/op 47 allocs/op -BenchmarkSegmentTree_Update-12 686 174.5 ns/op 2387 B/op 29 allocs/op -BenchmarkSort_Quick-12 1588 75.8 ns/op 81920 B/op 1 allocs/op -BenchmarkSort_Std-12 1377 86.2 ns/op 81944 B/op 2 allocs/op +BenchmarkDict_Set-12 1 2295.2 ns/op 1405087408 B/op 24868109 allocs/op +BenchmarkDict_Get-12 2 784.0 ns/op 48000000 B/op 1000000 allocs/op +BenchmarkDict_Match-12 2 961.0 ns/op 48000000 B/op 1000000 allocs/op +BenchmarkHeap_Push_Binary-12 48 24.8 ns/op 65708034 B/op 5 allocs/op +BenchmarkHeap_Push_Quadratic-12 58 19.4 ns/op 65708033 B/op 5 allocs/op +BenchmarkHeap_Push_Octal-12 69 17.1 ns/op 65708033 B/op 5 allocs/op +BenchmarkHeap_Pop_Binary-12 3 376.3 ns/op 16007168 B/op 1 allocs/op +BenchmarkHeap_Pop_Quadratic-12 3 342.8 ns/op 16007168 B/op 1 allocs/op +BenchmarkHeap_Pop_Octal-12 3 374.8 ns/op 16007168 B/op 1 allocs/op +BenchmarkStdList_Push-12 21 55.0 ns/op 55998007 B/op 1999745 allocs/op +BenchmarkStdList_PushAndPop-12 15 67.5 ns/op 55998008 B/op 1999745 allocs/op +BenchmarkLinkedList_Push-12 43 29.5 ns/op 24000000 B/op 1000000 allocs/op +BenchmarkLinkedList_PushAndPop-12 39 34.7 ns/op 24000002 B/op 1000000 allocs/op +BenchmarkDeque_Push-12 123 9.4 ns/op 24002560 B/op 1 allocs/op +BenchmarkDeque_PushAndPop-12 60 18.7 ns/op 45098876 B/op 37 allocs/op +BenchmarkRBTree_Set-12 6 171.9 ns/op 72000064 B/op 2000001 allocs/op +BenchmarkRBTree_Get-12 22 50.1 ns/op 0 B/op 0 allocs/op +BenchmarkRBTree_FindAll-12 1 1936.8 ns/op 288000128 B/op 5000001 allocs/op +BenchmarkRBTree_FindAOne-12 1 1793.4 ns/op 56000000 B/op 5000000 allocs/op +BenchmarkSegmentTree_Query-12 1 1630.0 ns/op 169678048 B/op 2000038 allocs/op +BenchmarkSegmentTree_Update-12 1 1025.0 ns/op 169678048 B/op 2000038 allocs/op +BenchmarkSort_Quick-12 10 109.5 ns/op 8003584 B/op 1 allocs/op +BenchmarkSort_Std-12 9 123.0 ns/op 8003608 B/op 2 allocs/op PASS -ok github.com/lxzan/dao/benchmark 32.279s -``` \ No newline at end of file +ok github.com/lxzan/dao/benchmark 47.376s +``` diff --git a/algorithm/helper.go b/algorithm/helper.go index 416f4c2..358eae1 100644 --- a/algorithm/helper.go +++ b/algorithm/helper.go @@ -1,14 +1,12 @@ package algorithm import ( - "cmp" - "github.com/lxzan/dao" - "slices" + "github.com/lxzan/dao/types/cmp" "strconv" ) // ToString 数字转字符串 -func ToString[T dao.Integer](x T) string { +func ToString[T cmp.Integer](x T) string { return strconv.Itoa(int(x)) } @@ -35,12 +33,12 @@ func Swap[T any](a, b *T) { *b = temp } -func Unique[T cmp.Ordered](arr []T) []T { +func Unique[T cmp.Ordered, A ~[]T](arr A) A { if len(arr) == 0 { return arr } - slices.Sort(arr) + Sort(arr) var n = len(arr) var j = 1 @@ -54,21 +52,13 @@ func Unique[T cmp.Ordered](arr []T) []T { return arr } -func UniqueBy[T any, K cmp.Ordered](arr []T, getKey func(item T) K) []T { +func UniqueBy[T any, K cmp.Ordered, A ~[]T](arr A, getKey func(item T) K) A { if len(arr) == 0 { return arr } - slices.SortFunc(arr, func(a, b T) int { - x := getKey(a) - y := getKey(b) - if x < y { - return -1 - } else if x > y { - return 1 - } else { - return 0 - } + SortBy(arr, func(a, b T) int { + return cmp.Compare(getKey(a), getKey(b)) }) var n = len(arr) @@ -84,11 +74,12 @@ func UniqueBy[T any, K cmp.Ordered](arr []T, getKey func(item T) K) []T { } // Reverse 反转数组 -func Reverse[T any](arr []T) { +func Reverse[T any, A ~[]T](arr A) A { var n = len(arr) for i := 0; i < n/2; i++ { arr[i], arr[n-1-i] = arr[n-1-i], arr[i] } + return arr } // SelectValue 选择一个值 三元操作符替代品 @@ -119,7 +110,7 @@ func Map[A any, B any](arr []A, transfer func(i int, v A) B) []B { } // Filter 过滤器 -func Filter[T any](arr []T, check func(i int, v T) bool) []T { +func Filter[T any, A ~[]T](arr A, check func(i int, v T) bool) A { var results = make([]T, 0, len(arr)) for i, v := range arr { if check(i, v) { diff --git a/algorithm/helper_test.go b/algorithm/helper_test.go index e0a61df..5394dee 100644 --- a/algorithm/helper_test.go +++ b/algorithm/helper_test.go @@ -15,6 +15,8 @@ func TestToString(t *testing.T) { } func TestUnique(t *testing.T) { + Unique[int, []int](nil) + t.Run("", func(t *testing.T) { arr := Unique([]int{}) assert.ElementsMatch(t, arr, []int{}) @@ -133,6 +135,8 @@ func TestSwap(t *testing.T) { } func TestReverse(t *testing.T) { + Reverse[int, []int](nil) + t.Run("", func(t *testing.T) { var list = []int{1, 2, 3, 4} Reverse(list) diff --git a/algorithm/sort.go b/algorithm/sort.go index 3bc0a88..3867fa4 100644 --- a/algorithm/sort.go +++ b/algorithm/sort.go @@ -1,17 +1,16 @@ package algorithm import ( - "cmp" - "github.com/lxzan/dao" + "github.com/lxzan/dao/types/cmp" ) -func IsSorted[T any](arr []T, cmp dao.CompareFunc[T]) bool { +func IsSorted[T any](arr []T, compare cmp.CompareFunc[T]) bool { var n = len(arr) if n <= 1 { return true } for i := 1; i < n; i++ { - if cmp(arr[i], arr[i-1]) < 0 { + if compare(arr[i], arr[i-1]) < 0 { return false } } @@ -26,22 +25,22 @@ func Sort[T cmp.Ordered](arr []T) { QuickSort(arr, 0, len(arr)-1, f) } -func SortBy[T any](arr []T, cmp dao.CompareFunc[T]) { - if IsSorted(arr, cmp) { +func SortBy[T any](arr []T, compare cmp.CompareFunc[T]) { + if IsSorted(arr, compare) { return } - QuickSort(arr, 0, len(arr)-1, cmp) + QuickSort(arr, 0, len(arr)-1, compare) } -func getMedium[T any](arr []T, begin int, end int, cmp dao.CompareFunc[T]) int { +func getMedium[T any](arr []T, begin int, end int, compare cmp.CompareFunc[T]) int { var mid = (begin + end) / 2 - var x = cmp(arr[begin], arr[mid]) - var y = cmp(arr[mid], arr[end]) + var x = compare(arr[begin], arr[mid]) + var y = compare(arr[mid], arr[end]) if x+y != 0 { return mid } - var z = cmp(arr[begin], arr[end]) + var z = compare(arr[begin], arr[end]) y *= -1 if y+z != 0 { return end @@ -49,9 +48,9 @@ func getMedium[T any](arr []T, begin int, end int, cmp dao.CompareFunc[T]) int { return begin } -func insertionSort[T any](arr []T, a, b int, cmp dao.CompareFunc[T]) { +func insertionSort[T any](arr []T, a, b int, compare cmp.CompareFunc[T]) { for i := a + 1; i <= b; i++ { - for j := i; j > a && cmp(arr[j], arr[j-1]) == dao.Less; j-- { + for j := i; j > a && compare(arr[j], arr[j-1]) == cmp.LT; j-- { arr[j], arr[j-1] = arr[j-1], arr[j] } } @@ -59,34 +58,34 @@ func insertionSort[T any](arr []T, a, b int, cmp dao.CompareFunc[T]) { // QuickSort 快速排序 begin <= x <= end 区间 // 对于随机数据, 此算法比标准库稍快; 对于本身比较有序的数据, 标准库表现更佳. -func QuickSort[T any](arr []T, begin int, end int, cmp dao.CompareFunc[T]) { +func QuickSort[T any](arr []T, begin int, end int, compare cmp.CompareFunc[T]) { if begin >= end { return } if end-begin <= 15 { - insertionSort(arr, begin, end, cmp) + insertionSort(arr, begin, end, compare) return } var index = begin - var mid = getMedium(arr, begin, end, cmp) + var mid = getMedium(arr, begin, end, compare) arr[mid], arr[begin] = arr[begin], arr[mid] for i := begin + 1; i <= end; i++ { - var flag = cmp(arr[i], arr[begin]) - if flag == dao.Less || (flag == dao.Equal && i%2 == 0) { + var flag = compare(arr[i], arr[begin]) + if flag == cmp.LT || (flag == cmp.EQ && i%2 == 0) { index++ arr[index], arr[i] = arr[i], arr[index] } } arr[index], arr[begin] = arr[begin], arr[index] - QuickSort(arr, begin, index-1, cmp) - QuickSort(arr, index+1, end, cmp) + QuickSort(arr, begin, index-1, compare) + QuickSort(arr, index+1, end, compare) } // BinarySearch 二分搜索 // @return 数组下标 如果不存在, 返回-1 -func BinarySearch[T any](arr []T, target T, cmp dao.CompareFunc[T]) int { +func BinarySearch[T any](arr []T, target T, compare cmp.CompareFunc[T]) int { var n = len(arr) if n == 0 { return -1 @@ -96,19 +95,19 @@ func BinarySearch[T any](arr []T, target T, cmp dao.CompareFunc[T]) int { var right = n - 1 for right-left > 1 { var mid = (left + right) / 2 - switch cmp(arr[mid], target) { - case dao.Equal: + switch compare(arr[mid], target) { + case cmp.EQ: return mid - case dao.Greater: + case cmp.GT: right = mid default: left = mid } } - if cmp(arr[left], target) == dao.Equal { + if compare(arr[left], target) == cmp.EQ { return left - } else if cmp(arr[right], target) == dao.Equal { + } else if compare(arr[right], target) == cmp.EQ { return right } else { return -1 diff --git a/algorithm/sort_test.go b/algorithm/sort_test.go index 5b32ef7..0fc3048 100644 --- a/algorithm/sort_test.go +++ b/algorithm/sort_test.go @@ -1,10 +1,11 @@ package algorithm import ( - "cmp" "github.com/lxzan/dao/internal/utils" + "github.com/lxzan/dao/types/cmp" "github.com/stretchr/testify/assert" "math/rand" + "sort" "testing" ) @@ -64,9 +65,9 @@ func TestSort(t *testing.T) { } t.Run("", func(t *testing.T) { - var a = []int{1, 2, 3} + var a = []int{2, 1} SortBy(a, cmp.Compare[int]) - assert.True(t, utils.IsSameSlice(a, []int{1, 2, 3})) + assert.True(t, utils.IsSameSlice(a, []int{1, 2})) }) t.Run("", func(t *testing.T) { @@ -80,6 +81,42 @@ func TestSort(t *testing.T) { Sort(a) assert.True(t, utils.IsSameSlice(a, []int{1, 2, 3, 4})) }) + + t.Run("", func(t *testing.T) { + Sort([]int{}) + Sort([]int{1}) + }) + + t.Run("", func(t *testing.T) { + for j := 0; j < 100; j++ { + var count = rand.Intn(100) + var arr0 []int + for i := 0; i < count; i++ { + arr0 = append(arr0, rand.Intn(count)) + } + var arr1 = utils.Clone(arr0) + + Sort(arr0) + sort.Ints(arr1) + assert.True(t, utils.IsSameSlice(arr0, arr1)) + } + + for j := 0; j < 100; j++ { + var count = rand.Intn(100) + var arr0 []int + for i := 0; i < count; i++ { + arr0 = append(arr0, rand.Intn(count)) + } + var arr1 = utils.Clone(arr0) + + SortBy(arr0, func(a, b int) int { + return -1 * cmp.Compare(a, b) + }) + sort.Ints(arr1) + Reverse(arr1) + assert.True(t, utils.IsSameSlice(arr0, arr1)) + } + }) } func TestBinarySearch(t *testing.T) { diff --git a/benchmark/heap_test.go b/benchmark/heap_test.go index 21e9d93..f8cedce 100644 --- a/benchmark/heap_test.go +++ b/benchmark/heap_test.go @@ -2,12 +2,13 @@ package benchmark import ( "github.com/lxzan/dao/heap" + "github.com/lxzan/dao/types/cmp" "math/rand" "testing" ) func BenchmarkHeap_Push_Binary(b *testing.B) { - var tpl = heap.New[int]().SetForkNumber(heap.Binary) + var tpl = heap.NewWithForks(heap.Binary, cmp.Less[int]) for j := 0; j < bench_count; j++ { tpl.Push(rand.Int()) } @@ -23,7 +24,7 @@ func BenchmarkHeap_Push_Binary(b *testing.B) { } func BenchmarkHeap_Push_Quadratic(b *testing.B) { - var tpl = heap.New[int]().SetForkNumber(heap.Quadratic) + var tpl = heap.NewWithForks(heap.Quadratic, cmp.Less[int]) for j := 0; j < bench_count; j++ { tpl.Push(rand.Int()) } @@ -38,7 +39,7 @@ func BenchmarkHeap_Push_Quadratic(b *testing.B) { } } func BenchmarkHeap_Push_Octal(b *testing.B) { - var tpl = heap.New[int]().SetForkNumber(heap.Octal) + var tpl = heap.NewWithForks(heap.Octal, cmp.Less[int]) for j := 0; j < bench_count; j++ { tpl.Push(rand.Int()) } @@ -54,7 +55,7 @@ func BenchmarkHeap_Push_Octal(b *testing.B) { } func BenchmarkHeap_Pop_Binary(b *testing.B) { - var tpl = heap.New[int]().SetForkNumber(heap.Binary) + var tpl = heap.NewWithForks(heap.Binary, cmp.Less[int]) for j := 0; j < bench_count*2; j++ { tpl.Push(rand.Int()) } @@ -69,7 +70,7 @@ func BenchmarkHeap_Pop_Binary(b *testing.B) { } func BenchmarkHeap_Pop_Quadratic(b *testing.B) { - var tpl = heap.New[int]().SetForkNumber(heap.Quadratic) + var tpl = heap.NewWithForks(heap.Quadratic, cmp.Less[int]) for j := 0; j < bench_count*2; j++ { tpl.Push(rand.Int()) } @@ -84,7 +85,7 @@ func BenchmarkHeap_Pop_Quadratic(b *testing.B) { } func BenchmarkHeap_Pop_Octal(b *testing.B) { - var tpl = heap.New[int]().SetForkNumber(heap.Octal) + var tpl = heap.NewWithForks(heap.Octal, cmp.Less[int]) for j := 0; j < bench_count*2; j++ { tpl.Push(rand.Int()) } diff --git a/benchmark/rbtree_test.go b/benchmark/rbtree_test.go index cf6ba72..807d922 100644 --- a/benchmark/rbtree_test.go +++ b/benchmark/rbtree_test.go @@ -30,7 +30,7 @@ func BenchmarkRBTree_Get(b *testing.B) { } } -func BenchmarkRBTree_Query(b *testing.B) { +func BenchmarkRBTree_FindAll(b *testing.B) { var tree = rbtree.New[int, string]() for j := 0; j < bench_count; j++ { tree.Set(j, "") @@ -47,7 +47,32 @@ func BenchmarkRBTree_Query(b *testing.B) { NewQuery(). Left(func(key int) bool { return key >= x }). Right(func(key int) bool { return key <= y }). - Do() + Order(rbtree.DESC). + Limit(10). + FindAll() + } + } +} + +func BenchmarkRBTree_FindAOne(b *testing.B) { + var tree = rbtree.New[int, string]() + for j := 0; j < bench_count; j++ { + tree.Set(j, "") + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < bench_count; j++ { + x, y := rand.Intn(bench_count), rand.Intn(bench_count) + if x > y { + algorithm.Swap(&x, &y) + } + tree. + NewQuery(). + Left(func(key int) bool { return key >= x }). + Right(func(key int) bool { return key <= y }). + Order(rbtree.ASC). + FindOne() } } } diff --git a/deque/deque.go b/deque/deque.go index ecc33ae..073c269 100644 --- a/deque/deque.go +++ b/deque/deque.go @@ -45,7 +45,7 @@ func (c *Element[T]) Value() T { } // New 创建双端队列 -func New[T any](capacity uint32) *Deque[T] { +func New[T any](capacity int) *Deque[T] { return &Deque[T]{elements: make([]Element[T], 1, 1+capacity)} } @@ -83,7 +83,6 @@ func (c *Deque[T]) putElement(ele *Element[T]) { // Reset 重置 func (c *Deque[T]) Reset() { - clear(c.elements) c.autoReset() } diff --git a/deque/deque_test.go b/deque/deque_test.go index 9613ba9..da3cf1b 100644 --- a/deque/deque_test.go +++ b/deque/deque_test.go @@ -1,9 +1,9 @@ package deque import ( - "cmp" "container/list" "github.com/lxzan/dao/internal/utils" + "github.com/lxzan/dao/types/cmp" "github.com/stretchr/testify/assert" "math/rand" "testing" diff --git a/dict/dict.go b/dict/dict.go index 18dd5f7..783dd65 100644 --- a/dict/dict.go +++ b/dict/dict.go @@ -12,30 +12,33 @@ type element struct { } type Dict[T any] struct { - indexes []uint8 // 索引 - root *element // 根节点 - storage *mlist.MList[string, T] // 存储 + indexes []uint8 // 索引 + binaryIndex bool // 是否使用2进制索引 + root *element // 根节点 + storage *mlist.MList[string, T] // 存储 } // New 新建字典树 // 注意: key不能重复 func New[T any]() *Dict[T] { return &Dict[T]{ - indexes: defaultIndexes, - root: &element{Children: make([]*element, defaultIndexes[0])}, - storage: mlist.NewMList[string, T](8), + indexes: defaultIndexes, + binaryIndex: true, + root: &element{Children: make([]*element, defaultIndexes[0])}, + storage: mlist.NewMList[string, T](8), } } // WithIndexes 设置索引 -// 索引元素必须满足 y=2^x +// 索引长度至少为2; 如果每个数字都满足y=pow(2,x), 索引效率更高. func (c *Dict[T]) WithIndexes(indexes []uint8) *Dict[T] { if len(indexes) < 2 { panic("indexes length at least 2") } for _, item := range indexes { if !utils.IsBinaryNumber(item) { - panic("indexes contains elements that must satisfy y=2^x") + c.binaryIndex = false + break } } c.indexes = indexes diff --git a/dict/dict_test.go b/dict/dict_test.go index 13363a0..606408d 100644 --- a/dict/dict_test.go +++ b/dict/dict_test.go @@ -186,7 +186,7 @@ func TestDict_WithIndexes(t *testing.T) { err := f(func() { New[uint8]().WithIndexes([]uint8{1, 2, 3, 4}) }) - assert.Error(t, err) + assert.NoError(t, err) }) t.Run("", func(t *testing.T) { @@ -199,7 +199,9 @@ func TestDict_WithIndexes(t *testing.T) { func TestDict_Random(t *testing.T) { var count = 1000000 - var d = New[int]() + var d = New[int]().WithIndexes( + []uint8{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, + ) var m = make(map[string]int) for i := 0; i < count; i++ { key := strconv.Itoa(i) diff --git a/dict/encode.go b/dict/encode.go index c3ed003..74ec4ce 100644 --- a/dict/encode.go +++ b/dict/encode.go @@ -1,5 +1,7 @@ package dict +import "github.com/lxzan/dao/algorithm" + var defaultIndexes = []uint8{32, 32, 16, 16, 16, 16, 16, 8, 8, 8, 8, 8, 8, 4, 4, 4, 4} type iterator struct { @@ -15,11 +17,14 @@ func (c *iterator) hit() bool { } func (c *Dict[T]) getIndex(iter *iterator) int { - return int(iter.Key[iter.Cursor]) & int(c.indexes[iter.Cursor]-1) + if c.binaryIndex { + return int(iter.Key[iter.Cursor]) & int(c.indexes[iter.Cursor]-1) + } + return int(iter.Key[iter.Cursor]) % int(c.indexes[iter.Cursor]) } func (c *Dict[T]) begin(key string, initialize bool) *iterator { - var iter = &iterator{Node: c.root, Key: key, N: min(len(key), len(c.indexes)-1), Initialize: initialize} + var iter = &iterator{Node: c.root, Key: key, N: algorithm.Min(len(key), len(c.indexes)-1), Initialize: initialize} var idx = c.getIndex(iter) if iter.Node.Children[idx] == nil { if !iter.Initialize { diff --git a/go.mod b/go.mod index 10d38f7..a9d68ac 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/lxzan/dao -go 1.21 +go 1.18 require github.com/stretchr/testify v1.8.4 diff --git a/hashmap/hashmap.go b/hashmap/hashmap.go index 93c43f0..af56003 100644 --- a/hashmap/hashmap.go +++ b/hashmap/hashmap.go @@ -4,14 +4,17 @@ type HashMap[K comparable, V any] map[K]V // New instantiates a hashmap // at most one param, means initial capacity -func New[K comparable, V any](capacity uint32) HashMap[K, V] { +func New[K comparable, V any](capacity int) HashMap[K, V] { return make(map[K]V, capacity) } -// Reset clear contents -func (c HashMap[K, V]) Reset() { - clear(c) -} +//Reset clear contents +//func (c HashMap[K, V]) Reset() { +// keys := c.Keys() +// for _, key := range keys { +// delete(c, key) +// } +//} // Len get the length of hashmap func (c HashMap[K, V]) Len() int { diff --git a/hashmap/hashmap_test.go b/hashmap/hashmap_test.go index b8df2df..471db56 100644 --- a/hashmap/hashmap_test.go +++ b/hashmap/hashmap_test.go @@ -118,11 +118,3 @@ func TestHashMap_Range(t *testing.T) { }) assert.Equal(t, len(keys), 2) } - -func TestHashMap_Reset(t *testing.T) { - var m = New[string, int](8) - m.Set("a", 1) - assert.Equal(t, m.Len(), 1) - m.Reset() - assert.Equal(t, m.Len(), 0) -} diff --git a/heap/heap.go b/heap/heap.go index e21c3a8..d732c08 100644 --- a/heap/heap.go +++ b/heap/heap.go @@ -1,8 +1,7 @@ package heap import ( - "cmp" - "github.com/lxzan/dao" + "github.com/lxzan/dao/types/cmp" ) const ( @@ -13,35 +12,34 @@ const ( Octal = 8 ) -// New 新建一个最小堆 -// Create a new minimum heap -func New[T cmp.Ordered]() *Heap[T] { return NewHeap(cmp.Less[T]) } +// New 新建一个最小四叉堆 +// Create a new minimum quadratic heap +func New[T cmp.Ordered]() *Heap[T] { return NewWithForks(Quadratic, cmp.Less[T]) } -// NewHeap 新建一个堆 -// Create a new heap -func NewHeap[T any](less dao.LessFunc[T]) *Heap[T] { - h := &Heap[T]{cmp: less} - h.SetForkNumber(Quadratic) +// NewWithForks 新建堆 +// @forks 分叉数, 可选值为: 2,4,6 +// @lessFunc 比较函数 +func NewWithForks[T any](forks uint32, lessFunc cmp.LessFunc[T]) *Heap[T] { + h := &Heap[T]{lessFunc: lessFunc} + h.setForkNumber(forks) return h } -// Heap 可以不使用New函数, 声明为值类型自动初始化 -// 如果使用值类型, 需要设置比较函数和分叉数 type Heap[T any] struct { - bits uint32 - forks int - data []T - cmp func(a, b T) bool + bits uint32 + forks int + data []T + lessFunc func(a, b T) bool } // SetCap 设置预分配容量 -func (c *Heap[T]) SetCap(n uint32) *Heap[T] { +func (c *Heap[T]) SetCap(n int) *Heap[T] { c.data = make([]T, 0, n) return c } -// SetForkNumber 设置分叉数 -func (c *Heap[T]) SetForkNumber(n uint32) *Heap[T] { +// setForkNumber 设置分叉数 +func (c *Heap[T]) setForkNumber(n uint32) *Heap[T] { c.forks = int(n) switch n { case Quadratic, Binary: @@ -54,19 +52,13 @@ func (c *Heap[T]) SetForkNumber(n uint32) *Heap[T] { return c } -// SetLessFunc 设置比较函数 -func (c *Heap[T]) SetLessFunc(less dao.LessFunc[T]) *Heap[T] { - c.cmp = less - return c -} - // Len 获取元素数量 func (c *Heap[T]) Len() int { return len(c.data) } func (c *Heap[T]) less(i, j int) bool { - return c.cmp(c.data[i], c.data[j]) + return c.lessFunc(c.data[i], c.data[j]) } func (c *Heap[T]) swap(i, j int) { @@ -102,7 +94,6 @@ func (c *Heap[T]) down(i, n int) { // Reset 重置堆 func (c *Heap[T]) Reset() { - clear(c.data) c.data = c.data[:0] } @@ -149,3 +140,8 @@ func (c *Heap[T]) Clone() *Heap[T] { copy(v.data, c.data) return &v } + +// UnWrap 解包, 返回底层数组 +func (c *Heap[T]) UnWrap() []T { + return c.data +} diff --git a/heap/heap_test.go b/heap/heap_test.go index 4001c50..f5e4344 100644 --- a/heap/heap_test.go +++ b/heap/heap_test.go @@ -2,20 +2,22 @@ package heap import ( "fmt" - "github.com/lxzan/dao" "github.com/lxzan/dao/internal/utils" + "github.com/lxzan/dao/types/cmp" "github.com/stretchr/testify/assert" "sort" "testing" "unsafe" ) +func desc[T cmp.Ordered](a, b T) bool { + return a > b +} + func TestNew(t *testing.T) { const count = 1000 { - var h = NewHeap[string](func(a, b string) bool { - return a < b - }) + var h = New[string]() var arr1 = make([]string, 0) var arr2 = make([]string, 0) for i := 0; i < count; i++ { @@ -34,9 +36,8 @@ func TestNew(t *testing.T) { } func TestDesc(t *testing.T) { - var h = NewHeap(dao.DescFunc[int]) + var h = NewWithForks(Octal, desc[int]) h.SetCap(8) - h.SetForkNumber(Octal) h.Push(1) assert.Equal(t, h.Top(), 1) h.Push(3) @@ -52,9 +53,8 @@ func TestDesc(t *testing.T) { } func TestAsc(t *testing.T) { - var h = New[int]() + var h = NewWithForks(Binary, cmp.Less[int]) h.SetCap(8) - h.SetForkNumber(Binary) h.Push(1) h.Push(3) h.Push(2) @@ -69,8 +69,8 @@ func TestAsc(t *testing.T) { } func TestHeap_Range(t *testing.T) { - var h Heap[int] - h.SetForkNumber(Quadratic).SetLessFunc(dao.AscFunc[int]).SetCap(8) + var h = NewWithForks(Quadratic, cmp.Less[int]) + h.SetCap(8) h.Push(1) h.Push(3) h.Push(2) @@ -108,14 +108,13 @@ func TestHeap_Reset(t *testing.T) { } func TestHeap_Pop(t *testing.T) { - var h = NewHeap(dao.AscFunc[int]) + var h = New[int]() assert.Equal(t, h.Pop(), 0) h.Push(1) assert.Equal(t, h.Pop(), 1) } func TestHeap_SetForkNumber(t *testing.T) { - var h = NewHeap(dao.AscFunc[int]) var catch = func(f func()) (err error) { defer func() { if excp := recover(); excp != nil { @@ -127,19 +126,18 @@ func TestHeap_SetForkNumber(t *testing.T) { } var err1 = catch(func() { - h.SetForkNumber(3) + NewWithForks(3, cmp.Less[int]) }) assert.Error(t, err1) var err2 = catch(func() { - h.SetForkNumber(4) + NewWithForks(4, cmp.Less[int]) }) assert.Nil(t, err2) } func TestHeap_Clone(t *testing.T) { - var h = NewHeap(dao.AscFunc[int]) - h.SetForkNumber(4) + var h = New[int]() h.Push(1) h.Push(3) h.Push(2) @@ -154,3 +152,11 @@ func TestHeap_Clone(t *testing.T) { assert.NotEqual(t, addr, addr1) assert.Equal(t, addr, addr2) } + +func TestHeap_UnWrap(t *testing.T) { + var h = NewWithForks(2, cmp.Less[int]) + h.Push(1) + h.Push(2) + h.Push(3) + assert.True(t, utils.IsSameSlice(h.UnWrap(), []int{1, 2, 3})) +} diff --git a/interface.go b/interface.go index 9083c52..4f7a8ff 100644 --- a/interface.go +++ b/interface.go @@ -1,44 +1,6 @@ package dao -import "cmp" - -const ( - Less = -1 - Equal = 0 - Greater = 1 -) - -type ( - // LessFunc 比大小 - LessFunc[T any] func(a, b T) bool - - // CompareFunc 比较函数 - // a>b, 返回1; a b } - -type Order uint8 - -const ( - ASC Order = 0 // 升序 - DESC Order = 1 // 降序 -) - type ( - Number interface { - Integer | ~float32 | ~float64 - } - - Integer interface { - ~int64 | ~int | ~int32 | ~int16 | ~int8 | ~uint64 | ~uint | ~uint32 | ~uint16 | ~uint8 - } - // Map 键不可重复 Map[K comparable, V any] interface { Len() int diff --git a/internal/mlist/mlist.go b/internal/mlist/mlist.go index f52e84b..e61fd7b 100644 --- a/internal/mlist/mlist.go +++ b/internal/mlist/mlist.go @@ -35,7 +35,6 @@ func NewMList[K comparable, V any](size uint32) *MList[K, V] { func (c *MList[K, V]) Reset() { c.length, c.serial = 0, 0 - clear(c.Buckets) c.recyclable = c.recyclable[:0] c.Buckets = c.Buckets[:1] } diff --git a/internal/utils/helper.go b/internal/utils/helper.go index a370e53..3795c59 100644 --- a/internal/utils/helper.go +++ b/internal/utils/helper.go @@ -63,3 +63,10 @@ type Integer interface { func IsBinaryNumber[T Integer](x T) bool { return x&(x-1) == 0 } + +func Clone[S ~[]E, E any](s S) S { + if s == nil { + return nil + } + return append(S([]E{}), s...) +} diff --git a/internal/utils/helper_test.go b/internal/utils/helper_test.go index a101d39..0aa6490 100644 --- a/internal/utils/helper_test.go +++ b/internal/utils/helper_test.go @@ -62,3 +62,17 @@ func TestIsBinaryNumber(t *testing.T) { assert.False(t, IsBinaryNumber(7)) assert.False(t, IsBinaryNumber(21)) } + +func TestClone(t *testing.T) { + { + var a []int + var b = Clone(a) + assert.True(t, len(b) == 0) + } + + { + var a = []int{1, 2, 3} + var b = Clone(a) + assert.ElementsMatch(t, b, a) + } +} diff --git a/queue/queue.go b/queue/queue.go new file mode 100644 index 0000000..966f6c8 --- /dev/null +++ b/queue/queue.go @@ -0,0 +1,67 @@ +package queue + +import "github.com/lxzan/dao/internal/utils" + +type Queue[T any] struct { + offset int + tpl T + data []T +} + +// New 创建队列 +func New[T any](capacity int) *Queue[T] { + return &Queue[T]{data: make([]T, 0, capacity)} +} + +// NewFrom 从切片创建队列 +func NewFrom[T any](values ...T) *Queue[T] { + return &Queue[T]{data: values} +} + +// Reset 重置 +func (c *Queue[T]) Reset() { + c.offset = 0 + c.data = c.data[:0] +} + +// Len 获取队列长度 +func (c *Queue[T]) Len() int { + return len(c.data) - c.offset +} + +// Push 追加元素到队列尾部 +func (c *Queue[T]) Push(v T) { + c.data = append(c.data, v) +} + +// Pop 从队列头部弹出元素 +func (c *Queue[T]) Pop() (value T) { + if n := c.Len(); n > 0 { + value = c.data[c.offset] + c.data[c.offset] = c.tpl + c.offset++ + if c.offset == len(c.data) { + c.Reset() + } + } + return value +} + +// Range 遍历 +func (c *Queue[T]) Range(f func(value T) bool) { + for _, item := range c.data { + if !f(item) { + return + } + } +} + +// UnWrap 解包, 返回底层数组 +func (c *Queue[T]) UnWrap() []T { + return c.data[c.offset:] +} + +// Clone 拷贝副本 +func (c *Queue[T]) Clone() *Queue[T] { + return &Queue[T]{data: utils.Clone(c.data)} +} diff --git a/queue/queue_test.go b/queue/queue_test.go new file mode 100644 index 0000000..accec45 --- /dev/null +++ b/queue/queue_test.go @@ -0,0 +1,86 @@ +package queue + +import ( + "github.com/lxzan/dao/internal/utils" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestQueue(t *testing.T) { + t.Run("", func(t *testing.T) { + q := NewFrom(1, 3, 5, 7, 9) + var a []int + for q.Len() > 0 { + a = append(a, q.Pop()) + } + assert.True(t, utils.IsSameSlice(a, []int{1, 3, 5, 7, 9})) + assert.Equal(t, q.offset, 0) + }) + + t.Run("", func(t *testing.T) { + q := NewFrom[int](1) + q.Push(3) + q.Pop() + assert.Equal(t, q.offset, 1) + assert.Equal(t, q.data[0], 0) + assert.Equal(t, q.Len(), 1) + }) +} + +func TestQueue_UnWrap(t *testing.T) { + t.Run("", func(t *testing.T) { + q := NewFrom(1, 3, 5, 7, 9) + q.Pop() + a := q.UnWrap() + assert.True(t, utils.IsSameSlice(a, []int{3, 5, 7, 9})) + }) + + t.Run("", func(t *testing.T) { + q := NewFrom(1) + q.Pop() + a := q.UnWrap() + assert.Equal(t, len(a), 0) + }) + + t.Run("", func(t *testing.T) { + q := New[int](8) + a := q.UnWrap() + assert.Equal(t, len(a), 0) + }) +} + +func TestQueue_Clone(t *testing.T) { + q := NewFrom(1, 3, 5, 7, 9) + b := q.Clone() + assert.True(t, utils.IsSameSlice(b.UnWrap(), []int{1, 3, 5, 7, 9})) +} + +func TestQueue_Range(t *testing.T) { + t.Run("", func(t *testing.T) { + var s = NewFrom(1, 3, 5) + + var arr []int + s.Range(func(value int) bool { + arr = append(arr, value) + return true + }) + assert.True(t, utils.IsSameSlice(arr, []int{1, 3, 5})) + + s.Reset() + assert.Equal(t, s.Len(), 0) + }) + + t.Run("", func(t *testing.T) { + var s = New[int](8) + s.Push(1) + s.Push(3) + s.Push(5) + + var arr []int + s.Range(func(value int) bool { + arr = append(arr, value) + return len(arr) < 2 + }) + assert.True(t, utils.IsSameSlice(arr, []int{1, 3})) + }) +} diff --git a/rbtree/query.go b/rbtree/query.go index bd65e59..4504b38 100644 --- a/rbtree/query.go +++ b/rbtree/query.go @@ -1,105 +1,35 @@ package rbtree import ( - "cmp" - "github.com/lxzan/dao" "github.com/lxzan/dao/algorithm" "github.com/lxzan/dao/stack" + "github.com/lxzan/dao/types/cmp" ) -// Get 查询一个key -func (c *RBTree[K, V]) Get(key K) (result V, exist bool) { - for i := c.begin(); !c.end(i); i = c.next(i, key) { - if key == i.data.Key { - return i.data.Val, true - } - } - return result, false -} - -// Range 遍历树 -func (c *RBTree[K, V]) Range(fn func(key K, value V) bool) { - var next = true - c.do_range(c.root, &next, fn) -} - -func (c *RBTree[K, V]) do_range(node *rbtree_node[K, V], next *bool, fn func(K, V) bool) { - if c.end(node) || !*next { - return - } +type Order uint8 - if ok := fn(node.data.Key, node.data.Val); !ok { - *next = ok - } - - c.do_range(node.left, next, fn) - c.do_range(node.right, next, fn) -} - -// GetMinKey 获取最小的key, 过滤条件可为空 -func (c *RBTree[K, V]) GetMinKey(filter func(key K) bool) (result Pair[K, V], exist bool) { - return c.doGetMinKey(stack.New[*rbtree_node[K, V]](10), filter) -} - -func (c *RBTree[K, V]) doGetMinKey(s *stack.Stack[*rbtree_node[K, V]], filter func(key K) bool) (result Pair[K, V], exist bool) { - s.Reset() - filter = algorithm.SelectValue(filter == nil, TrueFunc[K], filter) - s.Push(c.root) - for s.Len() > 0 { - var node = s.Pop() - if c.end(node) { - continue - } - if filter(node.data.Key) { - if !exist || node.data.Key < result.Key { - exist = true - result = *node.data - } - s.Push(node.left) - } else { - s.Push(node.right) - } - } - return result, exist -} +const ( + ASC Order = 0 // 升序 + DESC Order = 1 // 降序 +) -// GetMaxKey 获取最大的key, 过滤条件可为空 -func (c *RBTree[K, V]) GetMaxKey(filter func(key K) bool) (result Pair[K, V], exist bool) { - return c.doGetMaxKey(stack.New[*rbtree_node[K, V]](10), filter) -} +func TrueFunc[K cmp.Ordered](d K) bool { return true } -func (c *RBTree[K, V]) doGetMaxKey(s *stack.Stack[*rbtree_node[K, V]], filter func(key K) bool) (result Pair[K, V], exist bool) { - s.Reset() - filter = algorithm.SelectValue(filter == nil, TrueFunc[K], filter) - s.Push(c.root) - for s.Len() > 0 { - var node = s.Pop() - if c.end(node) { - continue - } - if filter(node.data.Key) { - if !exist || node.data.Key > result.Key { - exist = true - result = *node.data - } - s.Push(node.right) - } else { - s.Push(node.left) - } - } - return result, exist -} - -func TrueFunc[K cmp.Ordered](d K) bool { - return true +// NewQuery 新建一个查询 +func (c *RBTree[K, V]) NewQuery() *QueryBuilder[K, V] { + return &QueryBuilder[K, V]{tree: c} } type QueryBuilder[K cmp.Ordered, V any] struct { - tree *RBTree[K, V] - leftFilter func(key K) bool - rightFilter func(key K) bool - limit int - order dao.Order + tree *RBTree[K, V] // 红黑树 + results []Pair[K, V] // 查询结果 + + leftFilter func(key K) bool // 左边界条件 + rightFilter func(key K) bool // 右边界条件 + limit int // 单页限制 + total int // 总条数 + offset int // 偏移量 + order Order // 排序 } func (c *QueryBuilder[K, V]) init() *QueryBuilder[K, V] { @@ -112,6 +42,7 @@ func (c *QueryBuilder[K, V]) init() *QueryBuilder[K, V] { if c.limit <= 0 { c.limit = 10 } + c.total = c.offset + c.limit return c } @@ -128,7 +59,7 @@ func (c *QueryBuilder[K, V]) Right(f func(key K) bool) *QueryBuilder[K, V] { } // Order 设置排序, 默认ASC -func (c *QueryBuilder[K, V]) Order(o dao.Order) *QueryBuilder[K, V] { +func (c *QueryBuilder[K, V]) Order(o Order) *QueryBuilder[K, V] { c.order = o return c } @@ -139,59 +70,182 @@ func (c *QueryBuilder[K, V]) Limit(n int) *QueryBuilder[K, V] { return c } -// Do 执行查询 -func (c *QueryBuilder[K, V]) Do() []Pair[K, V] { - return c.tree.do_query(c) +func (c *QueryBuilder[K, V]) Offset(n int) *QueryBuilder[K, V] { + c.offset = n + return c } -// NewQuery 新建一个查询 -func (c *RBTree[K, V]) NewQuery() *QueryBuilder[K, V] { - return &QueryBuilder[K, V]{tree: c} +// FindAll 执行查询 +func (c *QueryBuilder[K, V]) FindAll() []Pair[K, V] { + c.init() + c.results = make([]Pair[K, V], 0, c.total) + + switch c.order { + case DESC: + c.rangeDesc(c.tree.root) + case ASC: + c.rangeAsc(c.tree.root) + } + + if c.offset > 0 { + if len(c.results) > c.offset { + c.results = c.results[c.offset:] + } else { + c.results = c.results[:0] + } + } + return c.results } -func (c *RBTree[K, V]) do_query(q *QueryBuilder[K, V]) []Pair[K, V] { - q.init() - var results = make([]Pair[K, V], 0, q.limit) - var s = stack.New[*rbtree_node[K, V]](uint32(q.limit)) +// 降序遍历 中序遍历是有序的 +func (c *QueryBuilder[K, V]) rangeDesc(node *rbtree_node[K, V]) { + if c.tree.end(node) || len(c.results) >= c.total { + return + } + + state := 0 + if c.rightFilter(node.data.Key) { + state += 1 + } + if c.leftFilter(node.data.Key) { + state += 2 + } + + switch state { + case 3: + c.rangeDesc(node.right) + if len(c.results) < c.total { + c.results = append(c.results, *node.data) + } else { + return + } + c.rangeDesc(node.left) + case 2: + c.rangeDesc(node.left) + case 1: + c.rangeDesc(node.right) + } +} - if q.order == dao.DESC { - maxEle, exist := c.doGetMaxKey(s, q.rightFilter) - if exist && q.leftFilter(maxEle.Key) { - results = append(results, maxEle) +// 升序遍历 中序遍历是有序的 +func (c *QueryBuilder[K, V]) rangeAsc(node *rbtree_node[K, V]) { + if c.tree.end(node) || len(c.results) >= c.total { + return + } + + state := 0 + if c.rightFilter(node.data.Key) { + state += 1 + } + if c.leftFilter(node.data.Key) { + state += 2 + } + + switch state { + case 3: + c.rangeAsc(node.left) + if len(c.results) < c.total { + c.results = append(c.results, *node.data) } else { - return results + return } + c.rangeAsc(node.right) + case 2: + c.rangeAsc(node.left) + case 1: + c.rangeAsc(node.right) + } +} - for i := 0; i < q.limit-1; i++ { - result, exist := c.doGetMaxKey(s, func(key K) bool { - return key < maxEle.Key - }) - if exist && q.leftFilter(result.Key) { - results = append(results, result) - maxEle = result - } else { - break +func (c *QueryBuilder[K, V]) FindOne() (p Pair[K, V], exist bool) { + c.Limit(1).Offset(0).init() + + switch c.order { + case DESC: + if v, ok := c.getMaxPair(c.rightFilter); ok { + if c.leftFilter(v.Key) { + return *v, true + } + } + case ASC: + if v, ok := c.getMinPair(c.leftFilter); ok { + if c.rightFilter(v.Key) { + return *v, true } } - } else { - minEle, exist := c.doGetMinKey(s, q.leftFilter) - if exist && q.rightFilter(minEle.Key) { - results = append(results, minEle) + } + + return p, exist +} + +func (c *QueryBuilder[K, V]) getMaxPair(filter func(key K) bool) (result *Pair[K, V], exist bool) { + var s = stack.Stack[*rbtree_node[K, V]]{} + s.Push(c.tree.root) + for s.Len() > 0 { + var node = s.Pop() + if c.tree.end(node) { + continue + } + if filter(node.data.Key) { + if !exist || node.data.Key > result.Key { + exist = true + result = node.data + } + s.Push(node.right) } else { - return results + s.Push(node.left) } + } + return result, exist +} - for i := 0; i < q.limit-1; i++ { - result, exist := c.doGetMinKey(s, func(key K) bool { - return key > minEle.Key - }) - if exist && q.rightFilter(result.Key) { - results = append(results, result) - minEle = result - } else { - break +func (c *QueryBuilder[K, V]) getMinPair(filter func(key K) bool) (result *Pair[K, V], exist bool) { + var s = stack.Stack[*rbtree_node[K, V]]{} + filter = algorithm.SelectValue(filter == nil, TrueFunc[K], filter) + s.Push(c.tree.root) + for s.Len() > 0 { + var node = s.Pop() + if c.tree.end(node) { + continue + } + if filter(node.data.Key) { + if !exist || node.data.Key < result.Key { + exist = true + result = node.data } + s.Push(node.left) + } else { + s.Push(node.right) + } + } + return result, exist +} + +// Get 查询一个key +func (c *RBTree[K, V]) Get(key K) (result V, exist bool) { + for i := c.begin(); !c.end(i); i = c.next(i, key) { + if key == i.data.Key { + return i.data.Val, true } } - return results + return result, false +} + +// Range 遍历树 +func (c *RBTree[K, V]) Range(fn func(key K, value V) bool) { + var next = true + c.do_range(c.root, &next, fn) +} + +func (c *RBTree[K, V]) do_range(node *rbtree_node[K, V], next *bool, fn func(K, V) bool) { + if c.end(node) || !*next { + return + } + + if ok := fn(node.data.Key, node.data.Val); !ok { + *next = ok + } + + c.do_range(node.left, next, fn) + c.do_range(node.right, next, fn) } diff --git a/rbtree/rbtree.go b/rbtree/rbtree.go index 9823470..01a70b8 100644 --- a/rbtree/rbtree.go +++ b/rbtree/rbtree.go @@ -1,7 +1,7 @@ package rbtree import ( - "cmp" + "github.com/lxzan/dao/types/cmp" ) type Color uint8 diff --git a/rbtree/rbtree_test.go b/rbtree/rbtree_test.go index db73079..d441990 100644 --- a/rbtree/rbtree_test.go +++ b/rbtree/rbtree_test.go @@ -1,14 +1,13 @@ package rbtree import ( - "cmp" "fmt" - "github.com/lxzan/dao" "github.com/lxzan/dao/algorithm" + "github.com/lxzan/dao/hashmap" "github.com/lxzan/dao/internal/utils" "github.com/lxzan/dao/internal/validator" + "github.com/lxzan/dao/types/cmp" "github.com/stretchr/testify/assert" - "math/rand" "sort" "strconv" "testing" @@ -141,49 +140,131 @@ func TestRBTree_ForEach(t *testing.T) { } func TestRBTree_Between(t *testing.T) { - var tree = New[string, int]() - var m = make(map[string]int) - for i := 0; i < 10000; i++ { - var length = utils.Rand.Intn(16) + 1 - var key = utils.Numeric.Generate(4) - m[key] = length - tree.Set(key, length) - } + t.Run("desc", func(t *testing.T) { + var tree = New[string, int]() + var m = make(map[string]int) + for i := 0; i < 10000; i++ { + var length = utils.Rand.Intn(16) + 1 + var key = utils.Numeric.Generate(4) + m[key] = length + tree.Set(key, length) + } - var limit = 100 - for i := 0; i < 100; i++ { - var left = utils.Numeric.Generate(4) - x, _ := strconv.Atoi(left) + var limit = 100 + for i := 0; i < 100; i++ { + var left = utils.Numeric.Generate(4) + x, _ := strconv.Atoi(left) - var right = fmt.Sprintf("%04d", x+limit) - if left > right { - right, left = left, right - } - var keys1 = tree. - NewQuery(). - Left(func(key string) bool { return key >= left }). - Right(func(key string) bool { return key <= right }). - Order(dao.DESC). - Limit(limit). - Do() - var keys2 = make([]string, 0) - for k := range m { - if k >= left && k <= right { - keys2 = append(keys2, k) + var right = fmt.Sprintf("%04d", x+limit) + if left > right { + right, left = left, right + } + var values = tree. + NewQuery(). + Left(func(key string) bool { return key >= left }). + Right(func(key string) bool { return key <= right }). + Order(DESC). + Limit(limit). + FindAll() + var keys1 = algorithm.Map[Pair[string, int], string](values, func(i int, v Pair[string, int]) string { + return v.Key + }) + + var keys2 = make([]string, 0) + for k := range m { + if k >= left && k <= right { + keys2 = append(keys2, k) + } } + sort.Strings(keys2) + algorithm.Reverse(keys2) + if len(keys2) > limit { + keys2 = keys2[:limit] + } + + assert.True(t, utils.IsSameSlice(keys1, keys2)) } - sort.Strings(keys2) - algorithm.Reverse(keys2) - if len(keys2) > limit { - keys2 = keys2[:limit] + }) + + t.Run("asc", func(t *testing.T) { + var tree = New[string, int]() + var m = make(map[string]int) + for i := 0; i < 10000; i++ { + var length = utils.Rand.Intn(16) + 1 + var key = utils.Numeric.Generate(4) + m[key] = length + tree.Set(key, length) } - if !utils.IsSameSlice(keys2, algorithm.Map(keys1, func(i int, x Pair[string, int]) string { - return x.Key - })) { - t.Fatal("error!") + var limit = 100 + for i := 0; i < 100; i++ { + var left = utils.Numeric.Generate(4) + x, _ := strconv.Atoi(left) + + var right = fmt.Sprintf("%04d", x+limit) + if left > right { + right, left = left, right + } + var values = tree. + NewQuery(). + Left(func(key string) bool { return key >= left }). + Right(func(key string) bool { return key <= right }). + Order(ASC). + Limit(limit). + Offset(10). + FindAll() + var keys1 = algorithm.Map[Pair[string, int], string](values, func(i int, v Pair[string, int]) string { + return v.Key + }) + + var keys2 = make([]string, 0) + for k := range m { + if k >= left && k <= right { + keys2 = append(keys2, k) + } + } + sort.Strings(keys2) + if len(keys2) > 10 { + keys2 = keys2[10:] + } else { + keys2 = keys2[:0] + } + if len(keys2) > limit { + keys2 = keys2[:limit] + } + + assert.True(t, utils.IsSameSlice(keys1, keys2)) } - } + }) + + t.Run("", func(t *testing.T) { + var tree = New[int, uint8]() + tree.Set(1, 1) + tree.Set(2, 1) + tree.Set(3, 1) + tree.Set(4, 1) + tree.Set(5, 1) + + var values0 = tree. + NewQuery(). + Left(func(key int) bool { return key >= 1 }). + Right(func(key int) bool { return key <= 3 }). + Order(ASC). + Limit(10). + Offset(5). + FindAll() + assert.Equal(t, len(values0), 0) + + var values1 = tree. + NewQuery(). + Left(func(key int) bool { return key >= 1 }). + Right(func(key int) bool { return key <= 3 }). + Order(ASC). + Limit(10). + FindAll() + var keys1 = algorithm.Map(values1, func(i int, v Pair[int, uint8]) int { return v.Key }) + assert.True(t, utils.IsSameSlice(keys1, []int{1, 2, 3})) + }) } func TestRBTree_GreaterEqual(t *testing.T) { @@ -199,11 +280,14 @@ func TestRBTree_GreaterEqual(t *testing.T) { var limit = 100 for i := 0; i < 100; i++ { var left = utils.Numeric.Generate(4) - var keys1 = tree. + var values = tree. NewQuery(). Left(func(key string) bool { return key >= left }). Limit(limit). - Do() + FindAll() + var keys1 = algorithm.Map[Pair[string, int], string](values, func(i int, v Pair[string, int]) string { + return v.Key + }) var keys2 = make([]string, 0) for k := range m { if k >= left { @@ -215,11 +299,7 @@ func TestRBTree_GreaterEqual(t *testing.T) { keys2 = keys2[:limit] } - if !utils.IsSameSlice(keys2, algorithm.Map(keys1, func(i int, x Pair[string, int]) string { - return x.Key - })) { - t.Fatal("error!") - } + assert.True(t, utils.IsSameSlice(keys1, keys2)) } } @@ -236,11 +316,15 @@ func TestRBTree_LessEqual(t *testing.T) { var limit = 10 for i := 0; i < 100; i++ { var target = utils.Numeric.Generate(4) - var keys1 = tree. + var results = tree. NewQuery(). Right(func(key string) bool { return key <= target }). - Order(dao.DESC). - Do() + Order(DESC). + Limit(limit). + FindAll() + var keys1 = algorithm.Map[Pair[string, int], string](results, func(i int, v Pair[string, int]) string { + return v.Key + }) var keys2 = make([]string, 0) for k := range m { if k <= target { @@ -253,78 +337,85 @@ func TestRBTree_LessEqual(t *testing.T) { keys2 = keys2[:limit] } - if !utils.IsSameSlice(keys2, algorithm.Map(keys1, func(i int, x Pair[string, int]) string { - return x.Key - })) { - t.Fatal("error!") - } + assert.True(t, utils.IsSameSlice(keys1, keys2)) } } -func TestRBTree_GetMinKey(t *testing.T) { - var tree = New[string, int]() - - const test_count = 100 - for i := 0; i < test_count; i++ { - var v = rand.Intn(10000) - tree.Set(strconv.Itoa(v), v) - } - - for i := 0; i < test_count; i++ { - var k = strconv.Itoa(rand.Intn(10000)) - result, exist := tree.GetMinKey(func(key string) bool { - return key >= k - }) +func TestRBTree_FindOne(t *testing.T) { + t.Run("desc", func(t *testing.T) { + var tree = New[string, int]() + var m = hashmap.New[string, int](0) + for i := 0; i < 10000; i++ { + var length = utils.Rand.Intn(16) + 1 + var key = utils.Numeric.Generate(4) + m.Set(key, length) + tree.Set(key, length) + } - if !exist { - tree.Range(func(key string, value int) bool { - if key >= k { - t.Fatal("error!") + for i := 0; i < 100; i++ { + var target = utils.Numeric.Generate(4) + v0, ok0 := tree. + NewQuery(). + Right(func(key string) bool { return key <= target }). + Order(DESC). + FindOne() + + var v1, ok1 = "", false + m.Range(func(key string, val int) bool { + if key <= target && (v1 == "" || key > v1) { + v1 = key + ok1 = true } return true }) - } else { - tree.Range(func(key string, value int) bool { - if key < result.Key && key >= k { - t.Fatal("error!") - } - return true - }) - } - } -} - -func TestRBTree_GetMaxKey(t *testing.T) { - var tree = New[string, int]() - const test_count = 100 - for i := 0; i < test_count; i++ { - var v = rand.Intn(10000) - tree.Set(strconv.Itoa(v), v) - } + assert.Equal(t, ok0, ok1) + if ok0 { + assert.Equal(t, v0.Key, v1) + } + } + }) - for i := 0; i < test_count; i++ { - var k = strconv.Itoa(rand.Intn(10000)) - result, exist := tree.GetMaxKey(func(key string) bool { - return key <= k - }) + t.Run("asc", func(t *testing.T) { + var tree = New[string, int]() + var m = hashmap.New[string, int](0) + for i := 0; i < 10000; i++ { + var length = utils.Rand.Intn(16) + 1 + var key = utils.Numeric.Generate(4) + m.Set(key, length) + tree.Set(key, length) + } - if !exist { - tree.Range(func(key string, value int) bool { - if key <= k { - t.Fatal("error!") - } - return true - }) - } else { - tree.Range(func(key string, value int) bool { - if key > result.Key && key <= k { - t.Fatal("error!") + for i := 0; i < 100; i++ { + var target = utils.Numeric.Generate(4) + v0, ok0 := tree. + NewQuery(). + Left(func(key string) bool { return key >= target }). + Order(ASC). + FindOne() + + var v1, ok1 = "", false + m.Range(func(key string, val int) bool { + if key >= target && (v1 == "" || key < v1) { + v1 = key + ok1 = true } return true }) + + assert.Equal(t, ok0, ok1) + if ok0 { + assert.Equal(t, v0.Key, v1) + } } - } + }) + + t.Run("", func(t *testing.T) { + var tree = New[string, int]() + var qb = QueryBuilder[string, int]{tree: tree} + _, ok := qb.FindOne() + assert.False(t, ok) + }) } func TestDict_Map(t *testing.T) { @@ -344,7 +435,7 @@ func TestRBTree_NewQuery(t *testing.T) { var results = tree. NewQuery(). Left(func(key int) bool { return key > 10 }). - Do() + FindAll() assert.Equal(t, len(results), 0) }) @@ -352,8 +443,8 @@ func TestRBTree_NewQuery(t *testing.T) { var results = tree. NewQuery(). Left(func(key int) bool { return key > 10 }). - Order(dao.DESC). - Do() + Order(DESC). + FindAll() assert.Equal(t, len(results), 0) }) } diff --git a/segment_tree/segement_tree_test.go b/segment_tree/segement_tree_test.go index 1c93634..681dc5e 100644 --- a/segment_tree/segement_tree_test.go +++ b/segment_tree/segement_tree_test.go @@ -21,7 +21,7 @@ func TestSegmentTree_Query(t *testing.T) { if left > right { left, right = right, left } - var result1 = tree.Query(left, right) + var result1 = tree.Query(left, right+1) var result2 = Int64Schema{ MaxValue: arr[left].Value(), @@ -51,7 +51,7 @@ func TestSegmentTree_Query(t *testing.T) { if left > right { left, right = right, left } - var result1 = tree.Query(left, right) + var result1 = tree.Query(left, right+1) var result2 = Int64Schema{ MaxValue: arr[left].Value(), diff --git a/segment_tree/segment_tree.go b/segment_tree/segment_tree.go index 13d7aac..c9bf094 100644 --- a/segment_tree/segment_tree.go +++ b/segment_tree/segment_tree.go @@ -63,11 +63,11 @@ func (c *SegmentTree[S, T]) build(cur *Element[S, T]) { cur.data = cur.son.data.Merge(cur.daughter.data) } -// Query 查询 left <= index <= right 区间 -func (c *SegmentTree[S, T]) Query(left int, right int) S { +// Query 查询 begin <= index < end 区间 +func (c *SegmentTree[S, T]) Query(begin int, end int) S { var result S - result = c.arr[left].Init(OperateQuery) - c.doQuery(c.root, left, right, &result) + result = c.arr[begin].Init(OperateQuery) + c.doQuery(c.root, begin, end-1, &result) return result } diff --git a/stack/stack.go b/stack/stack.go index 6dae270..cccfdce 100644 --- a/stack/stack.go +++ b/stack/stack.go @@ -3,24 +3,34 @@ package stack // Stack 可以不使用New函数, 声明为值类型自动初始化 type Stack[T any] []T -func New[T any](capacity uint32) *Stack[T] { +// New 创建栈 +func New[T any](capacity int) *Stack[T] { s := Stack[T](make([]T, 0, capacity)) return &s } +// NewFrom 从可变参数切片创建栈 +func NewFrom[T any](values ...T) *Stack[T] { + c := Stack[T](values) + return &c +} + +// Reset 重置 func (c *Stack[T]) Reset() { - clear(*c) *c = (*c)[:0] } +// Len 获取元素数量 func (c *Stack[T]) Len() int { return len(*c) } +// Push 追加元素 func (c *Stack[T]) Push(v T) { *c = append(*c, v) } +// Pop 弹出元素 func (c *Stack[T]) Pop() (value T) { n := c.Len() switch n { @@ -33,6 +43,7 @@ func (c *Stack[T]) Pop() (value T) { } } +// Range 遍历 func (c *Stack[T]) Range(f func(value T) bool) { for _, item := range *c { if !f(item) { @@ -40,3 +51,8 @@ func (c *Stack[T]) Range(f func(value T) bool) { } } } + +// UnWrap 解包, 返回底层数组 +func (c *Stack[T]) UnWrap() []T { + return *(*[]T)(c) +} diff --git a/stack/stack_test.go b/stack/stack_test.go index cd2d809..e7d7788 100644 --- a/stack/stack_test.go +++ b/stack/stack_test.go @@ -22,10 +22,7 @@ func TestStack_Pop(t *testing.T) { func TestStack_Range(t *testing.T) { t.Run("", func(t *testing.T) { - var s = New[int](8) - s.Push(1) - s.Push(3) - s.Push(5) + var s = NewFrom(1, 3, 5) var arr []int s.Range(func(value int) bool { @@ -52,3 +49,9 @@ func TestStack_Range(t *testing.T) { assert.True(t, utils.IsSameSlice(arr, []int{1, 3})) }) } + +func TestStack_UnWrap(t *testing.T) { + var s = NewFrom(1, 3, 5) + var a = s.UnWrap() + assert.ElementsMatch(t, a, []int{1, 3, 5}) +} diff --git a/types/cmp/cmp.go b/types/cmp/cmp.go new file mode 100644 index 0000000..76572d3 --- /dev/null +++ b/types/cmp/cmp.go @@ -0,0 +1,47 @@ +package cmp + +type ( + Ordered interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | + ~float32 | ~float64 | + ~string + } + + Number interface { + Integer | ~float32 | ~float64 + } + + Integer interface { + ~int64 | ~int | ~int32 | ~int16 | ~int8 | ~uint64 | ~uint | ~uint32 | ~uint16 | ~uint8 + } +) + +const ( + LT = -1 // 小于 + EQ = 0 // 等于 + GT = 1 // 大于 +) + +type ( + // LessFunc 比大小 + LessFunc[T any] func(a, b T) bool + + // CompareFunc 比较函数 + // a>b, 返回1; a y { + return +1 + } + return 0 +} diff --git a/vector/types.go b/vector/types.go new file mode 100644 index 0000000..476c60c --- /dev/null +++ b/vector/types.go @@ -0,0 +1,27 @@ +package vector + +import ( + "github.com/lxzan/dao/types/cmp" +) + +type Document[T cmp.Ordered] interface { + GetID() T +} + +type ( + Int int + + Int64 int64 + + String string +) + +func (c Int) GetID() int { + return int(c) +} + +func (c Int64) GetID() int64 { + return int64(c) +} + +func (c String) GetID() string { return string(c) } diff --git a/vector/vector.go b/vector/vector.go new file mode 100644 index 0000000..fbcbdc0 --- /dev/null +++ b/vector/vector.go @@ -0,0 +1,183 @@ +package vector + +import ( + "github.com/lxzan/dao/algorithm" + "github.com/lxzan/dao/hashmap" + "github.com/lxzan/dao/internal/utils" + "github.com/lxzan/dao/types/cmp" + "unsafe" +) + +// New 创建动态数组 +func New[D Document[K], K cmp.Ordered](capacity int) *Vector[D, K] { + c := Vector[D, K](make([]D, 0, capacity)) + return &c +} + +// NewFromDocs 从可变参数创建动态数组 +func NewFromDocs[D Document[K], K cmp.Ordered](values ...D) *Vector[D, K] { + c := Vector[D, K](values) + return &c +} + +// NewFromInts 创建动态数组 +func NewFromInts(values ...int) *Vector[Int, int] { + var b = *(*[]Int)(unsafe.Pointer(&values)) + v := Vector[Int, int](b) + return &v +} + +// NewFromInt64s 创建动态数组 +func NewFromInt64s(values ...int64) *Vector[Int64, int64] { + var b = *(*[]Int64)(unsafe.Pointer(&values)) + v := Vector[Int64, int64](b) + return &v +} + +// NewFromStrings 创建动态数组 +func NewFromStrings(values ...string) *Vector[String, string] { + var b = *(*[]String)(unsafe.Pointer(&values)) + v := Vector[String, string](b) + return &v +} + +// Vector 动态数组 +type Vector[D Document[K], K cmp.Ordered] []D + +// Reset 重置 +func (c *Vector[D, K]) Reset() { + *c = (*c)[:0] +} + +// Len 获取元素数量 +func (c *Vector[D, K]) Len() int { + return len(*c) +} + +// Get 根据下标取值 +func (c *Vector[D, K]) Get(index int) D { + return (*c)[index] +} + +// Update 根据下标修改值 +func (c *Vector[D, K]) Update(index int, value D) { + (*c)[index] = value +} + +// Elem 取值 +func (c *Vector[D, K]) Elem() []D { + return *c +} + +// Exists 根据id判断某条数据是否存在 +func (c *Vector[D, K]) Exists(id K) (v D, exist bool) { + for _, item := range *c { + if item.GetID() == id { + return item, true + } + } + return v, exist +} + +// Unique 排序并根据id去重 +func (c *Vector[D, K]) Unique() *Vector[D, K] { + *c = algorithm.UniqueBy(*c, func(item D) K { + return item.GetID() + }) + return c +} + +// Filter 过滤 +func (c *Vector[D, K]) Filter(f func(i int, v D) bool) *Vector[D, K] { + *c = algorithm.Filter(*c, f) + return c +} + +// Sort 排序 +func (c *Vector[D, K]) Sort() *Vector[D, K] { + algorithm.SortBy(*c, func(a, b D) int { + return cmp.Compare(a.GetID(), b.GetID()) + }) + return c +} + +// IdList 获取id数组 +func (c *Vector[D, K]) IdList() []K { + var d D + switch any(d).(type) { + case Int, Int64, String: + var keys = *(*[]K)(unsafe.Pointer(c)) + return keys + default: + var keys = make([]K, 0, c.Len()) + for _, item := range *c { + keys = append(keys, item.GetID()) + } + return keys + } +} + +// ToMap 生成map[K]D +func (c *Vector[D, K]) ToMap() hashmap.HashMap[K, D] { + var m = hashmap.New[K, D](c.Len()) + for _, item := range *c { + m.Set(item.GetID(), item) + } + return m +} + +// PushBack 向尾部追加元素 +func (c *Vector[D, K]) PushBack(v D) { + *c = append(*c, v) +} + +// PopFront 从头部弹出元素 +func (c *Vector[D, K]) PopFront() (value D) { + switch c.Len() { + case 0: + return value + default: + value = (*c)[0] + *c = (*c)[1:] + return value + } +} + +// PopBack 从尾部弹出元素 +func (c *Vector[D, K]) PopBack() (value D) { + n := c.Len() + switch n { + case 0: + return value + default: + value = (*c)[n-1] + *c = (*c)[:n-1] + return value + } +} + +// Range 遍历 +func (c *Vector[D, K]) Range(f func(i int, v D) bool) { + for index, value := range *c { + if !f(index, value) { + return + } + } +} + +// Clone 拷贝 +func (c *Vector[D, K]) Clone() *Vector[D, K] { + var d = utils.Clone(*c) + return &d +} + +// Slice 截取子数组 +func (c *Vector[D, K]) Slice(start, end int) *Vector[D, K] { + var children = (*c)[start:end] + return &children +} + +func (c *Vector[D, K]) Reverse() *Vector[D, K] { + *c = algorithm.Reverse(*c) + return c +} diff --git a/vector/vector_test.go b/vector/vector_test.go new file mode 100644 index 0000000..f87cbb5 --- /dev/null +++ b/vector/vector_test.go @@ -0,0 +1,180 @@ +package vector + +import ( + "github.com/lxzan/dao/internal/utils" + "github.com/stretchr/testify/assert" + "testing" + "unsafe" +) + +func TestUser_GetID(t *testing.T) { + var docs Vector[user, string] + docs = append(docs, user{ID: "a"}) + docs = append(docs, user{ID: "c"}) + docs = append(docs, user{ID: "c"}) + docs = append(docs, user{ID: "b"}) + docs.Unique() + docs.Sort() + docs.Filter(func(i int, v user) bool { + return v.ID == "b" + }) +} + +type user struct { + ID string +} + +func (u user) GetID() string { + return u.ID +} + +func TestNewFromInts(t *testing.T) { + var a = NewFromInts(1, 3, 5) + var b = a.IdList() + assert.ElementsMatch(t, b, []int{1, 3, 5}) +} + +func TestNewFromInt64s(t *testing.T) { + var a = NewFromInt64s(1, 3, 5) + var b = a.IdList() + assert.ElementsMatch(t, b, []int64{1, 3, 5}) +} + +func TestVector_Keys(t *testing.T) { + t.Run("", func(t *testing.T) { + var a = NewFromStrings("a", "b", "c") + var b = a.IdList() + assert.ElementsMatch(t, b, []string{"a", "b", "c"}) + assert.Equal(t, a.Get(0).GetID(), "a") + + var addr0 = (uintptr)(unsafe.Pointer(&(*a)[0])) + var addr1 = (uintptr)(unsafe.Pointer(&b[0])) + assert.Equal(t, addr0, addr1) + + var values = a.Elem() + assert.ElementsMatch(t, values, []String{"a", "b", "c"}) + }) + + t.Run("", func(t *testing.T) { + var docs = NewFromDocs[user, string]( + user{ID: "a"}, + user{ID: "b"}, + user{ID: "c"}, + ) + assert.ElementsMatch(t, docs.IdList(), []string{"a", "b", "c"}) + }) +} + +func TestVector_Exists(t *testing.T) { + var v = New[Int, int](8) + v.PushBack(1) + v.PushBack(3) + v.PushBack(5) + + { + _, ok := v.Exists(1) + assert.True(t, ok) + } + + { + _, ok := v.Exists(3) + assert.True(t, ok) + } + + { + _, ok := v.Exists(2) + assert.False(t, ok) + } +} + +func TestVector_PushBack(t *testing.T) { + var v = New[Int, int](8) + v.PushBack(1) + v.PushBack(3) + v.PushBack(5) + + var arr []int + for v.Len() > 0 { + arr = append(arr, v.PopBack().GetID()) + } + assert.True(t, utils.IsSameSlice(arr, []int{5, 3, 1})) + assert.Equal(t, v.PopBack().GetID(), 0) +} + +func TestVector_PopFront(t *testing.T) { + var v = New[Int, int](8) + v.PushBack(1) + v.PushBack(3) + v.PushBack(5) + + var arr []int + for v.Len() > 0 { + arr = append(arr, v.PopFront().GetID()) + } + assert.True(t, utils.IsSameSlice(arr, []int{1, 3, 5})) + assert.Equal(t, v.PopFront().GetID(), 0) +} + +func TestVector_Range(t *testing.T) { + t.Run("", func(t *testing.T) { + var a = NewFromInt64s(1, 3, 5) + var v = a.Clone() + var arr []int64 + v.Range(func(i int, value Int64) bool { + arr = append(arr, value.GetID()) + return true + }) + assert.True(t, utils.IsSameSlice(arr, []int64{1, 3, 5})) + }) + + t.Run("", func(t *testing.T) { + var v = NewFromInt64s(1, 3, 5) + var arr []int64 + v.Range(func(i int, value Int64) bool { + arr = append(arr, value.GetID()) + return len(arr) < 2 + }) + assert.True(t, utils.IsSameSlice(arr, []int64{1, 3})) + }) +} + +func TestVector_ToMap(t *testing.T) { + var a = NewFromDocs[user, string]( + user{ID: "a"}, + user{ID: "b"}, + user{ID: "c"}, + ) + var values = a.ToMap().Keys() + assert.ElementsMatch(t, values, []string{"a", "b", "c"}) +} + +func TestVector_Slice(t *testing.T) { + var a = NewFromStrings("a", "b", "c", "d") + var b = a.Slice(1, 3) + var values = b.IdList() + assert.ElementsMatch(t, values, []string{"b", "c"}) + + assert.Equal(t, a.Len(), 4) + a.Reset() + assert.Equal(t, a.Len(), 0) +} + +func TestVector_Sort(t *testing.T) { + var a = NewFromInts(1, 3, 5, 2, 4, 6).Sort().IdList() + assert.True(t, utils.IsSameSlice(a, []int{1, 2, 3, 4, 5, 6})) +} + +func TestVector_Update(t *testing.T) { + var v = NewFromInts(1, 3, 5) + assert.True(t, utils.IsSameSlice(v.Elem(), []Int{1, 3, 5})) + v.Update(0, 2) + v.Update(1, 4) + v.Update(2, 6) + assert.True(t, utils.IsSameSlice(v.Elem(), []Int{2, 4, 6})) +} + +func TestVector_Reverse(t *testing.T) { + var v = NewFromInts(1, 2, 3) + v.Reverse() + assert.True(t, utils.IsSameSlice(v.Elem(), []Int{3, 2, 1})) +}