From 3d1497a6343c96f1a2a6f151fdb2f26c51b6e49b Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Fri, 7 Feb 2025 01:19:27 +0000 Subject: [PATCH] fix: fix early cancel when RequestTimeout is provided for streaming requests (#3904) --- internal/requestconfig/requestconfig.go | 59 ++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 7 deletions(-) diff --git a/internal/requestconfig/requestconfig.go b/internal/requestconfig/requestconfig.go index 724d554f283..75c11791fc2 100644 --- a/internal/requestconfig/requestconfig.go +++ b/internal/requestconfig/requestconfig.go @@ -295,6 +295,41 @@ func parseRetryAfterHeader(resp *http.Response) (time.Duration, bool) { return 0, false } +// isBeforeContextDeadline reports whether the non-zero Time t is +// before ctx's deadline. If ctx does not have a deadline, it +// always reports true (the deadline is considered infinite). +func isBeforeContextDeadline(t time.Time, ctx context.Context) bool { + d, ok := ctx.Deadline() + if !ok { + return true + } + return t.Before(d) +} + +// bodyWithTimeout is an io.ReadCloser which can observe a context's cancel func +// to handle timeouts etc. It wraps an existing io.ReadCloser. +type bodyWithTimeout struct { + stop func() // stops the time.Timer waiting to cancel the request + rc io.ReadCloser +} + +func (b *bodyWithTimeout) Read(p []byte) (n int, err error) { + n, err = b.rc.Read(p) + if err == nil { + return n, nil + } + if err == io.EOF { + return n, err + } + return n, err +} + +func (b *bodyWithTimeout) Close() error { + err := b.rc.Close() + b.stop() + return err +} + func retryDelay(res *http.Response, retryCount int) time.Duration { // If the API asks us to wait a certain amount of time (and it's a reasonable amount), // just do what it says. @@ -356,12 +391,17 @@ func (cfg *RequestConfig) Execute() (err error) { shouldSendRetryCount := cfg.Request.Header.Get("X-Stainless-Retry-Count") == "0" var res *http.Response + var cancel context.CancelFunc for retryCount := 0; retryCount <= cfg.MaxRetries; retryCount += 1 { ctx := cfg.Request.Context() - if cfg.RequestTimeout != time.Duration(0) { - var cancel context.CancelFunc + if cfg.RequestTimeout != time.Duration(0) && isBeforeContextDeadline(time.Now().Add(cfg.RequestTimeout), ctx) { ctx, cancel = context.WithTimeout(ctx, cfg.RequestTimeout) - defer cancel() + defer func() { + // The cancel function is nil if it was handed off to be handled in a different scope. + if cancel != nil { + cancel() + } + }() } req := cfg.Request.Clone(ctx) @@ -429,10 +469,15 @@ func (cfg *RequestConfig) Execute() (err error) { return &aerr } - if cfg.ResponseBodyInto == nil { - return nil - } - if _, ok := cfg.ResponseBodyInto.(**http.Response); ok { + _, intoCustomResponseBody := cfg.ResponseBodyInto.(**http.Response) + if cfg.ResponseBodyInto == nil || intoCustomResponseBody { + // We aren't reading the response body in this scope, but whoever is will need the + // cancel func from the context to observe request timeouts. + // Put the cancel function in the response body so it can be handled elsewhere. + if cancel != nil { + res.Body = &bodyWithTimeout{rc: res.Body, stop: cancel} + cancel = nil + } return nil }