From 5c8ee9a8acdad9e60d8eea7917554dc0885a0ee7 Mon Sep 17 00:00:00 2001 From: Rowan Seymour Date: Wed, 6 Nov 2024 12:32:52 -0500 Subject: [PATCH 1/2] Agressively close redis connections --- handlers/firebase/handler.go | 9 ++++++--- handlers/hormuud/handler.go | 4 ++-- handlers/jiochat/handler.go | 21 ++++++++++++--------- handlers/mtn/handler.go | 17 ++++++++++------- handlers/wechat/handler.go | 21 ++++++++++++--------- handlers/whatsapp_legacy/handler.go | 28 +++++++++++++++++----------- 6 files changed, 59 insertions(+), 41 deletions(-) diff --git a/handlers/firebase/handler.go b/handlers/firebase/handler.go index f7eb228e7..86106c40f 100644 --- a/handlers/firebase/handler.go +++ b/handlers/firebase/handler.go @@ -338,15 +338,15 @@ func (h *handler) sendWithCredsJSON(msg courier.MsgOut, res *courier.SendResult, } func (h *handler) getAccessToken(channel courier.Channel) (string, error) { - rc := h.Backend().RedisPool().Get() - defer rc.Close() - tokenKey := fmt.Sprintf("channel-token:%s", channel.UUID()) h.fetchTokenMutex.Lock() defer h.fetchTokenMutex.Unlock() + rc := h.Backend().RedisPool().Get() token, err := redis.String(rc.Do("GET", tokenKey)) + rc.Close() + if err != nil && err != redis.ErrNil { return "", fmt.Errorf("error reading cached access token: %w", err) } @@ -360,7 +360,10 @@ func (h *handler) getAccessToken(channel courier.Channel) (string, error) { return "", fmt.Errorf("error fetching new access token: %w", err) } + rc = h.Backend().RedisPool().Get() _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) + rc.Close() + if err != nil { return "", fmt.Errorf("error updating cached access token: %w", err) } diff --git a/handlers/hormuud/handler.go b/handlers/hormuud/handler.go index 5219f22cb..198eefacd 100644 --- a/handlers/hormuud/handler.go +++ b/handlers/hormuud/handler.go @@ -173,9 +173,9 @@ func (h *handler) FetchToken(ctx context.Context, channel courier.Channel, msg c // we got a token, cache it to redis with an expiration from the response(we default to 60 minutes) rc = h.Backend().RedisPool().Get() - defer rc.Close() - _, err = rc.Do("SETEX", fmt.Sprintf("hm_token_%s", channel.UUID()), expiration, token) + rc.Close() + if err != nil { slog.Error("error caching HM access token", "error", err) } diff --git a/handlers/jiochat/handler.go b/handlers/jiochat/handler.go index f37c3c997..b52f9b60f 100644 --- a/handlers/jiochat/handler.go +++ b/handlers/jiochat/handler.go @@ -163,7 +163,7 @@ type mtPayload struct { } func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.SendResult, clog *courier.ChannelLog) error { - accessToken, err := h.getAccessToken(ctx, msg.Channel(), clog) + accessToken, err := h.getAccessToken(msg.Channel(), clog) if err != nil { return courier.ErrChannelConfig } @@ -198,7 +198,7 @@ func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.Sen // DescribeURN handles Jiochat contact details func (h *handler) DescribeURN(ctx context.Context, channel courier.Channel, urn urns.URN, clog *courier.ChannelLog) (map[string]string, error) { - accessToken, err := h.getAccessToken(ctx, channel, clog) + accessToken, err := h.getAccessToken(channel, clog) if err != nil { return nil, err } @@ -237,7 +237,7 @@ func (h *handler) BuildAttachmentRequest(ctx context.Context, b courier.Backend, return nil, err } - accessToken, err := h.getAccessToken(ctx, channel, clog) + accessToken, err := h.getAccessToken(channel, clog) if err != nil { return nil, err } @@ -250,16 +250,16 @@ func (h *handler) BuildAttachmentRequest(ctx context.Context, b courier.Backend, var _ courier.AttachmentRequestBuilder = (*handler)(nil) -func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, error) { - rc := h.Backend().RedisPool().Get() - defer rc.Close() - +func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, error) { tokenKey := fmt.Sprintf("channel-token:%s", channel.UUID()) h.fetchTokenMutex.Lock() defer h.fetchTokenMutex.Unlock() + rc := h.Backend().RedisPool().Get() token, err := redis.String(rc.Do("GET", tokenKey)) + rc.Close() + if err != nil && err != redis.ErrNil { return "", fmt.Errorf("error reading cached access token: %w", err) } @@ -268,12 +268,15 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c return token, nil } - token, expires, err := h.fetchAccessToken(ctx, channel, clog) + token, expires, err := h.fetchAccessToken(channel, clog) if err != nil { return "", fmt.Errorf("error fetching new access token: %w", err) } + rc = h.Backend().RedisPool().Get() _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) + rc.Close() + if err != nil { return "", fmt.Errorf("error updating cached access token: %w", err) } @@ -288,7 +291,7 @@ type fetchPayload struct { } // fetchAccessToken tries to fetch a new token for our channel -func (h *handler) fetchAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) { +func (h *handler) fetchAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) { tokenURL, _ := url.Parse(fmt.Sprintf("%s/%s", sendURL, "auth/token.action")) payload := &fetchPayload{ GrantType: "client_credentials", diff --git a/handlers/mtn/handler.go b/handlers/mtn/handler.go index 21946e5b2..bc6af3e3a 100644 --- a/handlers/mtn/handler.go +++ b/handlers/mtn/handler.go @@ -121,7 +121,7 @@ type mtPayload struct { } func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.SendResult, clog *courier.ChannelLog) error { - accessToken, err := h.getAccessToken(ctx, msg.Channel(), clog) + accessToken, err := h.getAccessToken(msg.Channel(), clog) if err != nil { return courier.ErrChannelConfig } @@ -175,16 +175,16 @@ func (h *handler) RedactValues(ch courier.Channel) []string { } } -func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, error) { - rc := h.Backend().RedisPool().Get() - defer rc.Close() - +func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, error) { tokenKey := fmt.Sprintf("channel-token:%s", channel.UUID()) h.fetchTokenMutex.Lock() defer h.fetchTokenMutex.Unlock() + rc := h.Backend().RedisPool().Get() token, err := redis.String(rc.Do("GET", tokenKey)) + rc.Close() + if err != nil && err != redis.ErrNil { return "", fmt.Errorf("error reading cached access token: %w", err) } @@ -193,12 +193,15 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c return token, nil } - token, expires, err := h.fetchAccessToken(ctx, channel, clog) + token, expires, err := h.fetchAccessToken(channel, clog) if err != nil { return "", fmt.Errorf("error fetching new access token: %w", err) } + rc = h.Backend().RedisPool().Get() _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) + rc.Close() + if err != nil { return "", fmt.Errorf("error updating cached access token: %w", err) } @@ -207,7 +210,7 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c } // fetchAccessToken tries to fetch a new token for our channel, setting the result in redis -func (h *handler) fetchAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) { +func (h *handler) fetchAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) { form := url.Values{ "client_id": []string{channel.StringConfigForKey(courier.ConfigAPIKey, "")}, "client_secret": []string{channel.StringConfigForKey(courier.ConfigAuthToken, "")}, diff --git a/handlers/wechat/handler.go b/handlers/wechat/handler.go index 49eb1adf7..74ec56ad5 100644 --- a/handlers/wechat/handler.go +++ b/handlers/wechat/handler.go @@ -177,7 +177,7 @@ type mtPayload struct { } func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.SendResult, clog *courier.ChannelLog) error { - accessToken, err := h.getAccessToken(ctx, msg.Channel(), clog) + accessToken, err := h.getAccessToken(msg.Channel(), clog) if err != nil { return err } @@ -216,7 +216,7 @@ func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.Sen // DescribeURN handles WeChat contact details func (h *handler) DescribeURN(ctx context.Context, channel courier.Channel, urn urns.URN, clog *courier.ChannelLog) (map[string]string, error) { - accessToken, err := h.getAccessToken(ctx, channel, clog) + accessToken, err := h.getAccessToken(channel, clog) if err != nil { return nil, err } @@ -255,7 +255,7 @@ func (h *handler) RedactValues(ch courier.Channel) []string { // BuildAttachmentRequest download media for message attachment func (h *handler) BuildAttachmentRequest(ctx context.Context, b courier.Backend, channel courier.Channel, attachmentURL string, clog *courier.ChannelLog) (*http.Request, error) { - accessToken, err := h.getAccessToken(ctx, channel, clog) + accessToken, err := h.getAccessToken(channel, clog) if err != nil { return nil, err } @@ -275,16 +275,16 @@ func (h *handler) BuildAttachmentRequest(ctx context.Context, b courier.Backend, var _ courier.AttachmentRequestBuilder = (*handler)(nil) -func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, error) { - rc := h.Backend().RedisPool().Get() - defer rc.Close() - +func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, error) { tokenKey := fmt.Sprintf("channel-token:%s", channel.UUID()) h.fetchTokenMutex.Lock() defer h.fetchTokenMutex.Unlock() + rc := h.Backend().RedisPool().Get() token, err := redis.String(rc.Do("GET", tokenKey)) + rc.Close() + if err != nil && err != redis.ErrNil { return "", fmt.Errorf("error reading cached access token: %w", err) } @@ -293,12 +293,15 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c return token, nil } - token, expires, err := h.fetchAccessToken(ctx, channel, clog) + token, expires, err := h.fetchAccessToken(channel, clog) if err != nil { return "", fmt.Errorf("error fetching new access token: %w", err) } + rc = h.Backend().RedisPool().Get() _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) + rc.Close() + if err != nil { return "", fmt.Errorf("error updating cached access token: %w", err) } @@ -307,7 +310,7 @@ func (h *handler) getAccessToken(ctx context.Context, channel courier.Channel, c } // fetchAccessToken tries to fetch a new token for our channel, setting the result in redis -func (h *handler) fetchAccessToken(ctx context.Context, channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) { +func (h *handler) fetchAccessToken(channel courier.Channel, clog *courier.ChannelLog) (string, time.Duration, error) { form := url.Values{ "grant_type": []string{"client_credential"}, "appid": []string{channel.StringConfigForKey(configAppID, "")}, diff --git a/handlers/whatsapp_legacy/handler.go b/handlers/whatsapp_legacy/handler.go index 50646800d..50a7b2883 100644 --- a/handlers/whatsapp_legacy/handler.go +++ b/handlers/whatsapp_legacy/handler.go @@ -16,7 +16,6 @@ import ( "time" "github.com/buger/jsonparser" - "github.com/gomodule/redigo/redis" "github.com/nyaruka/courier" "github.com/nyaruka/courier/handlers" "github.com/nyaruka/courier/utils" @@ -495,9 +494,6 @@ type mtErrorPayload struct { const maxMsgLength = 4096 func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.SendResult, clog *courier.ChannelLog) error { - conn := h.Backend().RedisPool().Get() - defer conn.Close() - // get our token token := msg.Channel().StringConfigForKey(courier.ConfigAuthToken, "") urlStr := msg.Channel().StringConfigForKey(courier.ConfigBaseURL, "") @@ -519,7 +515,7 @@ func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.Sen for _, payload := range payloads { externalID := "" - wppID, externalID, err = h.sendWhatsAppMsg(conn, msg, sendPath, payload, clog) + wppID, externalID, err = h.sendWhatsAppMsg(msg, sendPath, payload, clog) if err != nil { return err } @@ -562,7 +558,7 @@ func buildPayloads(msg courier.MsgOut, h *handler, clog *courier.ChannelLog) ([] for attachmentCount, attachment := range msg.Attachments() { mimeType, mediaURL := handlers.SplitAttachment(attachment) - mediaID, err := h.fetchMediaID(msg, mimeType, mediaURL, clog) + mediaID, err := h.fetchMediaID(msg, mediaURL, clog) if err != nil { slog.Error("error while uploading media to whatsapp", "error", err, "channel_uuid", msg.Channel().UUID()) } @@ -817,14 +813,15 @@ func buildPayloads(msg courier.MsgOut, h *handler, clog *courier.ChannelLog) ([] } // fetchMediaID tries to fetch the id for the uploaded media, setting the result in redis. -func (h *handler) fetchMediaID(msg courier.MsgOut, mimeType, mediaURL string, clog *courier.ChannelLog) (string, error) { +func (h *handler) fetchMediaID(msg courier.MsgOut, mediaURL string, clog *courier.ChannelLog) (string, error) { // check in cache first - rc := h.Backend().RedisPool().Get() - defer rc.Close() - cacheKey := fmt.Sprintf(mediaCacheKeyPattern, msg.Channel().UUID()) mediaCache := redisx.NewIntervalHash(cacheKey, time.Hour*24, 2) + + rc := h.Backend().RedisPool().Get() mediaID, err := mediaCache.Get(rc, mediaURL) + rc.Close() + if err != nil { return "", fmt.Errorf("error reading media id from redis: %s : %s: %w", cacheKey, mediaURL, err) } else if mediaID != "" { @@ -885,7 +882,10 @@ func (h *handler) fetchMediaID(msg courier.MsgOut, mimeType, mediaURL string, cl } // put in cache + rc = h.Backend().RedisPool().Get() err = mediaCache.Set(rc, mediaURL, mediaID) + rc.Close() + if err != nil { return "", fmt.Errorf("error setting media id in cache: %w", err) } @@ -893,7 +893,7 @@ func (h *handler) fetchMediaID(msg courier.MsgOut, mimeType, mediaURL string, cl return mediaID, nil } -func (h *handler) sendWhatsAppMsg(rc redis.Conn, msg courier.MsgOut, sendPath *url.URL, payload any, clog *courier.ChannelLog) (string, string, error) { +func (h *handler) sendWhatsAppMsg(msg courier.MsgOut, sendPath *url.URL, payload any, clog *courier.ChannelLog) (string, string, error) { jsonBody := jsonx.MustMarshal(payload) req, _ := http.NewRequest(http.MethodPost, sendPath.String(), bytes.NewReader(jsonBody)) @@ -906,12 +906,15 @@ func (h *handler) sendWhatsAppMsg(rc redis.Conn, msg courier.MsgOut, sendPath *u if resp != nil && (resp.StatusCode == 429 || resp.StatusCode == 503) { rateLimitKey := fmt.Sprintf("rate_limit:%s", msg.Channel().UUID()) + + rc := h.Backend().RedisPool().Get() rc.Do("SET", rateLimitKey, "engaged") // The rate limit is 50 requests per second // We pause sending 2 seconds so the limit count is reset // TODO: In the future we should the header value when available rc.Do("EXPIRE", rateLimitKey, 2) + rc.Close() return "", "", courier.ErrConnectionThrottled } @@ -923,11 +926,14 @@ func (h *handler) sendWhatsAppMsg(rc redis.Conn, msg courier.MsgOut, sendPath *u if err == nil && len(errPayload.Errors) > 0 { if hasTiersError(*errPayload) { rateLimitBulkKey := fmt.Sprintf("rate_limit_bulk:%s", msg.Channel().UUID()) + + rc := h.Backend().RedisPool().Get() rc.Do("SET", rateLimitBulkKey, "engaged") // The WA tiers spam rate limit hit // We pause the bulk queue for 24 hours and 5min rc.Do("EXPIRE", rateLimitBulkKey, (60*60*24)+(5*60)) + rc.Close() return "", "", courier.ErrConnectionThrottled } From 9c173f8c1291d4c8ba068c3a151afc8d36dd60ed Mon Sep 17 00:00:00 2001 From: Rowan Seymour Date: Wed, 6 Nov 2024 12:49:06 -0500 Subject: [PATCH 2/2] Add BaseHandler.WithRedisConn util method --- handlers/base.go | 8 ++++++ handlers/firebase/handler.go | 14 +++++----- handlers/hormuud/handler.go | 20 +++++++------- handlers/jiochat/handler.go | 14 +++++----- handlers/mtn/handler.go | 14 +++++----- handlers/wechat/handler.go | 14 +++++----- handlers/whatsapp_legacy/handler.go | 41 ++++++++++++++++------------- 7 files changed, 72 insertions(+), 53 deletions(-) diff --git a/handlers/base.go b/handlers/base.go index bdc4e9d41..e9a50b46f 100644 --- a/handlers/base.go +++ b/handlers/base.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" + "github.com/gomodule/redigo/redis" "github.com/nyaruka/courier" "github.com/nyaruka/gocommon/httpx" ) @@ -148,3 +149,10 @@ func (h *BaseHandler) WriteRequestError(ctx context.Context, w http.ResponseWrit func (h *BaseHandler) WriteRequestIgnored(ctx context.Context, w http.ResponseWriter, details string) error { return courier.WriteIgnored(w, details) } + +// WithRedisConn is a utility to execute some code with a redis connection +func (h *BaseHandler) WithRedisConn(fn func(rc redis.Conn)) { + rc := h.Backend().RedisPool().Get() + defer rc.Close() + fn(rc) +} diff --git a/handlers/firebase/handler.go b/handlers/firebase/handler.go index 86106c40f..b88de4e81 100644 --- a/handlers/firebase/handler.go +++ b/handlers/firebase/handler.go @@ -343,9 +343,11 @@ func (h *handler) getAccessToken(channel courier.Channel) (string, error) { h.fetchTokenMutex.Lock() defer h.fetchTokenMutex.Unlock() - rc := h.Backend().RedisPool().Get() - token, err := redis.String(rc.Do("GET", tokenKey)) - rc.Close() + var token string + var err error + h.WithRedisConn(func(rc redis.Conn) { + token, err = redis.String(rc.Do("GET", tokenKey)) + }) if err != nil && err != redis.ErrNil { return "", fmt.Errorf("error reading cached access token: %w", err) @@ -360,9 +362,9 @@ func (h *handler) getAccessToken(channel courier.Channel) (string, error) { return "", fmt.Errorf("error fetching new access token: %w", err) } - rc = h.Backend().RedisPool().Get() - _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) - rc.Close() + h.WithRedisConn(func(rc redis.Conn) { + _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) + }) if err != nil { return "", fmt.Errorf("error updating cached access token: %w", err) diff --git a/handlers/hormuud/handler.go b/handlers/hormuud/handler.go index 198eefacd..615972ad9 100644 --- a/handlers/hormuud/handler.go +++ b/handlers/hormuud/handler.go @@ -130,9 +130,10 @@ func (h *handler) Send(ctx context.Context, msg courier.MsgOut, res *courier.Sen // FetchToken gets the current token for this channel, either from Redis if cached or by requesting it func (h *handler) FetchToken(ctx context.Context, channel courier.Channel, msg courier.MsgOut, username, password string, clog *courier.ChannelLog) (string, error) { // first check whether we have it in redis - rc := h.Backend().RedisPool().Get() - token, _ := redis.String(rc.Do("GET", fmt.Sprintf("hm_token_%s", channel.UUID()))) - rc.Close() + var token string + h.WithRedisConn(func(rc redis.Conn) { + token, _ = redis.String(rc.Do("GET", fmt.Sprintf("hm_token_%s", channel.UUID()))) + }) // got a token, use it if token != "" { @@ -172,13 +173,12 @@ func (h *handler) FetchToken(ctx context.Context, channel courier.Channel, msg c } // we got a token, cache it to redis with an expiration from the response(we default to 60 minutes) - rc = h.Backend().RedisPool().Get() - _, err = rc.Do("SETEX", fmt.Sprintf("hm_token_%s", channel.UUID()), expiration, token) - rc.Close() - - if err != nil { - slog.Error("error caching HM access token", "error", err) - } + h.WithRedisConn(func(rc redis.Conn) { + _, err = rc.Do("SETEX", fmt.Sprintf("hm_token_%s", channel.UUID()), expiration, token) + if err != nil { + slog.Error("error caching HM access token", "error", err) + } + }) return token, nil } diff --git a/handlers/jiochat/handler.go b/handlers/jiochat/handler.go index b52f9b60f..0acccd885 100644 --- a/handlers/jiochat/handler.go +++ b/handlers/jiochat/handler.go @@ -256,9 +256,11 @@ func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelL h.fetchTokenMutex.Lock() defer h.fetchTokenMutex.Unlock() - rc := h.Backend().RedisPool().Get() - token, err := redis.String(rc.Do("GET", tokenKey)) - rc.Close() + var token string + var err error + h.WithRedisConn(func(rc redis.Conn) { + token, err = redis.String(rc.Do("GET", tokenKey)) + }) if err != nil && err != redis.ErrNil { return "", fmt.Errorf("error reading cached access token: %w", err) @@ -273,9 +275,9 @@ func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelL return "", fmt.Errorf("error fetching new access token: %w", err) } - rc = h.Backend().RedisPool().Get() - _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) - rc.Close() + h.WithRedisConn(func(rc redis.Conn) { + _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) + }) if err != nil { return "", fmt.Errorf("error updating cached access token: %w", err) diff --git a/handlers/mtn/handler.go b/handlers/mtn/handler.go index bc6af3e3a..af4c8730b 100644 --- a/handlers/mtn/handler.go +++ b/handlers/mtn/handler.go @@ -181,9 +181,11 @@ func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelL h.fetchTokenMutex.Lock() defer h.fetchTokenMutex.Unlock() - rc := h.Backend().RedisPool().Get() - token, err := redis.String(rc.Do("GET", tokenKey)) - rc.Close() + var token string + var err error + h.WithRedisConn(func(rc redis.Conn) { + token, err = redis.String(rc.Do("GET", tokenKey)) + }) if err != nil && err != redis.ErrNil { return "", fmt.Errorf("error reading cached access token: %w", err) @@ -198,9 +200,9 @@ func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelL return "", fmt.Errorf("error fetching new access token: %w", err) } - rc = h.Backend().RedisPool().Get() - _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) - rc.Close() + h.WithRedisConn(func(rc redis.Conn) { + _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) + }) if err != nil { return "", fmt.Errorf("error updating cached access token: %w", err) diff --git a/handlers/wechat/handler.go b/handlers/wechat/handler.go index 74ec56ad5..c91acd810 100644 --- a/handlers/wechat/handler.go +++ b/handlers/wechat/handler.go @@ -281,9 +281,11 @@ func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelL h.fetchTokenMutex.Lock() defer h.fetchTokenMutex.Unlock() - rc := h.Backend().RedisPool().Get() - token, err := redis.String(rc.Do("GET", tokenKey)) - rc.Close() + var token string + var err error + h.WithRedisConn(func(rc redis.Conn) { + token, err = redis.String(rc.Do("GET", tokenKey)) + }) if err != nil && err != redis.ErrNil { return "", fmt.Errorf("error reading cached access token: %w", err) @@ -298,9 +300,9 @@ func (h *handler) getAccessToken(channel courier.Channel, clog *courier.ChannelL return "", fmt.Errorf("error fetching new access token: %w", err) } - rc = h.Backend().RedisPool().Get() - _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) - rc.Close() + h.WithRedisConn(func(rc redis.Conn) { + _, err = rc.Do("SET", tokenKey, token, "EX", int(expires/time.Second)) + }) if err != nil { return "", fmt.Errorf("error updating cached access token: %w", err) diff --git a/handlers/whatsapp_legacy/handler.go b/handlers/whatsapp_legacy/handler.go index 50a7b2883..5b7ce4e1d 100644 --- a/handlers/whatsapp_legacy/handler.go +++ b/handlers/whatsapp_legacy/handler.go @@ -16,6 +16,7 @@ import ( "time" "github.com/buger/jsonparser" + "github.com/gomodule/redigo/redis" "github.com/nyaruka/courier" "github.com/nyaruka/courier/handlers" "github.com/nyaruka/courier/utils" @@ -818,9 +819,11 @@ func (h *handler) fetchMediaID(msg courier.MsgOut, mediaURL string, clog *courie cacheKey := fmt.Sprintf(mediaCacheKeyPattern, msg.Channel().UUID()) mediaCache := redisx.NewIntervalHash(cacheKey, time.Hour*24, 2) - rc := h.Backend().RedisPool().Get() - mediaID, err := mediaCache.Get(rc, mediaURL) - rc.Close() + var mediaID string + var err error + h.WithRedisConn(func(rc redis.Conn) { + mediaID, err = mediaCache.Get(rc, mediaURL) + }) if err != nil { return "", fmt.Errorf("error reading media id from redis: %s : %s: %w", cacheKey, mediaURL, err) @@ -882,9 +885,9 @@ func (h *handler) fetchMediaID(msg courier.MsgOut, mediaURL string, clog *courie } // put in cache - rc = h.Backend().RedisPool().Get() - err = mediaCache.Set(rc, mediaURL, mediaID) - rc.Close() + h.WithRedisConn(func(rc redis.Conn) { + err = mediaCache.Set(rc, mediaURL, mediaID) + }) if err != nil { return "", fmt.Errorf("error setting media id in cache: %w", err) @@ -907,14 +910,14 @@ func (h *handler) sendWhatsAppMsg(msg courier.MsgOut, sendPath *url.URL, payload if resp != nil && (resp.StatusCode == 429 || resp.StatusCode == 503) { rateLimitKey := fmt.Sprintf("rate_limit:%s", msg.Channel().UUID()) - rc := h.Backend().RedisPool().Get() - rc.Do("SET", rateLimitKey, "engaged") + h.WithRedisConn(func(rc redis.Conn) { + rc.Do("SET", rateLimitKey, "engaged") - // The rate limit is 50 requests per second - // We pause sending 2 seconds so the limit count is reset - // TODO: In the future we should the header value when available - rc.Do("EXPIRE", rateLimitKey, 2) - rc.Close() + // The rate limit is 50 requests per second + // We pause sending 2 seconds so the limit count is reset + // TODO: In the future we should the header value when available + rc.Do("EXPIRE", rateLimitKey, 2) + }) return "", "", courier.ErrConnectionThrottled } @@ -927,13 +930,13 @@ func (h *handler) sendWhatsAppMsg(msg courier.MsgOut, sendPath *url.URL, payload if hasTiersError(*errPayload) { rateLimitBulkKey := fmt.Sprintf("rate_limit_bulk:%s", msg.Channel().UUID()) - rc := h.Backend().RedisPool().Get() - rc.Do("SET", rateLimitBulkKey, "engaged") + h.WithRedisConn(func(rc redis.Conn) { + rc.Do("SET", rateLimitBulkKey, "engaged") - // The WA tiers spam rate limit hit - // We pause the bulk queue for 24 hours and 5min - rc.Do("EXPIRE", rateLimitBulkKey, (60*60*24)+(5*60)) - rc.Close() + // The WA tiers spam rate limit hit + // We pause the bulk queue for 24 hours and 5min + rc.Do("EXPIRE", rateLimitBulkKey, (60*60*24)+(5*60)) + }) return "", "", courier.ErrConnectionThrottled }