diff --git a/handlers/base.go b/handlers/base.go index bdc4e9d4..e9a50b46 100644 --- a/handlers/base.go +++ b/handlers/base.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" + "github.com/gomodule/redigo/redis" "github.com/nyaruka/courier" "github.com/nyaruka/gocommon/httpx" ) @@ -148,3 +149,10 @@ func (h *BaseHandler) WriteRequestError(ctx context.Context, w http.ResponseWrit func (h *BaseHandler) WriteRequestIgnored(ctx context.Context, w http.ResponseWriter, details string) error { return courier.WriteIgnored(w, details) } + +// WithRedisConn is a utility to execute some code with a redis connection +func (h *BaseHandler) WithRedisConn(fn func(rc redis.Conn)) { + rc := h.Backend().RedisPool().Get() + defer rc.Close() + fn(rc) +} diff --git a/handlers/firebase/handler.go b/handlers/firebase/handler.go index f7eb228e..b88de4e8 100644 --- a/handlers/firebase/handler.go +++ b/handlers/firebase/handler.go @@ -338,15 +338,17 @@ func (h *handler) sendWithCredsJSON(msg courier.MsgOut, res *courier.SendResult, } func (h *handler) getAccessToken(channel courier.Channel) (string, error) { - rc := h.Backend().RedisPool().Get() - defer rc.Close() - tokenKey := fmt.Sprintf("channel-token:%s", channel.UUID()) h.fetchTokenMutex.Lock() defer h.fetchTokenMutex.Unlock() - token, err := redis.String(rc.Do("GET", tokenKey)) + var token string + var err error + h.WithRedisConn(func(rc redis.Conn) { + token, err = redis.String(rc.Do("GET", tokenKey)) + }) + if err != nil && err != redis.ErrNil { return "", fmt.Errorf("error reading cached access token: %w", err) } @@ -360,7 +362,10 @@ func (h *handler) getAccessToken(channel courier.Channel) (string, error) { return "", fmt.Errorf("error fetching new access token: %w", err) } - _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) + h.WithRedisConn(func(rc redis.Conn) { + _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) + }) + if err != nil { return "", fmt.Errorf("error updating cached access token: %w", err) } diff --git a/handlers/hormuud/handler.go b/handlers/hormuud/handler.go index 5219f22c..615972ad 100644 --- a/handlers/hormuud/handler.go +++ b/handlers/hormuud/handler.go @@ -130,9 +130,10 @@ func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.Sen // FetchToken gets the current token for this channel, either from Redis if cached or by requesting it func (h *handler) FetchToken(ctx context.Context, channel courier.Channel, msg courier.MsgOut, username, password string, clog *courier.ChannelLog) (string, error) { // first check whether we have it in redis - rc := h.Backend().RedisPool().Get() - token, _ := redis.String(rc.Do("GET", fmt.Sprintf("hm_token_%s", channel.UUID()))) - rc.Close() + var token string + h.WithRedisConn(func(rc redis.Conn) { + token, _ = redis.String(rc.Do("GET", fmt.Sprintf("hm_token_%s", channel.UUID()))) + }) // got a token, use it if token != "" { @@ -172,13 +173,12 @@ func (h *handler) FetchToken(ctx context.Context, channel courier.Channel, msg c } // we got a token, cache it to redis with an expiration from the response(we default to 60 minutes) - rc = h.Backend().RedisPool().Get() - defer rc.Close() - - _, err = rc.Do("SETEX", fmt.Sprintf("hm_token_%s", channel.UUID()), expiration, token) - if err != nil { - slog.Error("error caching HM access token", "error", err) - } + h.WithRedisConn(func(rc redis.Conn) { + _, err = rc.Do("SETEX", fmt.Sprintf("hm_token_%s", channel.UUID()), expiration, token) + if err != nil { + slog.Error("error caching HM access token", "error", err) + } + }) return token, nil } diff --git a/handlers/jiochat/handler.go b/handlers/jiochat/handler.go index f37c3c99..0acccd88 100644 --- a/handlers/jiochat/handler.go +++ b/handlers/jiochat/handler.go @@ -163,7 +163,7 @@ type mtPayload struct { } func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.SendResult, clog *courier.ChannelLog) error { - accessToken, err := h.getAccessToken(ctx, msg.Channel(), clog) + accessToken, err := h.getAccessToken(msg.Channel(), clog) if err != nil { return courier.ErrChannelConfig } @@ -198,7 +198,7 @@ func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.Sen // DescribeURN handles Jiochat contact details func (h *handler) DescribeURN(ctx context.Context, channel courier.Channel, urn urns.URN, clog *courier.ChannelLog) (map[string]string, error) { - accessToken, err := h.getAccessToken(ctx, channel, clog) + accessToken, err := h.getAccessToken(channel, clog) if err != nil { return nil, err } @@ -237,7 +237,7 @@ func (h *handler) BuildAttachmentRequest(ctx context.Context, b courier.Backend, return nil, err } - accessToken, err := h.getAccessToken(ctx, channel, clog) + accessToken, err := h.getAccessToken(channel, clog) if err != nil { return nil, err } @@ -250,16 +250,18 @@ func (h *handler) BuildAttachmentRequest(ctx context.Context, b courier.Backend, var _ courier.AttachmentRequestBuilder = (*handler)(nil) -func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, error) { - rc := h.Backend().RedisPool().Get() - defer rc.Close() - +func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, error) { tokenKey := fmt.Sprintf("channel-token:%s", channel.UUID()) h.fetchTokenMutex.Lock() defer h.fetchTokenMutex.Unlock() - token, err := redis.String(rc.Do("GET", tokenKey)) + var token string + var err error + h.WithRedisConn(func(rc redis.Conn) { + token, err = redis.String(rc.Do("GET", tokenKey)) + }) + if err != nil && err != redis.ErrNil { return "", fmt.Errorf("error reading cached access token: %w", err) } @@ -268,12 +270,15 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c return token, nil } - token, expires, err := h.fetchAccessToken(ctx, channel, clog) + token, expires, err := h.fetchAccessToken(channel, clog) if err != nil { return "", fmt.Errorf("error fetching new access token: %w", err) } - _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) + h.WithRedisConn(func(rc redis.Conn) { + _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) + }) + if err != nil { return "", fmt.Errorf("error updating cached access token: %w", err) } @@ -288,7 +293,7 @@ type fetchPayload struct { } // fetchAccessToken tries to fetch a new token for our channel -func (h *handler) fetchAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) { +func (h *handler) fetchAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) { tokenURL, _ := url.Parse(fmt.Sprintf("%s/%s", sendURL, "auth/token.action")) payload := &fetchPayload{ GrantType: "client_credentials", diff --git a/handlers/mtn/handler.go b/handlers/mtn/handler.go index 21946e5b..af4c8730 100644 --- a/handlers/mtn/handler.go +++ b/handlers/mtn/handler.go @@ -121,7 +121,7 @@ type mtPayload struct { } func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.SendResult, clog *courier.ChannelLog) error { - accessToken, err := h.getAccessToken(ctx, msg.Channel(), clog) + accessToken, err := h.getAccessToken(msg.Channel(), clog) if err != nil { return courier.ErrChannelConfig } @@ -175,16 +175,18 @@ func (h *handler) RedactValues(ch courier.Channel) []string { } } -func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, error) { - rc := h.Backend().RedisPool().Get() - defer rc.Close() - +func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, error) { tokenKey := fmt.Sprintf("channel-token:%s", channel.UUID()) h.fetchTokenMutex.Lock() defer h.fetchTokenMutex.Unlock() - token, err := redis.String(rc.Do("GET", tokenKey)) + var token string + var err error + h.WithRedisConn(func(rc redis.Conn) { + token, err = redis.String(rc.Do("GET", tokenKey)) + }) + if err != nil && err != redis.ErrNil { return "", fmt.Errorf("error reading cached access token: %w", err) } @@ -193,12 +195,15 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c return token, nil } - token, expires, err := h.fetchAccessToken(ctx, channel, clog) + token, expires, err := h.fetchAccessToken(channel, clog) if err != nil { return "", fmt.Errorf("error fetching new access token: %w", err) } - _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) + h.WithRedisConn(func(rc redis.Conn) { + _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) + }) + if err != nil { return "", fmt.Errorf("error updating cached access token: %w", err) } @@ -207,7 +212,7 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c } // fetchAccessToken tries to fetch a new token for our channel, setting the result in redis -func (h *handler) fetchAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) { +func (h *handler) fetchAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) { form := url.Values{ "client_id": []string{channel.StringConfigForKey(courier.ConfigAPIKey, "")}, "client_secret": []string{channel.StringConfigForKey(courier.ConfigAuthToken, "")}, diff --git a/handlers/wechat/handler.go b/handlers/wechat/handler.go index 49eb1adf..c91acd81 100644 --- a/handlers/wechat/handler.go +++ b/handlers/wechat/handler.go @@ -177,7 +177,7 @@ type mtPayload struct { } func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.SendResult, clog *courier.ChannelLog) error { - accessToken, err := h.getAccessToken(ctx, msg.Channel(), clog) + accessToken, err := h.getAccessToken(msg.Channel(), clog) if err != nil { return err } @@ -216,7 +216,7 @@ func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.Sen // DescribeURN handles WeChat contact details func (h *handler) DescribeURN(ctx context.Context, channel courier.Channel, urn urns.URN, clog *courier.ChannelLog) (map[string]string, error) { - accessToken, err := h.getAccessToken(ctx, channel, clog) + accessToken, err := h.getAccessToken(channel, clog) if err != nil { return nil, err } @@ -255,7 +255,7 @@ func (h *handler) RedactValues(ch courier.Channel) []string { // BuildAttachmentRequest download media for message attachment func (h *handler) BuildAttachmentRequest(ctx context.Context, b courier.Backend, channel courier.Channel, attachmentURL string, clog *courier.ChannelLog) (*http.Request, error) { - accessToken, err := h.getAccessToken(ctx, channel, clog) + accessToken, err := h.getAccessToken(channel, clog) if err != nil { return nil, err } @@ -275,16 +275,18 @@ func (h *handler) BuildAttachmentRequest(ctx context.Context, b courier.Backend, var _ courier.AttachmentRequestBuilder = (*handler)(nil) -func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, error) { - rc := h.Backend().RedisPool().Get() - defer rc.Close() - +func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, error) { tokenKey := fmt.Sprintf("channel-token:%s", channel.UUID()) h.fetchTokenMutex.Lock() defer h.fetchTokenMutex.Unlock() - token, err := redis.String(rc.Do("GET", tokenKey)) + var token string + var err error + h.WithRedisConn(func(rc redis.Conn) { + token, err = redis.String(rc.Do("GET", tokenKey)) + }) + if err != nil && err != redis.ErrNil { return "", fmt.Errorf("error reading cached access token: %w", err) } @@ -293,12 +295,15 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c return token, nil } - token, expires, err := h.fetchAccessToken(ctx, channel, clog) + token, expires, err := h.fetchAccessToken(channel, clog) if err != nil { return "", fmt.Errorf("error fetching new access token: %w", err) } - _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) + h.WithRedisConn(func(rc redis.Conn) { + _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) + }) + if err != nil { return "", fmt.Errorf("error updating cached access token: %w", err) } @@ -307,7 +312,7 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c } // fetchAccessToken tries to fetch a new token for our channel, setting the result in redis -func (h *handler) fetchAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) { +func (h *handler) fetchAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) { form := url.Values{ "grant_type": []string{"client_credential"}, "appid": []string{channel.StringConfigForKey(configAppID, "")}, diff --git a/handlers/whatsapp_legacy/handler.go b/handlers/whatsapp_legacy/handler.go index 50646800..5b7ce4e1 100644 --- a/handlers/whatsapp_legacy/handler.go +++ b/handlers/whatsapp_legacy/handler.go @@ -495,9 +495,6 @@ type mtErrorPayload struct { const maxMsgLength = 4096 func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.SendResult, clog *courier.ChannelLog) error { - conn := h.Backend().RedisPool().Get() - defer conn.Close() - // get our token token := msg.Channel().StringConfigForKey(courier.ConfigAuthToken, "") urlStr := msg.Channel().StringConfigForKey(courier.ConfigBaseURL, "") @@ -519,7 +516,7 @@ func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.Sen for _, payload := range payloads { externalID := "" - wppID, externalID, err = h.sendWhatsAppMsg(conn, msg, sendPath, payload, clog) + wppID, externalID, err = h.sendWhatsAppMsg(msg, sendPath, payload, clog) if err != nil { return err } @@ -562,7 +559,7 @@ func buildPayloads(msg courier.MsgOut, h *handler, clog *courier.ChannelLog) ([] for attachmentCount, attachment := range msg.Attachments() { mimeType, mediaURL := handlers.SplitAttachment(attachment) - mediaID, err := h.fetchMediaID(msg, mimeType, mediaURL, clog) + mediaID, err := h.fetchMediaID(msg, mediaURL, clog) if err != nil { slog.Error("error while uploading media to whatsapp", "error", err, "channel_uuid", msg.Channel().UUID()) } @@ -817,14 +814,17 @@ func buildPayloads(msg courier.MsgOut, h *handler, clog *courier.ChannelLog) ([] } // fetchMediaID tries to fetch the id for the uploaded media, setting the result in redis. -func (h *handler) fetchMediaID(msg courier.MsgOut, mimeType, mediaURL string, clog *courier.ChannelLog) (string, error) { +func (h *handler) fetchMediaID(msg courier.MsgOut, mediaURL string, clog *courier.ChannelLog) (string, error) { // check in cache first - rc := h.Backend().RedisPool().Get() - defer rc.Close() - cacheKey := fmt.Sprintf(mediaCacheKeyPattern, msg.Channel().UUID()) mediaCache := redisx.NewIntervalHash(cacheKey, time.Hour*24, 2) - mediaID, err := mediaCache.Get(rc, mediaURL) + + var mediaID string + var err error + h.WithRedisConn(func(rc redis.Conn) { + mediaID, err = mediaCache.Get(rc, mediaURL) + }) + if err != nil { return "", fmt.Errorf("error reading media id from redis: %s : %s: %w", cacheKey, mediaURL, err) } else if mediaID != "" { @@ -885,7 +885,10 @@ func (h *handler) fetchMediaID(msg courier.MsgOut, mimeType, mediaURL string, cl } // put in cache - err = mediaCache.Set(rc, mediaURL, mediaID) + h.WithRedisConn(func(rc redis.Conn) { + err = mediaCache.Set(rc, mediaURL, mediaID) + }) + if err != nil { return "", fmt.Errorf("error setting media id in cache: %w", err) } @@ -893,7 +896,7 @@ func (h *handler) fetchMediaID(msg courier.MsgOut, mimeType, mediaURL string, cl return mediaID, nil } -func (h *handler) sendWhatsAppMsg(rc redis.Conn, msg courier.MsgOut, sendPath *url.URL, payload any, clog *courier.ChannelLog) (string, string, error) { +func (h *handler) sendWhatsAppMsg(msg courier.MsgOut, sendPath *url.URL, payload any, clog *courier.ChannelLog) (string, string, error) { jsonBody := jsonx.MustMarshal(payload) req, _ := http.NewRequest(http.MethodPost, sendPath.String(), bytes.NewReader(jsonBody)) @@ -906,12 +909,15 @@ func (h *handler) sendWhatsAppMsg(rc redis.Conn, msg courier.MsgOut, sendPath *u if resp != nil && (resp.StatusCode == 429 || resp.StatusCode == 503) { rateLimitKey := fmt.Sprintf("rate_limit:%s", msg.Channel().UUID()) - rc.Do("SET", rateLimitKey, "engaged") - // The rate limit is 50 requests per second - // We pause sending 2 seconds so the limit count is reset - // TODO: In the future we should the header value when available - rc.Do("EXPIRE", rateLimitKey, 2) + h.WithRedisConn(func(rc redis.Conn) { + rc.Do("SET", rateLimitKey, "engaged") + + // The rate limit is 50 requests per second + // We pause sending 2 seconds so the limit count is reset + // TODO: In the future we should the header value when available + rc.Do("EXPIRE", rateLimitKey, 2) + }) return "", "", courier.ErrConnectionThrottled } @@ -923,11 +929,14 @@ func (h *handler) sendWhatsAppMsg(rc redis.Conn, msg courier.MsgOut, sendPath *u if err == nil && len(errPayload.Errors) > 0 { if hasTiersError(*errPayload) { rateLimitBulkKey := fmt.Sprintf("rate_limit_bulk:%s", msg.Channel().UUID()) - rc.Do("SET", rateLimitBulkKey, "engaged") - // The WA tiers spam rate limit hit - // We pause the bulk queue for 24 hours and 5min - rc.Do("EXPIRE", rateLimitBulkKey, (60*60*24)+(5*60)) + h.WithRedisConn(func(rc redis.Conn) { + rc.Do("SET", rateLimitBulkKey, "engaged") + + // The WA tiers spam rate limit hit + // We pause the bulk queue for 24 hours and 5min + rc.Do("EXPIRE", rateLimitBulkKey, (60*60*24)+(5*60)) + }) return "", "", courier.ErrConnectionThrottled }