Skip to content

Commit

Permalink
simplify segment_tree
Browse files Browse the repository at this point in the history
  • Loading branch information
lixizan committed Jun 12, 2024
1 parent 1496ea2 commit c5fa764
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 130 deletions.
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -356,14 +356,15 @@ ability to perform interval queries and interval updates in `O(logn)` time.
package main

import (
tree "github.com/lxzan/dao/segment_tree"
"fmt"
st "github.com/lxzan/dao/segment_tree"
)

func main() {
var data = []tree.Int64{1, 3, 5, 7, 9, 2, 4, 6, 8, 10}
var lines = tree.New[tree.Int64Schema, tree.Int64](data)
var result = lines.Query(0, 10)
println(result.MinValue, result.MaxValue, result.Sum)
var a = []int{1, 3, 5, 7, 9, 2, 4, 6, 8, 10}
var t = st.New(a, st.NewIntSummary[int], st.MergeIntSummary[int])
var r = t.Query(3, 6)
fmt.Printf("%v\n", r)
}

```
Expand Down
12 changes: 6 additions & 6 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -346,16 +346,16 @@ func main() {
package main

import (
tree "github.com/lxzan/dao/segment_tree"
"fmt"
st "github.com/lxzan/dao/segment_tree"
)

func main() {
var data = []tree.Int64{1, 3, 5, 7, 9, 2, 4, 6, 8, 10}
var lines = tree.New[tree.Int64Schema, tree.Int64](data)
var result = lines.Query(0, 10)
println(result.MinValue, result.MaxValue, result.Sum)
var a = []int{1, 3, 5, 7, 9, 2, 4, 6, 8, 10}
var t = st.New(a, st.NewIntSummary[int], st.MergeIntSummary[int])
var r = t.Query(3, 6)
fmt.Printf("%v\n", r)
}

```

### 基准测试
Expand Down
46 changes: 20 additions & 26 deletions segment_tree/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,33 @@ package segment_tree

import (
"github.com/lxzan/dao/algo"
"github.com/lxzan/dao/types/cmp"
)

type Int64 int64
type (
NewSummary[T any, S any] func(T, Operate) S

// Init 初始化摘要结构
func (c Int64) Init(op Operate) Int64Schema {
var val = int64(c)
var result = Int64Schema{
MaxValue: val,
MinValue: val,
Sum: val,
}
if op == OperateQuery {
result.Sum = 0
}
return result
}
MergeSummary[S any] func(a, b S) S
)

func (c Int64) Value() int64 {
return int64(c)
type IntSummary[T cmp.Integer] struct {
MaxValue T
MinValue T
Sum T
}

type Int64Schema struct {
MaxValue int64
MinValue int64
Sum int64
func NewIntSummary[T cmp.Integer](num T, op Operate) IntSummary[T] {
var r = IntSummary[T]{MaxValue: num, MinValue: num, Sum: num}
if op == OperateQuery {
r.Sum = 0
}
return r
}

// Merge 合并摘要信息
func (c Int64Schema) Merge(d Int64Schema) Int64Schema {
return Int64Schema{
MaxValue: algo.Max(c.MaxValue, d.MaxValue),
MinValue: algo.Min(c.MinValue, d.MinValue),
Sum: c.Sum + d.Sum,
func MergeIntSummary[T cmp.Integer](a, b IntSummary[T]) IntSummary[T] {
return IntSummary[T]{
MaxValue: algo.Max(a.MaxValue, b.MaxValue),
MinValue: algo.Min(a.MinValue, b.MinValue),
Sum: a.Sum + b.Sum,
}
}
84 changes: 28 additions & 56 deletions segment_tree/segement_tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,69 +3,41 @@ package segment_tree
import (
"github.com/lxzan/dao/algo"
"github.com/lxzan/dao/internal/utils"
"github.com/stretchr/testify/assert"
"testing"
)

func TestSegmentTree_Query(t *testing.T) {
var n = 10000
var arr = make([]Int64, 0)
var arr = make([]int, 0)
for i := 0; i < n; i++ {
arr = append(arr, Int64(utils.Rand.Intn(n)))
arr = append(arr, utils.Rand.Intn(n))
}

var tree = New[Int64Schema, Int64](arr)

for i := 0; i < 1000; i++ {
var left = utils.Rand.Intn(n)
var right = utils.Rand.Intn(n)
if left > right {
left, right = right, left
}
var result1 = tree.Query(left, right+1)

var result2 = Int64Schema{
MaxValue: arr[left].Value(),
MinValue: arr[left].Value(),
Sum: 0,
}
for j := left; j <= right; j++ {
result2.Sum += arr[j].Value()
result2.MaxValue = algo.Max(result2.MaxValue, arr[j].Value())
result2.MinValue = algo.Min(result2.MinValue, arr[j].Value())
}

if result1.Sum != result2.Sum || result1.MinValue != result2.MinValue || result1.MaxValue != result2.MaxValue {
t.Fatal("error!")
}
}

for i := 0; i < 1000; i++ {
var index = utils.Rand.Intn(n)
var value = Int64(utils.Rand.Intn(n))
tree.Update(index, value)
}

for i := 0; i < 1000; i++ {
var left = utils.Rand.Intn(n)
var right = utils.Rand.Intn(n)
if left > right {
left, right = right, left
}
var result1 = tree.Query(left, right+1)

var result2 = Int64Schema{
MaxValue: arr[left].Value(),
MinValue: arr[left].Value(),
Sum: 0,
}
for j := left; j <= right; j++ {
result2.Sum += arr[j].Value()
result2.MaxValue = algo.Max(result2.MaxValue, arr[j].Value())
result2.MinValue = algo.Min(result2.MinValue, arr[j].Value())
}

if result1.Sum != result2.Sum || result1.MinValue != result2.MinValue || result1.MaxValue != result2.MaxValue {
t.Fatal("error!")
var stree = New(arr, NewIntSummary[int], MergeIntSummary[int])
for i := 0; i < 100; i++ {
var x, y = utils.Alphabet.Intn(n), utils.Alphabet.Intn(n)
if x == y {
continue
}
if x > y {
x, y = y, x
}

var flag = utils.Alphabet.Intn(4)
switch flag {
case 0:
stree.Update(x, y)
default:
r0 := stree.Query(x, y)
r1 := NewIntSummary(arr[x], OperateQuery)
for j := x; j < y; j++ {
r1.MaxValue = algo.Max(r1.MaxValue, arr[j])
r1.MinValue = algo.Min(r1.MinValue, arr[j])
r1.Sum += arr[j]
}
assert.Equal(t, r0.MaxValue, r1.MaxValue)
assert.Equal(t, r0.MinValue, r1.MinValue)
assert.Equal(t, r0.Sum, r1.Sum)
}
}
}
64 changes: 27 additions & 37 deletions segment_tree/segment_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,95 +8,85 @@ const (
OperateUpdate Operate = 2
)

type (
Initer[T any] interface {
Init(op Operate) T
}

Merger[T any] interface {
Merge(T) T
}
)

type Element[S Merger[S], T Initer[S]] struct {
type element[T any, S any] struct {
left int
right int
son *Element[S, T]
daughter *Element[S, T]
son *element[T, S]
daughter *element[T, S]
data S
}

type SegmentTree[S Merger[S], T Initer[S]] struct {
root *Element[S, T]
arr []T
type SegmentTree[T any, S any] struct {
root *element[T, S]
arr []T
newSummary NewSummary[T, S]
mergeSummary MergeSummary[S]
}

func New[S Merger[S], T Initer[S]](arr []T) *SegmentTree[S, T] {
var obj = &SegmentTree[S, T]{
root: &Element[S, T]{
func New[T any, S any](arr []T, newSummary NewSummary[T, S], mergeSummary MergeSummary[S]) *SegmentTree[T, S] {
var obj = &SegmentTree[T, S]{
root: &element[T, S]{
left: 0,
right: len(arr) - 1,
},
arr: arr,
arr: arr,
newSummary: newSummary,
mergeSummary: mergeSummary,
}
obj.build(obj.root)
return obj
}

func (c *SegmentTree[S, T]) build(cur *Element[S, T]) {
func (c *SegmentTree[T, S]) build(cur *element[T, S]) {
if cur.left == cur.right {
cur.data = c.arr[cur.left].Init(OperateCreate)
cur.data = c.newSummary(c.arr[cur.left], OperateCreate)
return
}

var mid = (cur.left + cur.right) / 2
cur.son = &Element[S, T]{
cur.son = &element[T, S]{
left: cur.left,
right: mid,
}
cur.daughter = &Element[S, T]{
cur.daughter = &element[T, S]{
left: mid + 1,
right: cur.right,
}
c.build(cur.son)
c.build(cur.daughter)
cur.data = cur.son.data.Merge(cur.daughter.data)
cur.data = c.mergeSummary(cur.son.data, cur.daughter.data)
}

// Query 查询 begin <= index < end 区间
func (c *SegmentTree[S, T]) Query(begin int, end int) S {
var result S
result = c.arr[begin].Init(OperateQuery)
func (c *SegmentTree[T, S]) Query(begin int, end int) S {
result := c.newSummary(c.arr[begin], OperateQuery)
c.doQuery(c.root, begin, end-1, &result)
return result
}

func (c *SegmentTree[S, T]) doQuery(cur *Element[S, T], left int, right int, result *S) {
func (c *SegmentTree[T, S]) doQuery(cur *element[T, S], left int, right int, result *S) {
if cur.left >= left && cur.right <= right {
*result = cur.data.Merge(*result)
*result = c.mergeSummary(*result, cur.data)
} else if !(cur.left > right || cur.right < left) {
c.doQuery(cur.son, left, right, result)
c.doQuery(cur.daughter, left, right, result)
}
}

// Update 更新
func (c *SegmentTree[S, T]) Update(i int, v T) {
func (c *SegmentTree[T, S]) Update(i int, v T) {
c.arr[i] = v
c.rebuild(c.root, i)
}

func (c *SegmentTree[S, T]) rebuild(cur *Element[S, T], i int) {
func (c *SegmentTree[T, S]) rebuild(cur *element[T, S], i int) {
if !(i >= cur.left && i <= cur.right) {
return
}

if cur.left == cur.right && cur.left == i {
cur.data = c.arr[cur.left].Init(OperateUpdate)
cur.data = c.newSummary(c.arr[cur.left], OperateUpdate)
return
}

c.rebuild(cur.son, i)
c.rebuild(cur.daughter, i)
cur.data = cur.son.data.Merge(cur.daughter.data)
cur.data = c.mergeSummary(cur.son.data, cur.daughter.data)
}

0 comments on commit c5fa764

Please sign in to comment.