diff --git a/sdk/storage/azblob/blockblob/chunkwriting.go b/sdk/storage/azblob/blockblob/chunkwriting.go index 0ed98c403281..16927ecf895e 100644 --- a/sdk/storage/azblob/blockblob/chunkwriting.go +++ b/sdk/storage/azblob/blockblob/chunkwriting.go @@ -13,12 +13,12 @@ import ( "encoding/binary" "errors" "fmt" - "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/shared" "io" "sync" "sync/atomic" "github.com/Azure/azure-sdk-for-go/sdk/internal/uuid" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/shared" ) // blockWriter provides methods to upload blocks that represent a file to a server and commit them. @@ -181,8 +181,12 @@ func (c *copier) write(chunk copierChunk) { stageBlockOptions := c.o.getStageBlockOptions() _, err := c.to.StageBlock(c.ctx, chunk.id, shared.NopCloser(bytes.NewReader(chunk.buffer[:chunk.length])), stageBlockOptions) if err != nil { - c.errCh <- fmt.Errorf("write error: %w", err) - return + select { + case c.errCh <- err: + // failed to stage block, cancel the copy + default: + // don't block the goroutine if there's a pending error + } } } diff --git a/sdk/storage/azblob/blockblob/chunkwriting_test.go b/sdk/storage/azblob/blockblob/chunkwriting_test.go index afbf7c0e63da..addc3fbd97be 100644 --- a/sdk/storage/azblob/blockblob/chunkwriting_test.go +++ b/sdk/storage/azblob/blockblob/chunkwriting_test.go @@ -30,6 +30,7 @@ type fakeBlockWriter struct { path string block int32 errOnBlock int32 + stageDelay time.Duration } func newFakeBlockWriter() *fakeBlockWriter { @@ -49,7 +50,12 @@ func newFakeBlockWriter() *fakeBlockWriter { func (f *fakeBlockWriter) StageBlock(_ context.Context, blockID string, body io.ReadSeekCloser, _ *StageBlockOptions) (StageBlockResponse, error) { n := atomic.AddInt32(&f.block, 1) - if n == f.errOnBlock { + + if f.stageDelay > 0 { + time.Sleep(f.stageDelay) + } + + if f.errOnBlock > -1 && n >= f.errOnBlock { return StageBlockResponse{}, io.ErrNoProgress } @@ -192,6 +198,45 @@ func TestGetErr(t *testing.T) { } } +func TestSlowDestCopyFrom(t *testing.T) { + p, err := createSrcFile(_1MiB + 500*1024) //This should cause 2 reads + if err != nil { + panic(err) + } + defer func(name string) { + _ = os.Remove(name) + }(p) + + from, err := os.Open(p) + if err != nil { + panic(err) + } + defer from.Close() + + br := newFakeBlockWriter() + defer br.cleanup() + + br.stageDelay = 200 * time.Millisecond + br.errOnBlock = 0 + + errs := make(chan error, 1) + go func() { + _, err := copyFromReader(context.Background(), from, br, UploadStreamOptions{}) + errs <- err + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + select { + case <-ctx.Done(): + failMsg := "TestSlowDestCopyFrom(slow writes shouldn't cause deadlock) failed: Context expired, copy deadlocked" + t.Error(failMsg) + case <-errs: + return + } +} + func TestCopyFromReader(t *testing.T) { t.Parallel()