diff --git a/internal/backoff/backoff.go b/internal/backoff/backoff.go index 5fc0ee3da53b..fed1c011a325 100644 --- a/internal/backoff/backoff.go +++ b/internal/backoff/backoff.go @@ -23,6 +23,8 @@ package backoff import ( + "context" + "errors" "time" grpcbackoff "google.golang.org/grpc/backoff" @@ -71,3 +73,37 @@ func (bc Exponential) Backoff(retries int) time.Duration { } return time.Duration(backoff) } + +// ErrResetBackoff is the error to be returned by the function executed by RunF, +// to instruct the latter to reset its backoff state. +var ErrResetBackoff = errors.New("reset backoff state") + +// RunF provides a convenient way to run a function f repeatedly until the +// context expires or f returns a non-nil error that is not ErrResetBackoff. +// When f returns ErrResetBackoff, RunF continues to run f, but resets its +// backoff state before doing so. backoff accepts an integer representing the +// number of retries, and returns the amount of time to backoff. +func RunF(ctx context.Context, f func() error, backoff func(int) time.Duration) { + attempt := 0 + timer := time.NewTimer(0) + for ctx.Err() == nil { + select { + case <-timer.C: + case <-ctx.Done(): + timer.Stop() + return + } + + err := f() + if errors.Is(err, ErrResetBackoff) { + timer.Reset(0) + attempt = 0 + continue + } + if err != nil { + return + } + timer.Reset(backoff(attempt)) + attempt++ + } +} diff --git a/orca/producer.go b/orca/producer.go index 2d58725547fc..04edae6de66f 100644 --- a/orca/producer.go +++ b/orca/producer.go @@ -24,6 +24,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/balancer" "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/backoff" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/orca/internal" "google.golang.org/grpc/status" @@ -169,48 +170,29 @@ func (p *producer) updateRunLocked() { func (p *producer) run(ctx context.Context, done chan struct{}, interval time.Duration) { defer close(done) - backoffAttempt := 0 - backoffTimer := time.NewTimer(0) - for ctx.Err() == nil { - select { - case <-backoffTimer.C: - case <-ctx.Done(): - return - } - + runStream := func() error { resetBackoff, err := p.runStream(ctx, interval) - - if resetBackoff { - backoffTimer.Reset(0) - backoffAttempt = 0 - } else { - backoffTimer.Reset(p.backoff(backoffAttempt)) - backoffAttempt++ - } - - switch { - case err == nil: - // No error was encountered; restart the stream. - case ctx.Err() != nil: - // Producer was stopped; exit immediately and without logging an - // error. - return - case status.Code(err) == codes.Unimplemented: + if status.Code(err) == codes.Unimplemented { // Unimplemented; do not retry. logger.Error("Server doesn't support ORCA OOB load reporting protocol; not listening for load reports.") - return - case status.Code(err) == codes.Unavailable, status.Code(err) == codes.Canceled: - // TODO: these codes should ideally log an error, too, but for now - // we receive them when shutting down the ClientConn (Unavailable - // if the stream hasn't started yet, and Canceled if it happens - // mid-stream). Once we can determine the state or ensure the - // producer is stopped before the stream ends, we can log an error - // when it's not a natural shutdown. - default: - // Log all other errors. + return err + } + // Retry for all other errors. + if code := status.Code(err); code != codes.Unavailable && code != codes.Canceled { + // TODO: Unavailable and Canceled should also ideally log an error, + // but for now we receive them when shutting down the ClientConn + // (Unavailable if the stream hasn't started yet, and Canceled if it + // happens mid-stream). Once we can determine the state or ensure + // the producer is stopped before the stream ends, we can log an + // error when it's not a natural shutdown. logger.Error("Received unexpected stream error:", err) } + if resetBackoff { + return backoff.ErrResetBackoff + } + return nil } + backoff.RunF(ctx, runStream, p.backoff) } // runStream runs a single stream on the subchannel and returns the resulting diff --git a/xds/internal/xdsclient/transport/loadreport.go b/xds/internal/xdsclient/transport/loadreport.go index 89ffc4fcec66..4b8ca29ce93f 100644 --- a/xds/internal/xdsclient/transport/loadreport.go +++ b/xds/internal/xdsclient/transport/loadreport.go @@ -25,6 +25,7 @@ import ( "time" "github.com/golang/protobuf/ptypes" + "google.golang.org/grpc/internal/backoff" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/xds/internal" @@ -100,54 +101,36 @@ func (t *Transport) lrsRunner(ctx context.Context) { node := proto.Clone(t.nodeProto).(*v3corepb.Node) node.ClientFeatures = append(node.ClientFeatures, "envoy.lrs.supports_send_all_clusters") - backoffAttempt := 0 - backoffTimer := time.NewTimer(0) - for ctx.Err() == nil { - select { - case <-backoffTimer.C: - case <-ctx.Done(): - backoffTimer.Stop() - return + runLoadReportStream := func() error { + // streamCtx is created and canceled in case we terminate the stream + // early for any reason, to avoid gRPC-Go leaking the RPC's monitoring + // goroutine. + streamCtx, cancel := context.WithCancel(ctx) + defer cancel() + stream, err := v3lrsgrpc.NewLoadReportingServiceClient(t.cc).StreamLoadStats(streamCtx) + if err != nil { + t.logger.Warningf("Creating LRS stream to server %q failed: %v", t.serverURI, err) + return nil } + t.logger.Infof("Created LRS stream to server %q", t.serverURI) - // We reset backoff state when we successfully receive at least one - // message from the server. - resetBackoff := func() bool { - // streamCtx is created and canceled in case we terminate the stream - // early for any reason, to avoid gRPC-Go leaking the RPC's monitoring - // goroutine. - streamCtx, cancel := context.WithCancel(ctx) - defer cancel() - stream, err := v3lrsgrpc.NewLoadReportingServiceClient(t.cc).StreamLoadStats(streamCtx) - if err != nil { - t.logger.Warningf("Creating LRS stream to server %q failed: %v", t.serverURI, err) - return false - } - t.logger.Infof("Created LRS stream to server %q", t.serverURI) - - if err := t.sendFirstLoadStatsRequest(stream, node); err != nil { - t.logger.Warningf("Sending first LRS request failed: %v", err) - return false - } - - clusters, interval, err := t.recvFirstLoadStatsResponse(stream) - if err != nil { - t.logger.Warningf("Reading from LRS stream failed: %v", err) - return false - } - - t.sendLoads(streamCtx, stream, clusters, interval) - return true - }() + if err := t.sendFirstLoadStatsRequest(stream, node); err != nil { + t.logger.Warningf("Sending first LRS request failed: %v", err) + return nil + } - if resetBackoff { - backoffTimer.Reset(0) - backoffAttempt = 0 - } else { - backoffTimer.Reset(t.backoff(backoffAttempt)) - backoffAttempt++ + clusters, interval, err := t.recvFirstLoadStatsResponse(stream) + if err != nil { + t.logger.Warningf("Reading from LRS stream failed: %v", err) + return nil } + + // We reset backoff state when we successfully receive at least one + // message from the server. + t.sendLoads(streamCtx, stream, clusters, interval) + return backoff.ErrResetBackoff } + backoff.RunF(ctx, runLoadReportStream, t.backoff) } func (t *Transport) sendLoads(ctx context.Context, stream lrsStream, clusterNames []string, interval time.Duration) { diff --git a/xds/internal/xdsclient/transport/transport.go b/xds/internal/xdsclient/transport/transport.go index 86803588a7cc..001552d7b479 100644 --- a/xds/internal/xdsclient/transport/transport.go +++ b/xds/internal/xdsclient/transport/transport.go @@ -325,43 +325,29 @@ func (t *Transport) adsRunner(ctx context.Context) { go t.send(ctx) - backoffAttempt := 0 - backoffTimer := time.NewTimer(0) - for ctx.Err() == nil { - select { - case <-backoffTimer.C: - case <-ctx.Done(): - backoffTimer.Stop() - return + // We reset backoff state when we successfully receive at least one + // message from the server. + runStreamWithBackoff := func() error { + stream, err := t.newAggregatedDiscoveryServiceStream(ctx, t.cc) + if err != nil { + t.onErrorHandler(err) + t.logger.Warningf("Creating new ADS stream failed: %v", err) + return nil } + t.logger.Infof("ADS stream created") - // We reset backoff state when we successfully receive at least one - // message from the server. - resetBackoff := func() bool { - stream, err := t.newAggregatedDiscoveryServiceStream(ctx, t.cc) - if err != nil { - t.onErrorHandler(err) - t.logger.Warningf("Creating new ADS stream failed: %v", err) - return false - } - t.logger.Infof("ADS stream created") - - select { - case <-t.adsStreamCh: - default: - } - t.adsStreamCh <- stream - return t.recv(stream) - }() - - if resetBackoff { - backoffTimer.Reset(0) - backoffAttempt = 0 - } else { - backoffTimer.Reset(t.backoff(backoffAttempt)) - backoffAttempt++ + select { + case <-t.adsStreamCh: + default: + } + t.adsStreamCh <- stream + msgReceived := t.recv(stream) + if msgReceived { + return backoff.ErrResetBackoff } + return nil } + backoff.RunF(ctx, runStreamWithBackoff, t.backoff) } // send is a separate goroutine for sending resource requests on the ADS stream.