Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(heap): add go codes #246

Merged
merged 6 commits into from
Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions codes/go/chapter_heap/heap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// File: intHeap.go
// Created Time: 2023-01-12
// Author: Reanon (793584285@qq.com)

package chapter_heap

// Go 语言中可以通过实现 heap.Interface 来构建整数大顶堆
// 实现 heap.Interface 需要同时实现 sort.Interface
type intHeap []any

// Push heap.Interface 的方法,实现推入元素到堆
func (h *intHeap) Push(x any) {
// Push 和 Pop 使用 pointer receiver 作为参数
// 因为它们不仅会对切片的内容进行调整,还会修改切片的长度。
*h = append(*h, x.(int))
}

// Pop heap.Interface 的方法,实现弹出堆顶元素
func (h *intHeap) Pop() any {
// 待出堆元素存放在最后
last := (*h)[len(*h)-1]
*h = (*h)[:len(*h)-1]
return last
}

// Len sort.Interface 的方法
func (h *intHeap) Len() int {
return len(*h)
}

// Less sort.Interface 的方法
func (h *intHeap) Less(i, j int) bool {
// 如果实现小顶堆,则需要调整为小于号
return (*h)[i].(int) > (*h)[j].(int)
}

// Swap sort.Interface 的方法
func (h *intHeap) Swap(i, j int) {
(*h)[i], (*h)[j] = (*h)[j], (*h)[i]
}

// Top 获取堆顶元素
func (h *intHeap) Top() any {
return (*h)[0]
}
90 changes: 90 additions & 0 deletions codes/go/chapter_heap/heap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// File: heap_test.go
// Created Time: 2023-01-12
// Author: Reanon (793584285@qq.com)

package chapter_heap

import (
"container/heap"
"fmt"
"testing"

. "github.com/krahets/hello-algo/pkg"
)

func testPush(h *intHeap, val int) {
// 调用 heap.Interface 的方法,来添加元素
heap.Push(h, val)
fmt.Printf("\n元素 %d 入堆后 \n", val)
PrintHeap(*h)
}

func testPop(h *intHeap) {
// 调用 heap.Interface 的方法,来移除元素
val := heap.Pop(h)
fmt.Printf("\n堆顶元素 %d 出堆后 \n", val)
PrintHeap(*h)
}

func TestHeap(t *testing.T) {
/* 初始化堆 */
// 初始化大顶堆
maxHeap := &intHeap{}
heap.Init(maxHeap)
/* 元素入堆 */
testPush(maxHeap, 1)
testPush(maxHeap, 3)
testPush(maxHeap, 2)
testPush(maxHeap, 5)
testPush(maxHeap, 4)

/* 获取堆顶元素 */
top := maxHeap.Top()
fmt.Printf("堆顶元素为 %d\n", top)

/* 堆顶元素出堆 */
testPop(maxHeap)
testPop(maxHeap)
testPop(maxHeap)
testPop(maxHeap)
testPop(maxHeap)

/* 获取堆大小 */
size := len(*maxHeap)
fmt.Printf("堆元素数量为 %d\n", size)

/* 判断堆是否为空 */
isEmpty := len(*maxHeap) == 0
fmt.Printf("堆是否为空 %t\n", isEmpty)
}

func TestMyHeap(t *testing.T) {
/* 初始化堆 */
// 初始化大顶堆
maxHeap := newMaxHeap([]any{9, 8, 6, 6, 7, 5, 2, 1, 4, 3, 6, 2})
fmt.Printf("输入数组并建堆后\n")
maxHeap.print()

/* 获取堆顶元素 */
peek := maxHeap.peek()
fmt.Printf("\n堆顶元素为 %d\n", peek)

/* 元素入堆 */
val := 7
maxHeap.push(val)
fmt.Printf("\n元素 %d 入堆后\n", val)
maxHeap.print()

/* 堆顶元素出堆 */
peek = maxHeap.poll()
fmt.Printf("\n堆顶元素 %d 出堆后\n", peek)
maxHeap.print()

/* 获取堆大小 */
size := maxHeap.size()
fmt.Printf("\n堆元素数量为 %d\n", size)

/* 判断堆是否为空 */
isEmpty := maxHeap.isEmpty()
fmt.Printf("\n堆是否为空 %t\n", isEmpty)
}
139 changes: 139 additions & 0 deletions codes/go/chapter_heap/my_heap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// File: my_heap.go
// Created Time: 2023-01-12
// Author: Reanon (793584285@qq.com)

package chapter_heap

import (
"fmt"

. "github.com/krahets/hello-algo/pkg"
)

type maxHeap struct {
// 使用切片而非数组,这样无需考虑扩容问题
data []any
}

/* 构造函数,建立空堆 */
func newHeap() *maxHeap {
return &maxHeap{
data: make([]any, 0),
}
}

/* 构造函数,根据切片建堆 */
func newMaxHeap(nums []any) *maxHeap {
// 所有元素入堆
h := &maxHeap{data: nums}
for i := len(h.data) - 1; i >= 0; i-- {
// 堆化除叶结点以外的其他所有结点
h.siftDown(i)
}
return h
}

/* 获取左子结点索引 */
func (h *maxHeap) left(i int) int {
return 2*i + 1
}

/* 获取右子结点索引 */
func (h *maxHeap) right(i int) int {
return 2*i + 2
}

/* 获取父结点索引 */
func (h *maxHeap) parent(i int) int {
// 向下整除
return (i - 1) / 2
}

/* 交换元素 */
func (h *maxHeap) swap(i, j int) {
h.data[i], h.data[j] = h.data[j], h.data[i]
}

/* 获取堆大小 */
func (h *maxHeap) size() int {
return len(h.data)
}

/* 判断堆是否为空 */
func (h *maxHeap) isEmpty() bool {
return len(h.data) == 0
}

/* 访问堆顶元素 */
func (h *maxHeap) peek() any {
return h.data[0]
}

/* 元素入堆 */
func (h *maxHeap) push(val any) {
// 添加结点
h.data = append(h.data, val)
// 从底至顶堆化
h.siftUp(len(h.data) - 1)
}

/* 从结点 i 开始,从底至顶堆化 */
func (h *maxHeap) siftUp(i int) {
for true {
// 获取结点 i 的父结点
p := h.parent(i)
// 当“越过根结点”或“结点无需修复”时,结束堆化
if p < 0 || h.data[i].(int) <= h.data[p].(int) {
break
}
// 交换两结点
h.swap(i, p)
// 循环向上堆化
i = p
}
}

/* 元素出堆 */
func (h *maxHeap) poll() any {
// 判空处理
if h.isEmpty() {
fmt.Println("error")
}
// 交换根结点与最右叶结点(即交换首元素与尾元素)
h.swap(0, h.size()-1)
// 删除结点
val := h.data[len(h.data)-1]
h.data = h.data[:len(h.data)-1]
// 从顶至底堆化
h.siftDown(0)

// 返回堆顶元素
return val
}

/* 从结点 i 开始,从顶至底堆化 */
func (h *maxHeap) siftDown(i int) {
for true {
// 判断结点 i, l, r 中值最大的结点,记为 max
l, r, max := h.left(i), h.right(i), i
if l < h.size() && h.data[l].(int) > h.data[max].(int) {
max = l
}
if r < h.size() && h.data[r].(int) > h.data[max].(int) {
max = r
}
// 若结点 i 最大或索引 l, r 越界,则无需继续堆化,跳出
if max == i {
break
}
// 交换两结点
h.swap(i, max)
// 循环向下堆化
i = max
}
}

/* 打印堆(二叉树) */
func (h *maxHeap) print() {
PrintHeap(h.data)
}
23 changes: 16 additions & 7 deletions codes/go/pkg/print_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ func PrintList(list *list.List) {
fmt.Print(e.Value, "]\n")
}

// PrintMap Print a hash map
func PrintMap[K comparable, V any](m map[K]V) {
for key, value := range m {
fmt.Println(key, "->", value)
}
}

// PrintHeap Print a heap
func PrintHeap(h []any) {
fmt.Printf("堆的数组表示:")
fmt.Printf("%v", h)
fmt.Printf("\n堆的树状表示:\n")
root := ArrToTree(h)
PrintTree(root)
}

// PrintLinkedList Print a linked list
func PrintLinkedList(node *ListNode) {
if node == nil {
Expand Down Expand Up @@ -97,10 +113,3 @@ func showTrunk(t *trunk) {
showTrunk(t.prev)
fmt.Print(t.str)
}

// PrintMap Print a hash map
func PrintMap[K comparable, V any](m map[K]V) {
for key, value := range m {
fmt.Println(key, "->", value)
}
}
Loading