Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Msmsg support (Meta AI and other personalities) #615

Merged
merged 15 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ var (
ErrUnknownServer = errors.New("can't send message to unknown server")
ErrRecipientADJID = errors.New("message recipient must be a user JID with no device part")
ErrServerReturnedError = errors.New("server returned error")
ErrInvalidInlineBotID = errors.New("invalid inline bot ID")
)

type DownloadHTTPError struct {
Expand Down
61 changes: 61 additions & 0 deletions mdtest/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"net/http"
"os"
"os/signal"
"regexp"
"strconv"
"strings"
"sync/atomic"
Expand Down Expand Up @@ -1022,6 +1023,66 @@ func handleCmd(cmd string, args []string) {
if err != nil {
log.Errorf("Error editing label: %v", err)
}
case "sendbotmsg":
if len(args) < 1 {
log.Errorf("Usage: sendBotMsg <inline jid (optional)> <text>")
return
}
var inlineJID types.JID
if len(args) > 1 {
var numbersRegex = regexp.MustCompile(`^[0-9]+$`)
jid, ok := parseJID(args[0])
if ok && numbersRegex.MatchString(jid.User) {
inlineJID = jid
} else {
inlineJID = types.EmptyJID
}
}

personaID := proto.String("867051314767696$760019659443059") // default meta bot personality: "Assistant"

var resp, err = whatsmeow.SendResponse{}, error(nil)
if !inlineJID.IsEmpty() {
text := fmt.Sprintf("@%s %s", types.MetaAIJID.User, strings.Join(args[1:], " "))
msg := &waE2E.Message{
ExtendedTextMessage: &waE2E.ExtendedTextMessage{
Text: &text,
ContextInfo: &waE2E.ContextInfo{
MentionedJID: []string{types.MetaAIJID.String()},
},
},
MessageContextInfo: &waE2E.MessageContextInfo{
BotMetadata: &waE2E.BotMetadata{
PersonaID: personaID,
},
},
}

resp, err = cli.SendMessage(context.Background(), inlineJID, msg, whatsmeow.SendRequestExtra{
InlineBotJID: types.MetaAIJID,
})
} else {
text := strings.Join(args, " ")
msg := &waE2E.Message{
Conversation: &text,
MessageContextInfo: &waE2E.MessageContextInfo{
BotMetadata: &waE2E.BotMetadata{
PersonaID: personaID,
},
},
}
resp, err = cli.SendMessage(context.Background(), types.MetaAIJID, msg)
}
if err != nil {
log.Errorf("Error sending bot message: %v", err)
} else {
log.Infof("Bot message sent (server timestamp: %s)", resp.Timestamp)
}
case "fetchbotprofiles":
list, _ := cli.GetBotListV2()
log.Infof("Bots list: %+v", list)
profiles, _ := cli.GetBotProfiles(list)
log.Infof("Bots profiles: %+v", profiles)
}
}

Expand Down
87 changes: 86 additions & 1 deletion message.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"runtime/debug"
"time"

"go.mau.fi/whatsmeow/proto/waE2E"

"go.mau.fi/libsignal/groups"
"go.mau.fi/libsignal/protocol"
"go.mau.fi/libsignal/session"
Expand Down Expand Up @@ -88,6 +90,16 @@ func (cli *Client) parseMessageSource(node *waBinary.Node, requireParticipant bo
} else {
source.Chat = from.ToNonAD()
}
} else if from.IsBot() {
source.Sender = from
meta := node.GetChildByTag("meta")
ag = meta.AttrGetter()
targetChatJID := ag.OptionalJID("target_chat_jid")
if targetChatJID != nil {
source.Chat = targetChatJID.ToNonAD()
} else {
source.Chat = from
}
} else {
source.Chat = from.ToNonAD()
source.Sender = from
Expand All @@ -96,6 +108,32 @@ func (cli *Client) parseMessageSource(node *waBinary.Node, requireParticipant bo
return
}

func (cli *Client) parseMsgBotInfo(node waBinary.Node) (botInfo types.MsgBotInfo, err error) {
botNode := node.GetChildByTag("bot")

ag := botNode.AttrGetter()
botInfo.EditType = types.BotEditType(ag.String("edit"))
if botInfo.EditType == types.EditTypeInner || botInfo.EditType == types.EditTypeLast {
botInfo.EditTargetID = types.MessageID(ag.String("edit_target_id"))
botInfo.EditSenderTimestampMS = ag.UnixMilli("sender_timestamp_ms")
}
err = ag.Error()
return
}

func (cli *Client) parseMsgMetaInfo(node waBinary.Node) (metaInfo types.MsgMetaInfo, err error) {
metaNode := node.GetChildByTag("meta")

ag := metaNode.AttrGetter()
metaInfo.TargetID = types.MessageID(ag.String("target_id"))
targetSenderJID := ag.OptionalJIDOrEmpty("target_sender_jid")
if targetSenderJID.User != "" {
metaInfo.TargetSender = targetSenderJID
}
err = ag.Error()
return
}

func (cli *Client) parseMessageInfo(node *waBinary.Node) (*types.MessageInfo, error) {
var info types.MessageInfo
var err error
Expand Down Expand Up @@ -124,6 +162,16 @@ func (cli *Client) parseMessageInfo(node *waBinary.Node) (*types.MessageInfo, er
if err != nil {
cli.Log.Warnf("Failed to parse verified_name node in %s: %v", info.ID, err)
}
case "bot":
info.MsgBotInfo, err = cli.parseMsgBotInfo(child)
if err != nil {
cli.Log.Warnf("Failed to parse <bot> node in %s: %v", info.ID, err)
}
case "meta":
info.MsgMetaInfo, err = cli.parseMsgMetaInfo(child)
if err != nil {
cli.Log.Warnf("Failed to parse <meta> node in %s: %v", info.ID, err)
}
case "franking":
// TODO
case "trace":
Expand Down Expand Up @@ -200,10 +248,47 @@ func (cli *Client) decryptMessages(info *types.MessageInfo, node *waBinary.Node)
containsDirectMsg = true
} else if info.IsGroup && encType == "skmsg" {
decrypted, err = cli.decryptGroupMsg(&child, info.Sender, info.Chat)
} else if encType == "msmsg" && info.Sender.IsBot() {
// Meta AI / other bots (biz?):

// step 1: get message secret
targetSenderJID := info.MsgMetaInfo.TargetSender
if targetSenderJID.User == "" {
// if no targetSenderJID in <meta> this must be ourselves (one-one-one mode)
targetSenderJID = cli.getOwnID()
}

messageSecret, err := cli.Store.MsgSecrets.GetMessageSecret(info.Chat, targetSenderJID, info.MsgMetaInfo.TargetID)
if err != nil || messageSecret == nil {
cli.Log.Warnf("Error getting message secret for bot msg with id %s", node.AttrGetter().String("id"))
continue
}

// step 2: get MessageSecretMessage
byteContents := child.Content.([]byte) // <enc> contents
var msMsg waE2E.MessageSecretMessage

err = proto.Unmarshal(byteContents, &msMsg)
if err != nil {
cli.Log.Warnf("Error decoding MessageSecretMesage protobuf %v", err)
continue
}

// step 3: determine best message id for decryption
var messageID string
if info.MsgBotInfo.EditType == types.EditTypeInner || info.MsgBotInfo.EditType == types.EditTypeLast {
messageID = info.MsgBotInfo.EditTargetID
} else {
messageID = info.ID
}

// step 4: decrypt and voila
decrypted, err = cli.decryptBotMessage(messageSecret, &msMsg, messageID, targetSenderJID, info)
} else {
cli.Log.Warnf("Unhandled encrypted message (type %s) from %s", encType, info.SourceString())
continue
}

if err != nil {
cli.Log.Warnf("Error decrypting message from %s: %v", info.SourceString(), err)
isUnavailable := encType == "skmsg" && !containsDirectMsg && errors.Is(err, signalerror.ErrNoSenderKeyForUser)
Expand All @@ -220,7 +305,7 @@ func (cli *Client) decryptMessages(info *types.MessageInfo, node *waBinary.Node)
cli.cancelDelayedRequestFromPhone(info.ID)
}

var msg waProto.Message
var msg waE2E.Message
switch ag.Int("v") {
case 2:
err = proto.Unmarshal(decrypted, &msg)
Expand Down
40 changes: 30 additions & 10 deletions msgsecret.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import (
"fmt"
"time"

"go.mau.fi/whatsmeow/proto/waCommon"
"go.mau.fi/whatsmeow/proto/waE2E"

"go.mau.fi/util/random"
"google.golang.org/protobuf/proto"

Expand All @@ -26,8 +29,13 @@ type MsgSecretType string
const (
EncSecretPollVote MsgSecretType = "Poll Vote"
EncSecretReaction MsgSecretType = "Enc Reaction"
EncSecretBotMsg MsgSecretType = "Bot Message"
)

func applyBotMessageHKDF(messageSecret []byte) []byte {
return hkdfutil.SHA256(messageSecret, nil, []byte(EncSecretBotMsg), 32)
}

func generateMsgSecretKey(
modificationType MsgSecretType, modificationSender types.JID,
origMsgID types.MessageID, origMsgSender types.JID, origMsgSecret []byte,
Expand All @@ -47,7 +55,7 @@ func generateMsgSecretKey(
return secretKey, additionalData
}

func getOrigSenderFromKey(msg *events.Message, key *waProto.MessageKey) (types.JID, error) {
func getOrigSenderFromKey(msg *events.Message, key *waCommon.MessageKey) (types.JID, error) {
if key.GetFromMe() {
// fromMe always means the poll and vote were sent by the same user
return msg.Info.Sender, nil
Expand All @@ -74,18 +82,18 @@ type messageEncryptedSecret interface {
GetEncPayload() []byte
}

func (cli *Client) decryptMsgSecret(msg *events.Message, useCase MsgSecretType, encrypted messageEncryptedSecret, origMsgKey *waProto.MessageKey) ([]byte, error) {
func (cli *Client) decryptMsgSecret(msg *events.Message, useCase MsgSecretType, encrypted messageEncryptedSecret, origMsgKey *waCommon.MessageKey) ([]byte, error) {
pollSender, err := getOrigSenderFromKey(msg, origMsgKey)
if err != nil {
return nil, err
}
baseEncKey, err := cli.Store.MsgSecrets.GetMessageSecret(msg.Info.Chat, pollSender, origMsgKey.GetId())
baseEncKey, err := cli.Store.MsgSecrets.GetMessageSecret(msg.Info.Chat, pollSender, origMsgKey.GetID())
if err != nil {
return nil, fmt.Errorf("failed to get original message secret key: %w", err)
} else if baseEncKey == nil {
return nil, ErrOriginalMessageSecretNotFound
}
secretKey, additionalData := generateMsgSecretKey(useCase, msg.Info.Sender, origMsgKey.GetId(), pollSender, baseEncKey)
secretKey, additionalData := generateMsgSecretKey(useCase, msg.Info.Sender, origMsgKey.GetID(), pollSender, baseEncKey)
plaintext, err := gcmutil.Decrypt(secretKey, encrypted.GetEncIV(), encrypted.GetEncPayload(), additionalData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt secret message: %w", err)
Expand Down Expand Up @@ -115,6 +123,18 @@ func (cli *Client) encryptMsgSecret(chat, origSender types.JID, origMsgID types.
return ciphertext, iv, nil
}

func (cli *Client) decryptBotMessage(messageSecret []byte, msMsg messageEncryptedSecret, messageID types.MessageID, targetSenderJID types.JID, info *types.MessageInfo) ([]byte, error) {
// gcm decrypt key generation
newKey, additionalData := generateMsgSecretKey("", info.Sender, messageID, targetSenderJID, applyBotMessageHKDF(messageSecret))

plaintext, err := gcmutil.Decrypt(newKey, msMsg.GetEncIV(), msMsg.GetEncPayload(), additionalData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt secret message: %w", err)
}

return plaintext, nil
}

// DecryptReaction decrypts a reaction update message. This form of reactions hasn't been rolled out yet,
// so this function is likely not of much use.
//
Expand All @@ -126,7 +146,7 @@ func (cli *Client) encryptMsgSecret(chat, origSender types.JID, origMsgID types.
// }
// fmt.Printf("Reaction message: %+v\n", reaction)
// }
func (cli *Client) DecryptReaction(reaction *events.Message) (*waProto.ReactionMessage, error) {
func (cli *Client) DecryptReaction(reaction *events.Message) (*waE2E.ReactionMessage, error) {
encReaction := reaction.Message.GetEncReactionMessage()
if encReaction == nil {
return nil, ErrNotEncryptedReactionMessage
Expand All @@ -135,7 +155,7 @@ func (cli *Client) DecryptReaction(reaction *events.Message) (*waProto.ReactionM
if err != nil {
return nil, fmt.Errorf("failed to decrypt reaction: %w", err)
}
var msg waProto.ReactionMessage
var msg waE2E.ReactionMessage
err = proto.Unmarshal(plaintext, &msg)
if err != nil {
return nil, fmt.Errorf("failed to decode reaction protobuf: %w", err)
Expand All @@ -156,7 +176,7 @@ func (cli *Client) DecryptReaction(reaction *events.Message) (*waProto.ReactionM
// fmt.Printf("- %X\n", hash)
// }
// }
func (cli *Client) DecryptPollVote(vote *events.Message) (*waProto.PollVoteMessage, error) {
func (cli *Client) DecryptPollVote(vote *events.Message) (*waE2E.PollVoteMessage, error) {
pollUpdate := vote.Message.GetPollUpdateMessage()
if pollUpdate == nil {
return nil, ErrNotPollUpdateMessage
Expand All @@ -165,16 +185,16 @@ func (cli *Client) DecryptPollVote(vote *events.Message) (*waProto.PollVoteMessa
if err != nil {
return nil, fmt.Errorf("failed to decrypt poll vote: %w", err)
}
var msg waProto.PollVoteMessage
var msg waE2E.PollVoteMessage
err = proto.Unmarshal(plaintext, &msg)
if err != nil {
return nil, fmt.Errorf("failed to decode poll vote protobuf: %w", err)
}
return &msg, nil
}

func getKeyFromInfo(msgInfo *types.MessageInfo) *waProto.MessageKey {
creationKey := &waProto.MessageKey{
func getKeyFromInfo(msgInfo *types.MessageInfo) *waCommon.MessageKey {
creationKey := &waCommon.MessageKey{
RemoteJID: proto.String(msgInfo.Chat.String()),
FromMe: proto.Bool(msgInfo.IsFromMe),
ID: proto.String(msgInfo.ID),
Expand Down
29 changes: 22 additions & 7 deletions prekeys.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,19 +174,34 @@ func nodeToPreKeyBundle(deviceID uint32, node waBinary.Node) (*prekey.Bundle, er
}
identityKeyPub := *(*[32]byte)(identityKeyRaw)

preKey, err := nodeToPreKey(keysNode.GetChildByTag("key"))
if err != nil {
return nil, fmt.Errorf("invalid prekey in prekey response: %w", err)
preKeyNode, ok := keysNode.GetOptionalChildByTag("key")
preKey := &keys.PreKey{}
if ok {
var err error
preKey, err = nodeToPreKey(preKeyNode)
if err != nil {
return nil, fmt.Errorf("invalid prekey in prekey response: %w", err)
}
}

signedPreKey, err := nodeToPreKey(keysNode.GetChildByTag("skey"))
if err != nil {
return nil, fmt.Errorf("invalid signed prekey in prekey response: %w", err)
}

return prekey.NewBundle(registrationID, deviceID,
optional.NewOptionalUint32(preKey.KeyID), signedPreKey.KeyID,
ecc.NewDjbECPublicKey(*preKey.Pub), ecc.NewDjbECPublicKey(*signedPreKey.Pub), *signedPreKey.Signature,
identity.NewKey(ecc.NewDjbECPublicKey(identityKeyPub))), nil
var bundle *prekey.Bundle
if ok {
bundle = prekey.NewBundle(registrationID, deviceID,
optional.NewOptionalUint32(preKey.KeyID), signedPreKey.KeyID,
ecc.NewDjbECPublicKey(*preKey.Pub), ecc.NewDjbECPublicKey(*signedPreKey.Pub), *signedPreKey.Signature,
identity.NewKey(ecc.NewDjbECPublicKey(identityKeyPub)))
} else {
bundle = prekey.NewBundle(registrationID, deviceID, optional.NewEmptyUint32(), signedPreKey.KeyID,
nil, ecc.NewDjbECPublicKey(*signedPreKey.Pub), *signedPreKey.Signature,
identity.NewKey(ecc.NewDjbECPublicKey(identityKeyPub)))
}

return bundle, nil
}

func nodeToPreKey(node waBinary.Node) (*keys.PreKey, error) {
Expand Down
2 changes: 1 addition & 1 deletion retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ func (cli *Client) handleRetryReceipt(receipt *events.Receipt, node *waBinary.No
}
var content []waBinary.Node
if msg.wa != nil {
content = cli.getMessageContent(*encrypted, msg.wa, attrs, includeDeviceIdentity)
content = cli.getMessageContent(*encrypted, msg.wa, attrs, includeDeviceIdentity, nil)
} else {
content = []waBinary.Node{
*encrypted,
Expand Down
Loading