|
| 1 | +// Copyright (c) 2025 Tulir Asokan |
| 2 | +// |
| 3 | +// This Source Code Form is subject to the terms of the Mozilla Public |
| 4 | +// License, v. 2.0. If a copy of the MPL was not distributed with this |
| 5 | +// file, You can obtain one at http://mozilla.org/MPL/2.0/. |
| 6 | + |
| 7 | +package commands |
| 8 | + |
| 9 | +import ( |
| 10 | + "context" |
| 11 | + "fmt" |
| 12 | + "runtime/debug" |
| 13 | + "strings" |
| 14 | + "sync" |
| 15 | + |
| 16 | + "github.com/rs/zerolog" |
| 17 | + |
| 18 | + "maunium.net/go/mautrix" |
| 19 | + "maunium.net/go/mautrix/event" |
| 20 | +) |
| 21 | + |
| 22 | +// Processor implements boilerplate code for splitting messages into a command and arguments, |
| 23 | +// and finding the appropriate handler for the command. |
| 24 | +type Processor[MetaType any] struct { |
| 25 | + Client *mautrix.Client |
| 26 | + LogArgs bool |
| 27 | + PreValidator PreValidator[MetaType] |
| 28 | + Meta MetaType |
| 29 | + commands map[string]*Handler[MetaType] |
| 30 | + aliases map[string]string |
| 31 | + lock sync.RWMutex |
| 32 | +} |
| 33 | + |
| 34 | +type Handler[MetaType any] struct { |
| 35 | + Func func(ce *Event[MetaType]) |
| 36 | + |
| 37 | + // Name is the primary name of the command. It must be lowercase. |
| 38 | + Name string |
| 39 | + // Aliases are alternative names for the command. They must be lowercase. |
| 40 | + Aliases []string |
| 41 | +} |
| 42 | + |
| 43 | +// UnknownCommandName is the name of the fallback handler which is used if no other handler is found. |
| 44 | +// If even the unknown command handler is not found, the command is ignored. |
| 45 | +const UnknownCommandName = "unknown-command" |
| 46 | + |
| 47 | +func NewProcessor[MetaType any](cli *mautrix.Client) *Processor[MetaType] { |
| 48 | + proc := &Processor[MetaType]{ |
| 49 | + Client: cli, |
| 50 | + PreValidator: ValidatePrefixSubstring[MetaType]("!"), |
| 51 | + commands: make(map[string]*Handler[MetaType]), |
| 52 | + aliases: make(map[string]string), |
| 53 | + } |
| 54 | + proc.Register(&Handler[MetaType]{ |
| 55 | + Name: UnknownCommandName, |
| 56 | + Func: func(ce *Event[MetaType]) { |
| 57 | + ce.Reply("Unknown command") |
| 58 | + }, |
| 59 | + }) |
| 60 | + return proc |
| 61 | +} |
| 62 | + |
| 63 | +// Register registers the given command handlers. |
| 64 | +func (proc *Processor[MetaType]) Register(handlers ...*Handler[MetaType]) { |
| 65 | + proc.lock.Lock() |
| 66 | + defer proc.lock.Unlock() |
| 67 | + for _, handler := range handlers { |
| 68 | + proc.registerOne(handler) |
| 69 | + } |
| 70 | +} |
| 71 | + |
| 72 | +func (proc *Processor[MetaType]) registerOne(handler *Handler[MetaType]) { |
| 73 | + if strings.ToLower(handler.Name) != handler.Name { |
| 74 | + panic(fmt.Errorf("command %q is not lowercase", handler.Name)) |
| 75 | + } |
| 76 | + proc.commands[handler.Name] = handler |
| 77 | + for _, alias := range handler.Aliases { |
| 78 | + if strings.ToLower(alias) != alias { |
| 79 | + panic(fmt.Errorf("alias %q is not lowercase", alias)) |
| 80 | + } |
| 81 | + proc.aliases[alias] = handler.Name |
| 82 | + } |
| 83 | +} |
| 84 | + |
| 85 | +func (proc *Processor[MetaType]) Unregister(handlers ...*Handler[MetaType]) { |
| 86 | + proc.lock.Lock() |
| 87 | + defer proc.lock.Unlock() |
| 88 | + for _, handler := range handlers { |
| 89 | + proc.unregisterOne(handler) |
| 90 | + } |
| 91 | +} |
| 92 | + |
| 93 | +func (proc *Processor[MetaType]) unregisterOne(handler *Handler[MetaType]) { |
| 94 | + delete(proc.commands, handler.Name) |
| 95 | + for _, alias := range handler.Aliases { |
| 96 | + if proc.aliases[alias] == handler.Name { |
| 97 | + delete(proc.aliases, alias) |
| 98 | + } |
| 99 | + } |
| 100 | +} |
| 101 | + |
| 102 | +func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) { |
| 103 | + log := *zerolog.Ctx(ctx) |
| 104 | + defer func() { |
| 105 | + panicErr := recover() |
| 106 | + if panicErr != nil { |
| 107 | + logEvt := log.Error(). |
| 108 | + Bytes(zerolog.ErrorStackFieldName, debug.Stack()) |
| 109 | + if realErr, ok := panicErr.(error); ok { |
| 110 | + logEvt = logEvt.Err(realErr) |
| 111 | + } else { |
| 112 | + logEvt = logEvt.Any(zerolog.ErrorFieldName, panicErr) |
| 113 | + } |
| 114 | + logEvt.Msg("Panic in command handler") |
| 115 | + _, err := proc.Client.SendReaction(ctx, evt.RoomID, evt.ID, "💥") |
| 116 | + if err != nil { |
| 117 | + log.Err(err).Msg("Failed to send reaction after panic") |
| 118 | + } |
| 119 | + } |
| 120 | + }() |
| 121 | + parsed := ParseEvent[MetaType](ctx, evt) |
| 122 | + if !proc.PreValidator.Validate(parsed) { |
| 123 | + return |
| 124 | + } |
| 125 | + |
| 126 | + realCommand := parsed.Command |
| 127 | + proc.lock.RLock() |
| 128 | + alias, ok := proc.aliases[realCommand] |
| 129 | + if ok { |
| 130 | + realCommand = alias |
| 131 | + } |
| 132 | + handler, ok := proc.commands[realCommand] |
| 133 | + if !ok { |
| 134 | + handler, ok = proc.commands[UnknownCommandName] |
| 135 | + } |
| 136 | + proc.lock.RUnlock() |
| 137 | + if !ok { |
| 138 | + return |
| 139 | + } |
| 140 | + |
| 141 | + logWith := log.With(). |
| 142 | + Str("command", realCommand). |
| 143 | + Stringer("sender", evt.Sender). |
| 144 | + Stringer("room_id", evt.RoomID) |
| 145 | + if proc.LogArgs { |
| 146 | + logWith = logWith.Strs("args", parsed.Args) |
| 147 | + } |
| 148 | + log = logWith.Logger() |
| 149 | + parsed.Ctx = log.WithContext(ctx) |
| 150 | + parsed.Handler = handler |
| 151 | + parsed.Proc = proc |
| 152 | + parsed.Meta = proc.Meta |
| 153 | + |
| 154 | + log.Debug().Msg("Processing command") |
| 155 | + handler.Func(parsed) |
| 156 | +} |
0 commit comments