Skip to content

Commit

Permalink
coolq: allow upload media concurrently for normal message
Browse files Browse the repository at this point in the history
  • Loading branch information
wdvxdr1123 committed Feb 22, 2022
1 parent d1f143e commit 987daad
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 50 deletions.
29 changes: 9 additions & 20 deletions coolq/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"runtime"
"strconv"
"strings"
"sync"
"time"

"github.com/segmentio/asm/base64"
Expand Down Expand Up @@ -793,32 +792,27 @@ func (bot *CQBot) uploadForwardElement(m gjson.Result, groupID int64) *message.F
fm := message.NewForwardMessage()
source := message.Source{SourceType: message.SourceGroup, PrimaryID: groupID}

var lazyUpload []func()
var wg sync.WaitGroup
var w worker
resolveElement := func(elems []message.IMessageElement) []message.IMessageElement {
for i, elem := range elems {
iescape := i
p := &elems[i]
switch o := elem.(type) {
case *LocalVideoElement:
wg.Add(1)
lazyUpload = append(lazyUpload, func() {
defer wg.Done()
w.do(func() {
gm, err := bot.uploadLocalVideo(source, o)
if err != nil {
log.Warnf("警告: 群 %d %s上传失败: %v", groupID, o.Type().String(), err)
log.Warnf(uploadFailedTemplate, "群", groupID, "视频", err)
} else {
elems[iescape] = gm
*p = gm
}
})
case *LocalImageElement:
wg.Add(1)
lazyUpload = append(lazyUpload, func() {
defer wg.Done()
w.do(func() {
gm, err := bot.uploadLocalImage(source, o)
if err != nil {
log.Warnf("警告: 群 %d %s上传失败: %v", groupID, o.Type().String(), err)
log.Warnf(uploadFailedTemplate, "群", groupID, "图片", err)
} else {
elems[iescape] = gm
*p = gm
}
})
}
Expand Down Expand Up @@ -903,12 +897,7 @@ func (bot *CQBot) uploadForwardElement(m gjson.Result, groupID int64) *message.F
fm.AddNode(node)
}
}

for _, upload := range lazyUpload {
go upload()
}
wg.Wait()

w.wait()
return bot.Client.UploadGroupForwardMessage(groupID, fm)
}

Expand Down
103 changes: 73 additions & 30 deletions coolq/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,22 @@ func (bot *CQBot) OnEventPush(f func(e *Event)) {
bot.lock.Unlock()
}

type worker struct {
wg sync.WaitGroup
}

func (w *worker) do(f func()) {
w.wg.Add(1)
go func() {
defer w.wg.Done()
f()
}()
}

func (w *worker) wait() {
w.wg.Wait()
}

// uploadLocalImage 上传本地图片
func (bot *CQBot) uploadLocalImage(target message.Source, img *LocalImageElement) (i message.IMessageElement, err error) {
if img.File != "" {
Expand Down Expand Up @@ -163,39 +179,70 @@ func (bot *CQBot) uploadLocalVideo(target message.Source, v *LocalVideoElement)
return bot.Client.UploadShortVideo(target, video, v.thumb, 4)
}

func (bot *CQBot) uploadMedia(target message.Source, elements []message.IMessageElement) []message.IMessageElement {
func removeLocalElement(elements []message.IMessageElement) []message.IMessageElement {
var j int
for _, m := range elements {
raw := m // upload failed will make m nil, so copy it
var err error
for i, e := range elements {
switch e.(type) {
case *LocalImageElement, *LocalVideoElement:
case *message.VoiceElement: // 未上传的语音消息, 也删除
default:
if j < i {
elements[j] = e
}
j++
}
}
return elements[:j]
}

const uploadFailedTemplate = "警告: %s %d %s上传失败: %v"

func (bot *CQBot) uploadMedia(target message.Source, elements []message.IMessageElement) []message.IMessageElement {
var w worker
var source string
switch target.SourceType { // nolint:exhaustive
case message.SourceGroup:
source = "群"
case message.SourcePrivate:
source = "私聊"
case message.SourceGuildChannel:
source = "频道"
}

for i, m := range elements {
p := &elements[i]
switch e := m.(type) {
case *LocalImageElement:
m, err = bot.uploadLocalImage(target, e)
w.do(func() {
m, err := bot.uploadLocalImage(target, e)
if err != nil {
log.Warnf(uploadFailedTemplate, source, target.PrimaryID, "图片", err)
} else {
*p = m
}
})
case *message.VoiceElement:
if target.SourceType == message.SourceGuildChannel {
continue // todo
}
m, err = bot.Client.UploadVoice(target, bytes.NewReader(e.Data))
w.do(func() {
m, err := bot.Client.UploadVoice(target, bytes.NewReader(e.Data))
if err != nil {
log.Warnf(uploadFailedTemplate, source, target.PrimaryID, "语音", err)
} else {
*p = m
}
})
case *LocalVideoElement:
m, err = bot.uploadLocalVideo(target, e)
}
if err != nil {
var source string
switch target.SourceType { // nolint:exhaustive
case message.SourceGroup:
source = "群"
case message.SourcePrivate:
source = "私聊"
case message.SourceGuildChannel:
source = "频道"
}
log.Warnf("警告: %s %d %s上传失败: %v", source, target.PrimaryID, raw.Type().String(), err)
continue
w.do(func() {
m, err := bot.uploadLocalVideo(target, e)
if err != nil {
log.Warnf(uploadFailedTemplate, source, target.PrimaryID, "视频", err)
} else {
*p = m
}
})
}
elements[j] = m
j++
}
return elements[:j]
w.wait()
return removeLocalElement(elements)
}

// SendGroupMessage 发送群消息
Expand Down Expand Up @@ -514,10 +561,6 @@ func (bot *CQBot) InsertGuildChannelMessage(m *message.GuildChannelMessage) stri
return msg.ID
}

// Release 释放Bot实例
func (bot *CQBot) Release() {
}

func (bot *CQBot) dispatchEventMessage(m global.MSG) {
bot.lock.RLock()
defer bot.lock.RUnlock()
Expand Down

0 comments on commit 987daad

Please sign in to comment.