diff --git a/sdk/storage/azblob/chunkwriting.go b/sdk/storage/azblob/chunkwriting.go index 5f4fc4d89571..8baedb82c27b 100644 --- a/sdk/storage/azblob/chunkwriting.go +++ b/sdk/storage/azblob/chunkwriting.go @@ -184,9 +184,7 @@ func (c *copier) write(chunk copierChunk) { // close commits our blocks to blob storage and closes our writer. func (c *copier) close() error { - c.wg.Wait() - - if err := c.getErr(); err != nil { + if err := c.waitForFinish(); err != nil { return err } @@ -196,6 +194,44 @@ func (c *copier) close() error { return err } +// waitForFinish waits for all writes to complete while combining errors from errCh +func (c *copier) waitForFinish() error { + var err error + done := make(chan struct{}) + go func() { + // when write latencies are long, several errors might have occurred + // drain them all as we wait for writes to complete. + err = c.drainErrs(done) + }() + + c.wg.Wait() + close(done) + return err +} + +// drainErrs drains all outstanding errors from writes +func (c *copier) drainErrs(done chan struct{}) error { + var err error + for { + select { + case <-done: + return err + default: + if writeErr := c.getErr(); writeErr != nil { + err = combineErrs(err, writeErr) + } + } + } +} + +// combineErrs combines err with newErr so multiple errors can be represented +func combineErrs(err, newErr error) error { + if err == nil { + return newErr + } + return fmt.Errorf("%s, %w", err.Error(), newErr) +} + // id allows the creation of unique IDs based on UUID4 + an int32. This auto-increments. type id struct { u [64]byte diff --git a/sdk/storage/azblob/zt_chunkwriting_test.go b/sdk/storage/azblob/zt_chunkwriting_test.go index 1d2f152fd241..08d1f3f573d2 100644 --- a/sdk/storage/azblob/zt_chunkwriting_test.go +++ b/sdk/storage/azblob/zt_chunkwriting_test.go @@ -25,6 +25,7 @@ type fakeBlockWriter struct { path string block int32 errOnBlock int32 + stageDelay time.Duration } //nolint @@ -46,7 +47,12 @@ func newFakeBlockWriter() *fakeBlockWriter { //nolint func (f *fakeBlockWriter) StageBlock(_ context.Context, blockID string, body io.ReadSeekCloser, _ *StageBlockOptions) (BlockBlobStageBlockResponse, error) { n := atomic.AddInt32(&f.block, 1) - if n == f.errOnBlock { + + if f.stageDelay > 0 { + time.Sleep(f.stageDelay) + } + + if f.errOnBlock > -1 && n >= f.errOnBlock { return BlockBlobStageBlockResponse{}, io.ErrNoProgress } @@ -195,6 +201,46 @@ func (s *azblobUnrecordedTestSuite) TestGetErr() { } } +// nolint +func (s *azblobUnrecordedTestSuite) TestSlowDestCopyFrom() { + p, err := createSrcFile(_1MiB + 500*1024) //This should cause 2 reads + if err != nil { + panic(err) + } + defer func(name string) { + _ = os.Remove(name) + }(p) + + from, err := os.Open(p) + if err != nil { + panic(err) + } + defer from.Close() + + br := newFakeBlockWriter() + defer br.cleanup() + + br.stageDelay = 200 * time.Millisecond + br.errOnBlock = 0 + + errs := make(chan error, 1) + go func() { + _, err := copyFromReader(context.Background(), from, br, UploadStreamToBlockBlobOptions{}) + errs <- err + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + select { + case <-ctx.Done(): + failMsg := "TestSlowDestCopyFrom(slow writes shouldn't cause deadlock) failed: Context expired, copy deadlocked" + s.T().Error(failMsg) + case <-errs: + return + } +} + //nolint func (s *azblobUnrecordedTestSuite) TestCopyFromReader() { s.T().Parallel() @@ -302,6 +348,7 @@ func (s *azblobUnrecordedTestSuite) TestCopyFromReader() { if err != nil { panic(err) } + defer from.Close() br := newFakeBlockWriter() defer br.cleanup()