Skip to content

Commit

Permalink
Add initial support for fb retries
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Feb 9, 2024
1 parent e6138b3 commit e82ca29
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 41 deletions.
4 changes: 2 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ type Client struct {
userDevicesCache map[types.JID]deviceCache
userDevicesCacheLock sync.Mutex

recentMessagesMap map[recentMessageKey]*waProto.Message
recentMessagesMap map[recentMessageKey]RecentMessage
recentMessagesList [recentMessagesSize]recentMessageKey
recentMessagesPtr int
recentMessagesLock sync.RWMutex
Expand Down Expand Up @@ -220,7 +220,7 @@ func NewClient(deviceStore *store.Device, log waLog.Logger) *Client {
groupParticipantsCache: make(map[types.JID][]types.JID),
userDevicesCache: make(map[types.JID]deviceCache),

recentMessagesMap: make(map[recentMessageKey]*waProto.Message, recentMessagesSize),
recentMessagesMap: make(map[recentMessageKey]RecentMessage, recentMessagesSize),
sessionRecreateHistory: make(map[types.JID]time.Time),
GetMessageForRetry: func(requester, to types.JID, id types.MessageID) *waProto.Message { return nil },
appStateKeyRequests: make(map[string]time.Time),
Expand Down
137 changes: 104 additions & 33 deletions retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ package whatsmeow

import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/binary"
"fmt"
"time"
Expand All @@ -19,6 +21,10 @@ import (
"google.golang.org/protobuf/proto"

waBinary "go.mau.fi/whatsmeow/binary"
"go.mau.fi/whatsmeow/binary/armadillo/waCommon"
"go.mau.fi/whatsmeow/binary/armadillo/waConsumerApplication"
"go.mau.fi/whatsmeow/binary/armadillo/waMsgApplication"
"go.mau.fi/whatsmeow/binary/armadillo/waMsgTransport"
waProto "go.mau.fi/whatsmeow/binary/proto"
"go.mau.fi/whatsmeow/types"
"go.mau.fi/whatsmeow/types/events"
Expand All @@ -32,19 +38,22 @@ type recentMessageKey struct {
ID types.MessageID
}

// RecentMessage contains the info needed to re-send a message when another device fails to decrypt it.
type RecentMessage struct {
Proto *waProto.Message
Timestamp time.Time
wa *waProto.Message
fb *waMsgApplication.MessageApplication
}

func (cli *Client) addRecentMessage(to types.JID, id types.MessageID, message *waProto.Message) {
func (rm RecentMessage) IsEmpty() bool {
return rm.wa == nil && rm.fb == nil
}

func (cli *Client) addRecentMessage(to types.JID, id types.MessageID, wa *waProto.Message, fb *waMsgApplication.MessageApplication) {
cli.recentMessagesLock.Lock()
key := recentMessageKey{to, id}
if cli.recentMessagesList[cli.recentMessagesPtr].ID != "" {
delete(cli.recentMessagesMap, cli.recentMessagesList[cli.recentMessagesPtr])
}
cli.recentMessagesMap[key] = message
cli.recentMessagesMap[key] = RecentMessage{wa: wa, fb: fb}
cli.recentMessagesList[cli.recentMessagesPtr] = key
cli.recentMessagesPtr++
if cli.recentMessagesPtr >= len(cli.recentMessagesList) {
Expand All @@ -53,26 +62,27 @@ func (cli *Client) addRecentMessage(to types.JID, id types.MessageID, message *w
cli.recentMessagesLock.Unlock()
}

func (cli *Client) getRecentMessage(to types.JID, id types.MessageID) *waProto.Message {
func (cli *Client) getRecentMessage(to types.JID, id types.MessageID) RecentMessage {
cli.recentMessagesLock.RLock()
msg, _ := cli.recentMessagesMap[recentMessageKey{to, id}]
cli.recentMessagesLock.RUnlock()
return msg
}

func (cli *Client) getMessageForRetry(receipt *events.Receipt, messageID types.MessageID) (*waProto.Message, error) {
func (cli *Client) getMessageForRetry(receipt *events.Receipt, messageID types.MessageID) (RecentMessage, error) {
msg := cli.getRecentMessage(receipt.Chat, messageID)
if msg == nil {
msg = cli.GetMessageForRetry(receipt.Sender, receipt.Chat, messageID)
if msg == nil {
return nil, fmt.Errorf("couldn't find message %s", messageID)
if msg.IsEmpty() {
waMsg := cli.GetMessageForRetry(receipt.Sender, receipt.Chat, messageID)
if waMsg == nil {
return RecentMessage{}, fmt.Errorf("couldn't find message %s", messageID)
} else {
cli.Log.Debugf("Found message in GetMessageForRetry to accept retry receipt for %s/%s from %s", receipt.Chat, messageID, receipt.Sender)
}
msg = RecentMessage{wa: waMsg}
} else {
cli.Log.Debugf("Found message in local cache to accept retry receipt for %s/%s from %s", receipt.Chat, messageID, receipt.Sender)
}
return proto.Clone(msg).(*waProto.Message), nil
return msg, nil
}

const recreateSessionTimeout = 1 * time.Hour
Expand Down Expand Up @@ -101,11 +111,6 @@ type incomingRetryKey struct {

// handleRetryReceipt handles an incoming retry receipt for an outgoing message.
func (cli *Client) handleRetryReceipt(receipt *events.Receipt, node *waBinary.Node) error {
// TODO implement replying to retry receipts in messenger mode
if cli.MessengerConfig != nil {
return nil
}

retryChild, ok := node.GetOptionalChildByTag("retry")
if !ok {
return &ElementMissingError{Tag: "retry", In: "retry receipt"}
Expand All @@ -121,6 +126,16 @@ func (cli *Client) handleRetryReceipt(receipt *events.Receipt, node *waBinary.No
if err != nil {
return err
}
var fbConsumerMsg *waConsumerApplication.ConsumerApplication
if msg.fb != nil {
subProto, ok := msg.fb.GetPayload().GetSubProtocol().GetSubProtocol().(*waMsgApplication.MessageApplication_SubProtocolPayload_ConsumerMessage)
if ok {
fbConsumerMsg, err = subProto.Decode()
if err != nil {
return fmt.Errorf("failed to decode consumer message for retry: %w", err)
}
}
}

retryKey := incomingRetryKey{receipt.Sender, messageID}
cli.incomingRetryRequestCounterLock.Lock()
Expand All @@ -137,35 +152,61 @@ func (cli *Client) handleRetryReceipt(receipt *events.Receipt, node *waBinary.No
return ErrNotLoggedIn
}

var fbSKDM *waMsgTransport.MessageTransport_Protocol_Ancillary_SenderKeyDistributionMessage
var fbDSM *waMsgTransport.MessageTransport_Protocol_Integral_DeviceSentMessage
if receipt.IsGroup {
builder := groups.NewGroupSessionBuilder(cli.Store, pbSerializer)
senderKeyName := protocol.NewSenderKeyName(receipt.Chat.String(), ownID.SignalAddress())
signalSKDMessage, err := builder.Create(senderKeyName)
if err != nil {
cli.Log.Warnf("Failed to create sender key distribution message to include in retry of %s in %s to %s: %v", messageID, receipt.Chat, receipt.Sender, err)
} else {
msg.SenderKeyDistributionMessage = &waProto.SenderKeyDistributionMessage{
}
if msg.wa != nil {
msg.wa.SenderKeyDistributionMessage = &waProto.SenderKeyDistributionMessage{
GroupId: proto.String(receipt.Chat.String()),
AxolotlSenderKeyDistributionMessage: signalSKDMessage.Serialize(),
}
} else {
fbSKDM = &waMsgTransport.MessageTransport_Protocol_Ancillary_SenderKeyDistributionMessage{
GroupID: receipt.Chat.String(),
AxolotlSenderKeyDistributionMessage: signalSKDMessage.Serialize(),
}
}
} else if receipt.IsFromMe {
msg = &waProto.Message{
DeviceSentMessage: &waProto.DeviceSentMessage{
DestinationJid: proto.String(receipt.Chat.String()),
Message: msg,
},
if msg.wa != nil {
msg.wa = &waProto.Message{
DeviceSentMessage: &waProto.DeviceSentMessage{
DestinationJid: proto.String(receipt.Chat.String()),
Message: msg.wa,
},
}
} else {
fbDSM = &waMsgTransport.MessageTransport_Protocol_Integral_DeviceSentMessage{
DestinationJID: receipt.Chat.String(),
}
}
}

if cli.PreRetryCallback != nil && !cli.PreRetryCallback(receipt, messageID, retryCount, msg) {
// TODO pre-retry callback for fb
if cli.PreRetryCallback != nil && !cli.PreRetryCallback(receipt, messageID, retryCount, msg.wa) {
cli.Log.Debugf("Cancelled retry receipt in PreRetryCallback")
return nil
}

plaintext, err := proto.Marshal(msg)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
var plaintext, frankingTag []byte
if msg.wa != nil {
plaintext, err = proto.Marshal(msg.wa)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
} else {
plaintext, err = proto.Marshal(msg.fb)
if err != nil {
return fmt.Errorf("failed to marshal consumer message: %w", err)
}
frankingHash := hmac.New(sha256.New, msg.fb.GetMetadata().GetFrankingKey())
frankingHash.Write(plaintext)
frankingTag = frankingHash.Sum(nil)
}
_, hasKeys := node.GetOptionalChildByTag("keys")
var bundle *prekey.Bundle
Expand All @@ -189,18 +230,39 @@ func (cli *Client) handleRetryReceipt(receipt *events.Receipt, node *waBinary.No
}
}
encAttrs := waBinary.Attrs{}
if mediaType := getMediaTypeFromMessage(msg); mediaType != "" {
encAttrs["mediatype"] = mediaType
var msgAttrs messageAttrs
if msg.wa != nil {
msgAttrs.MediaType = getMediaTypeFromMessage(msg.wa)
msgAttrs.Type = getTypeFromMessage(msg.wa)
} else if fbConsumerMsg != nil {
msgAttrs = getAttrsFromFBMessage(fbConsumerMsg)
} else {
msgAttrs.Type = "text"
}
if msgAttrs.MediaType != "" {
encAttrs["mediatype"] = msgAttrs.MediaType
}
var encrypted *waBinary.Node
var includeDeviceIdentity bool
if msg.wa != nil {
encrypted, includeDeviceIdentity, err = cli.encryptMessageForDevice(plaintext, receipt.Sender, bundle, encAttrs)
} else {
encrypted, err = cli.encryptMessageForDeviceV3(&waMsgTransport.MessageTransport_Payload{
ApplicationPayload: &waCommon.SubProtocol{
Payload: plaintext,
Version: FBMessageApplicationVersion,
},
FutureProof: waCommon.FutureProofBehavior_PLACEHOLDER,
}, fbSKDM, fbDSM, receipt.Sender, bundle, encAttrs)
}
encrypted, includeDeviceIdentity, err := cli.encryptMessageForDevice(plaintext, receipt.Sender, bundle, encAttrs)
if err != nil {
return fmt.Errorf("failed to encrypt message for retry: %w", err)
}
encrypted.Attrs["count"] = retryCount

attrs := waBinary.Attrs{
"to": node.Attrs["from"],
"type": getTypeFromMessage(msg),
"type": msgAttrs.Type,
"id": messageID,
"t": timestamp.Unix(),
}
Expand All @@ -216,10 +278,19 @@ func (cli *Client) handleRetryReceipt(receipt *events.Receipt, node *waBinary.No
if edit, ok := node.Attrs["edit"]; ok {
attrs["edit"] = edit
}
var content []waBinary.Node
if msg.wa != nil {
content = cli.getMessageContent(*encrypted, msg.wa, attrs, includeDeviceIdentity)
} else {
content = []waBinary.Node{
*encrypted,
{Tag: "franking", Content: []waBinary.Node{{Tag: "franking_tag", Content: frankingTag}}},
}
}
err = cli.sendNode(waBinary.Node{
Tag: "message",
Attrs: attrs,
Content: cli.getMessageContent(*encrypted, msg, attrs, includeDeviceIdentity),
Content: content,
})
if err != nil {
return fmt.Errorf("failed to send retry message: %w", err)
Expand Down
2 changes: 1 addition & 1 deletion send.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (cli *Client) SendMessage(ctx context.Context, to types.JID, message *waPro
respChan := cli.waitResponse(req.ID)
// Peer message retries aren't implemented yet
if !req.Peer {
cli.addRecentMessage(to, req.ID, message)
cli.addRecentMessage(to, req.ID, message, nil)
}
if message.GetMessageContextInfo().GetMessageSecret() != nil {
err = cli.Store.MsgSecrets.PutMessageSecret(to, ownID, req.ID, message.GetMessageContextInfo().GetMessageSecret())
Expand Down
11 changes: 6 additions & 5 deletions sendfb.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (cli *Client) SendFBMessage(
metadata.FrankingVersion = 0
metadata.FrankingKey = random.Bytes(32)
msgAttrs := getAttrsFromFBMessage(message)
messageApp, err := proto.Marshal(&waMsgApplication.MessageApplication{
messageAppProto := &waMsgApplication.MessageApplication{
Payload: &waMsgApplication.MessageApplication_Payload{
Content: &waMsgApplication.MessageApplication_Payload_SubProtocol{
SubProtocol: &waMsgApplication.MessageApplication_SubProtocolPayload{
Expand All @@ -77,7 +77,8 @@ func (cli *Client) SendFBMessage(
},
},
Metadata: metadata,
})
}
messageApp, err := proto.Marshal(messageAppProto)
if err != nil {
return resp, fmt.Errorf("failed to marshal message application: %w", err)
}
Expand Down Expand Up @@ -109,9 +110,9 @@ func (cli *Client) SendFBMessage(
defer cli.messageSendLock.Unlock()

respChan := cli.waitResponse(req.ID)
//if !req.Peer {
// cli.addRecentMessage(to, req.ID, message)
//}
if !req.Peer {
cli.addRecentMessage(to, req.ID, nil, messageAppProto)
}
var phash string
var data []byte
switch to.Server {
Expand Down

0 comments on commit e82ca29

Please sign in to comment.