diff --git a/.gitignore b/.gitignore index 66bc8edb..e15440d3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ matterircd matterircd.toml +matterircd-lastsaved.db diff --git a/matterircd.toml.example b/matterircd.toml.example index 13febba8..45a1377f 100644 --- a/matterircd.toml.example +++ b/matterircd.toml.example @@ -134,6 +134,14 @@ SuffixContext = false #in your mattermost "word that trigger mentions" notifications. ShowMentions = false +# Path to file to store last viewed information. This is useful for replying only +# the messages missed. +LastViewedSaveFile = "matterircd-lastsaved.db" +# Interval for how often to save last viewed information. +LastViewedSaveInterval = "5m" +# Consider saved last view information stale if last saved older than this time +LastViewedStaleDuration = "30d" + ############################# ##### SLACK EXAMPLE ######### ############################# diff --git a/mm-go-irckit/server_commands.go b/mm-go-irckit/server_commands.go index e6b79ab2..da9d6d77 100644 --- a/mm-go-irckit/server_commands.go +++ b/mm-go-irckit/server_commands.go @@ -389,8 +389,8 @@ func CmdPrivMsg(s Server, u *User, msg *irc.Message) error { u.msgLastMutex.Lock() defer u.msgLastMutex.Unlock() - u.msgLast[ch.ID()] = [2]string{msgID, ""} + u.saveLastViewedAt(ch.ID()) if u.v.GetBool(u.br.Protocol()+".prefixcontext") || u.v.GetBool(u.br.Protocol()+".suffixcontext") { u.prefixContext(ch.ID(), msgID, "", "") @@ -426,6 +426,7 @@ func CmdPrivMsg(s Server, u *User, msg *irc.Message) error { u.msgLastMutex.Lock() defer u.msgLastMutex.Unlock() u.msgLast[toUser.User] = [2]string{msgID, ""} + u.saveLastViewedAt(toUser.User) if u.v.GetBool(u.br.Protocol()+".prefixcontext") || u.v.GetBool(u.br.Protocol()+".suffixcontext") { u.prefixContext(toUser.User, msgID, "", "") @@ -512,6 +513,8 @@ func parseModifyMsg(u *User, msg *irc.Message, channelID string) bool { return false } u.MsgSpoofUser(u, u.br.Protocol(), "msg: "+text+" could not be modified: "+err.Error()) + } else { + u.saveLastViewedAt(channelID) } return true @@ -592,8 +595,8 @@ func threadMsgChannel(u *User, msg *irc.Message, channelID string) bool { u.msgLastMutex.Lock() defer u.msgLastMutex.Unlock() - u.msgLast[channelID] = [2]string{msgID, threadID} + u.saveLastViewedAt(channelID) if u.v.GetBool(u.br.Protocol()+".prefixcontext") || u.v.GetBool(u.br.Protocol()+".suffixcontext") { u.prefixContext(channelID, msgID, "", "") @@ -616,8 +619,8 @@ func threadMsgUser(u *User, msg *irc.Message, toUser string) bool { u.msgLastMutex.Lock() defer u.msgLastMutex.Unlock() - u.msgLast[toUser] = [2]string{msgID, threadID} + u.saveLastViewedAt(toUser) if u.v.GetBool(u.br.Protocol()+".prefixcontext") || u.v.GetBool(u.br.Protocol()+".suffixcontext") { u.prefixContext(toUser, msgID, "", "") diff --git a/mm-go-irckit/userbridge.go b/mm-go-irckit/userbridge.go index 2a202a42..b035e312 100644 --- a/mm-go-irckit/userbridge.go +++ b/mm-go-irckit/userbridge.go @@ -1,13 +1,17 @@ package irckit import ( + "errors" "fmt" "math/rand" "net" + "os" "strings" "sync" "time" + "encoding/gob" + "github.com/42wim/matterircd/bridge" "github.com/42wim/matterircd/bridge/mattermost" "github.com/42wim/matterircd/bridge/slack" @@ -19,19 +23,25 @@ import ( ) type UserBridge struct { - Srv Server - Credentials bridge.Credentials - br bridge.Bridger //nolint:structcheck - inprogress bool //nolint:structcheck - lastViewedAt map[string]int64 //nolint:structcheck - lastViewedAtMutex sync.RWMutex //nolint:structcheck - msgCounter map[string]int //nolint:structcheck - msgLast map[string][2]string //nolint:structcheck - msgLastMutex sync.RWMutex //nolint:structcheck - msgMap map[string]map[string]int //nolint:structcheck - msgMapMutex sync.RWMutex //nolint:structcheck - updateCounter map[string]time.Time //nolint:structcheck - updateCounterMutex sync.Mutex //nolint:structcheck + Srv Server + Credentials bridge.Credentials + br bridge.Bridger //nolint:structcheck + inprogress bool //nolint:structcheck + + lastViewedAtMutex sync.RWMutex //nolint:structcheck + lastViewedAt map[string]int64 //nolint:structcheck + + lastViewedAtSaved int64 //nolint:structcheck + msgCounter map[string]int //nolint:structcheck + + msgLastMutex sync.RWMutex //nolint:structcheck + msgLast map[string][2]string //nolint:structcheck + + msgMapMutex sync.RWMutex //nolint:structcheck + msgMap map[string]map[string]int //nolint:structcheck + + updateCounterMutex sync.Mutex //nolint:structcheck + updateCounter map[string]time.Time //nolint:structcheck } func NewUserBridge(c net.Conn, srv Server, cfg *viper.Viper) *User { @@ -43,7 +53,7 @@ func NewUserBridge(c net.Conn, srv Server, cfg *viper.Viper) *User { u.Srv = srv u.v = cfg - u.lastViewedAt = make(map[string]int64) + u.lastViewedAt = u.loadLastViewedAt() u.msgLast = make(map[string][2]string) u.msgMap = make(map[string]map[string]int) u.msgCounter = make(map[string]int) @@ -150,9 +160,7 @@ func (u *User) handleDirectMessageEvent(event *bridge.DirectMessageEvent) { if !u.v.GetBool(u.br.Protocol() + ".disableautoview") { u.updateLastViewed(event.ChannelID) } - u.lastViewedAtMutex.Lock() - defer u.lastViewedAtMutex.Unlock() - u.lastViewedAt[event.ChannelID] = model.GetMillis() + u.saveLastViewedAt(event.ChannelID) } func (u *User) handleChannelAddEvent(event *bridge.ChannelAddEvent) { @@ -176,9 +184,7 @@ func (u *User) handleChannelAddEvent(event *bridge.ChannelAddEvent) { if !u.v.GetBool(u.br.Protocol() + ".disableautoview") { u.updateLastViewed(event.ChannelID) } - u.lastViewedAtMutex.Lock() - defer u.lastViewedAtMutex.Unlock() - u.lastViewedAt[event.ChannelID] = model.GetMillis() + u.saveLastViewedAt(event.ChannelID) } func (u *User) handleChannelRemoveEvent(event *bridge.ChannelRemoveEvent) { @@ -198,6 +204,7 @@ func (u *User) handleChannelRemoveEvent(event *bridge.ChannelRemoveEvent) { ch.SpoofMessage("system", "removed "+removed.Nick+" from the channel by "+event.Remover.Nick) } } + u.saveLastViewedAt(event.ChannelID) } func (u *User) getMessageChannel(channelID, channelType string, sender *bridge.UserInfo) Channel { @@ -284,9 +291,7 @@ func (u *User) handleChannelMessageEvent(event *bridge.ChannelMessageEvent) { if !u.v.GetBool(u.br.Protocol() + ".disableautoview") { u.updateLastViewed(event.ChannelID) } - u.lastViewedAtMutex.Lock() - defer u.lastViewedAtMutex.Unlock() - u.lastViewedAt[event.ChannelID] = model.GetMillis() + u.saveLastViewedAt(event.ChannelID) } func (u *User) handleFileEvent(event *bridge.FileEvent) { @@ -384,6 +389,7 @@ func (u *User) handleReactionEvent(event interface{}) { } u.handleDirectMessageEvent(e) + u.saveLastViewedAt(channelID) return } @@ -399,6 +405,7 @@ func (u *User) handleReactionEvent(event interface{}) { } u.handleChannelMessageEvent(e) + u.saveLastViewedAt(channelID) } func (u *User) CreateUserFromInfo(info *bridge.UserInfo) *User { @@ -626,12 +633,12 @@ func (u *User) addUserToChannelWorker(channels <-chan *bridge.ChannelInfo, throt } } - if !u.v.GetBool(u.br.Protocol() + ".disableautoview") { - u.updateLastViewed(brchannel.ID) + if len(mmPostList.Order) > 0 { + if !u.v.GetBool(u.br.Protocol() + ".disableautoview") { + u.updateLastViewed(brchannel.ID) + } + u.saveLastViewedAt(brchannel.ID) } - u.lastViewedAtMutex.Lock() - u.lastViewedAt[brchannel.ID] = model.GetMillis() - u.lastViewedAtMutex.Unlock() } } @@ -885,3 +892,129 @@ func (u *User) updateLastViewed(channelID string) { u.br.UpdateLastViewed(channelID) }() } + +func (u *User) loadLastViewedAt() map[string]int64 { + statePath := u.v.GetString("mattermost.lastviewedsavefile") + if statePath == "" { + return make(map[string]int64) + } + + staleDuration := u.v.GetString("mattermost.lastviewedstaleduration") + lastViewedAt, err := loadLastViewedAtStateFile(statePath, staleDuration) + if err != nil { + logger.Warning("Unable to load saved lastViewedAt, using empty values: ", err) + return make(map[string]int64) + } + + logger.Info("Loaded lastViewedAt from ", time.Unix(lastViewedAt["__LastViewedStateSavedTime__"]/1000, 0)) + u.lastViewedAtSaved = model.GetMillis() + + return lastViewedAt +} + +const defaultSaveInterval = int64((5 * time.Minute) / time.Millisecond) + +func (u *User) saveLastViewedAt(channelID string) { + u.lastViewedAtMutex.Lock() + defer u.lastViewedAtMutex.Unlock() + if channelID != "" { + u.lastViewedAt[channelID] = model.GetMillis() + } + + statePath := u.v.GetString(u.br.Protocol() + ".lastviewedsavefile") + if statePath == "" { + return + } + + // We only want to save or dump out saved lastViewedAt on new + // messages after X time. + var saveInterval int64 + val, err := time.ParseDuration(u.v.GetString(u.br.Protocol() + ".lastviewedsaveinterval")) + if err != nil { + saveInterval = defaultSaveInterval + } else { + saveInterval = val.Milliseconds() + } + if u.lastViewedAtSaved < (model.GetMillis() - saveInterval) { + saveLastViewedAtStateFile(statePath, u.lastViewedAt) + u.lastViewedAtSaved = model.GetMillis() + } +} + +const lastViewedStateFormat = int64(1) + +func saveLastViewedAtStateFile(statePath string, lastViewedAt map[string]int64) error { + f, err := os.Create(statePath) + if err != nil { + logger.Debug("Unable to save lastViewedAt: ", err) + return err + } + defer f.Close() + + currentTime := model.GetMillis() + + lastViewedAt["__LastViewedStateFormat__"] = lastViewedStateFormat + if _, ok := lastViewedAt["__LastViewedStateCreateTime__"]; !ok { + lastViewedAt["__LastViewedStateCreateTime__"] = currentTime + } + lastViewedAt["__LastViewedStateSavedTime__"] = currentTime + // Simple checksum + lastViewedAt["__LastViewedStateChecksum__"] = lastViewedAt["__LastViewedStateCreateTime__"] ^ currentTime + + logger.Debug("Saving lastViewedAt") + + if err := gob.NewEncoder(f).Encode(lastViewedAt); err != nil { + return fmt.Errorf("gob encoding failed: %s", err) + } + + return nil +} + +const defaultStaleDuration = int64((30 * 24 * time.Hour) / time.Millisecond) + +func loadLastViewedAtStateFile(statePath string, staleDuration string) (map[string]int64, error) { + f, err := os.Open(statePath) + if err != nil { + logger.Debug("Unable to load lastViewedAt: ", err) + return nil, err + } + defer f.Close() + + var lastViewedAt map[string]int64 + err = gob.NewDecoder(f).Decode(&lastViewedAt) + if err != nil { + logger.Debug("Unable to load lastViewedAt: ", err) + return nil, err + } + + if lastViewedAt["__LastViewedStateFormat__"] != lastViewedStateFormat { + logger.Debug("State format version mismatch: ", lastViewedAt["__LastViewedStateFormat__"], " vs. ", lastViewedStateFormat) + return nil, errors.New("version mismatch") + } + checksum := lastViewedAt["__LastViewedStateChecksum__"] + createtime := lastViewedAt["__LastViewedStateCreateTime__"] + savedtime := lastViewedAt["__LastViewedStateSavedTime__"] + if createtime^savedtime != checksum { + logger.Debug("Checksum mismatch: (saved checksum, state file creation, last saved time)", checksum, createtime, savedtime) + return nil, errors.New("checksum mismatch") + } + + currentTime := model.GetMillis() + + // Check if stale, time last saved older than defined + var stale int64 + val, err := time.ParseDuration(staleDuration) + if err != nil { + stale = defaultStaleDuration + } else { + stale = val.Milliseconds() + } + + lastSaved, ok := lastViewedAt["__LastViewedStateSavedTime__"] + if !ok || (lastSaved > 0 && lastSaved < currentTime-stale) { + logger.Debug("File stale? Last saved too old: ", time.Unix(lastViewedAt["__LastViewedStateSavedTime__"]/1000, 0)) + return nil, errors.New("stale lastViewedAt state file") + } + + return lastViewedAt, nil +}