Skip to content

Commit

Permalink
Send() can use progress writers
Browse files Browse the repository at this point in the history
  • Loading branch information
gildas committed Feb 12, 2024
1 parent df47123 commit 61440c4
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 4 deletions.
64 changes: 64 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ res, err := request.Send(&request.Options{
}, nil)
```

The file name and its key will be written in the `multipart/form-data`'s `Content-Disposition` header as: `form-data; name="file"; filename="image.png"`.

To send the request again when receiving a Service Unavailable (`Attempts` and `Timeout` are optional):

```go
Expand Down Expand Up @@ -163,6 +165,68 @@ res, err := request.Send(&request.Options{
}, nil)
```

When sending requests to upload data streams, you can provide an `io.Writer` to write the progress to:

```go
import "github.com/schollz/progressbar/v3"

reader, err := os.Open(pathToFile)
defer reader.Close()
stat, err := reader.Stat()
bar := progressbar.DefaultBytes(stat.Size(), "Uploading")
res, err := request.Send(&request.Options{
Method: http.MethodPost,
URL: serverURL,
Payload: reader,
ProgressWriter: bar,
}, reader)
```

If the progress `io.Writer` is also an `io.Closer`, it will be closed at the end of the `request.Send()`.

When sending requests to download data streams, you can provide an `io.Writer` to write the progress to:

```go
import "github.com/schollz/progressbar/v3"

writer, err := os.Create(pathToFile)
defer writer.Close()
bar := progressbar.DefaultBytes(-1, "Downloading") // will use a spinner
res, err := request.Send(&request.Options{
URL: serverURL,
ProgressWriter: bar,
}, writer)
```

Again, if the progress `io.Writer` is also an `io.Closer`, it will be closed at the end of the `request.Send()`.

if you provide a `request.Options.ProgressSetMax` func or if the `io.Writer` is a `request.ProgressBarMaxSetter` or a `request.ProgressBarMaxChanger`, `request.Send` will call it to set the maximum value of the progress bar from the response `Content-Length`:

```go
import "github.com/schollz/progressbar/v3"

writer, err := os.Create(pathToFile)
defer writer.Close()
bar := progressbar.DefaultBytes(1, "Downloading") // use a temporary max value
res, err := request.Send(&request.Options{
URL: serverURL,
ProgressWriter: bar,
}, writer)
```

```go
import "github.com/cheggaaa/pb/v3"

writer, err := os.Create(pathToFile)
defer writer.Close()
bar := pb.StartNew(1) // use a temporary max value
res, err := request.Send(&request.Options{
URL: serverURL,
ProgressWriter: bar,
ProgressSetMaxFunc: func(max int64) { bar.SetTotal64(max) },
}, writer)
```

**Notes:**

- if the PayloadType is not mentioned, it is calculated when processing the Payload.
Expand Down
26 changes: 26 additions & 0 deletions progress.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package request

import "io"

// ProgressBarMaxSetter is an interface that allows setting the maximum value of a progress bar
type ProgressBarMaxSetter interface {
SetMax64(int64)
}

// ProgressBarMaxChanger is an interface that allows setting the maximum value of a progress bar
//
// This interface allows packages such as "/github.com/schollz/progressbar/v3" to be used as progress bars
type ProgressBarMaxChanger interface {
ChangeMax64(int64)
}

type progressReader struct {
io.Reader
Progress io.Writer
}

func (reader *progressReader) Read(p []byte) (n int, err error) {
n, err = reader.Reader.Read(p)
_, _ = reader.Progress.Write(p[:n])
return
}
44 changes: 40 additions & 4 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type Options struct {
RequestID string
UserAgent string
Transport *http.Transport
ProgressWriter io.Writer // if not nil, the progress of the request will be written to this writer
ProgressSetMaxFunc func(int64)
RetryableStatusCodes []int // Status codes that should be retried, by default: 429, 502, 503, 504
Attempts uint // number of attempts, by default: 5
InterAttemptDelay time.Duration // how long to wait between 2 attempts during the first backoff interval, by default: 3s
Expand Down Expand Up @@ -80,6 +82,12 @@ func Send(options *Options, results interface{}) (*Content, error) {
}
log := options.Logger.Child(nil, "request", "reqid", options.RequestID, "method", options.Method)

if progressCloser, ok := options.ProgressWriter.(io.Closer); ok {
defer func() {
progressCloser.Close()
}()
}

log.Debugf("HTTP %s %s", options.Method, options.URL.String())
req, err := buildRequest(log, options)
if err != nil {
Expand Down Expand Up @@ -183,6 +191,22 @@ func Send(options *Options, results interface{}) (*Content, error) {
// Reading the response body

if writer, ok := results.(io.Writer); ok {
if options.ProgressWriter != nil {
if options.ProgressSetMaxFunc != nil {
if size, err := strconv.ParseInt(res.Header.Get("Content-Length"), 10, 64); err == nil {
options.ProgressSetMaxFunc(size)
}
} else if maxSetter, ok := options.ProgressWriter.(ProgressBarMaxSetter); ok {
if size, err := strconv.ParseInt(res.Header.Get("Content-Length"), 10, 64); err == nil {
maxSetter.SetMax64(size)
}
} else if maxChanger, ok := options.ProgressWriter.(ProgressBarMaxChanger); ok {
if size, err := strconv.ParseInt(res.Header.Get("Content-Length"), 10, 64); err == nil {
maxChanger.ChangeMax64(size)
}
}
writer = io.MultiWriter(writer, options.ProgressWriter)
}
bytesRead, err := io.Copy(writer, res.Body)
if err != nil {
return nil, errors.WithStack(err)
Expand Down Expand Up @@ -430,9 +454,12 @@ func buildRequestContent(log *logger.Logger, options *Options) (content *Content
if err != nil {
return nil, errors.Wrapf(err, "Failed to create multipart for field %s", key)
}
_, err = options.Attachment.(io.Seeker).Seek(0, io.SeekStart)
if err != nil {
return nil, errors.Wrapf(err, "Failed to seek to beginning of attachment for field %s", key)
// if options.Attempts == 1, we don't need to seek to the beginning of the attachment
if options.Attempts > 1 {
_, err = options.Attachment.(io.Seeker).Seek(0, io.SeekStart)
if err != nil {
return nil, errors.Wrapf(err, "Failed to seek to beginning of attachment for field %s", key)
}
}
written, err := io.Copy(part, options.Attachment)
if err != nil {
Expand Down Expand Up @@ -484,7 +511,16 @@ func buildRequest(log *logger.Logger, options *Options) (*http.Request, error) {
log.Tracef("Computed HTTP method: %s", options.Method)
}

req, err := http.NewRequestWithContext(options.Context, options.Method, options.URL.String(), reqContent.Reader())
reader := reqContent.Reader()

if options.ProgressWriter != nil {
reader = &progressReader{
Reader: reqContent.Reader(),
Progress: options.ProgressWriter,
}
}

req, err := http.NewRequestWithContext(options.Context, options.Method, options.URL.String(), reader)
if err != nil {
return nil, errors.WithStack(err)
}
Expand Down
79 changes: 79 additions & 0 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1123,3 +1123,82 @@ func (suite *RequestSuite) TestCanSendRequestWithWriterStream() {
suite.Assert().Equal("application/octet-stream", content.Type)
suite.Assert().Equal(uint64(4), content.Length)
}

func (suite *RequestSuite) TestCandSendRequestWithUploadDataAndProgress() {
serverURL, _ := url.Parse(suite.Server.URL)
serverURL, _ = serverURL.Parse("/image")
bar := &progressWriter{}
content, err := request.Send(&request.Options{
URL: serverURL,
Payload: map[string]string{"ID": "1234", ">file": "image.png"},
AttachmentType: "image/png",
Attachment: bytes.NewReader(smallPNG()),
ProgressWriter: bar,
Logger: suite.Logger,
}, nil)
suite.Require().NoError(err, "Failed sending request, err=%+v", err)
suite.Require().NotNil(content, "Content should not be nil")
suite.Assert().Equal("1", string(content.Data))
suite.Assert().Equal(int64(408), bar.Total)
}

func (suite *RequestSuite) TestCandSendRequestWithDownloadDataAndProgress() {
writer := new(bytes.Buffer)
suite.Logger.Memoryf("Before sending request")
serverURL, _ := url.Parse(suite.Server.URL)
serverURL, _ = serverURL.Parse("/binary_data")
bar := &progressWriter{}
content, err := request.Send(&request.Options{
URL: serverURL,
ProgressWriter: bar,
ProgressSetMaxFunc: func(max int64) { bar.Max = max },
Logger: suite.Logger,
}, writer)
suite.Logger.Memoryf("After sending request")
suite.Require().NoError(err, "Failed sending request, err=%+v", err)
suite.Assert().Equal("application/octet-stream", content.Type)
suite.Assert().Equal(uint64(4), content.Length)
suite.Assert().Equal([]byte("body"), writer.Bytes())
suite.Assert().Equal(int64(4), bar.Total)
suite.Assert().Equal(int64(4), bar.Max)
}

func (suite *RequestSuite) TestCandSendRequestWithDownloadDataAndProgressMaxSetter() {
writer := new(bytes.Buffer)
suite.Logger.Memoryf("Before sending request")
serverURL, _ := url.Parse(suite.Server.URL)
serverURL, _ = serverURL.Parse("/binary_data")
bar := &progressWriter2{}
content, err := request.Send(&request.Options{
URL: serverURL,
ProgressWriter: bar,
Logger: suite.Logger,
}, writer)
suite.Logger.Memoryf("After sending request")
suite.Require().NoError(err, "Failed sending request, err=%+v", err)
suite.Assert().Equal("application/octet-stream", content.Type)
suite.Assert().Equal(uint64(4), content.Length)
suite.Assert().Equal([]byte("body"), writer.Bytes())
suite.Assert().Equal(int64(4), bar.Total)
suite.Assert().Equal(int64(4), bar.Max)
}

func (suite *RequestSuite) TestCandSendRequestWithDownloadDataAndProgressMaxChanger() {
writer := new(bytes.Buffer)
suite.Logger.Memoryf("Before sending request")
serverURL, _ := url.Parse(suite.Server.URL)
serverURL, _ = serverURL.Parse("/binary_data")
bar := &progressWriter3{}
content, err := request.Send(&request.Options{
URL: serverURL,
ProgressWriter: bar,
Logger: suite.Logger,
}, writer)
suite.Logger.Memoryf("After sending request")
suite.Require().NoError(err, "Failed sending request, err=%+v", err)
suite.Assert().Equal("application/octet-stream", content.Type)
suite.Assert().Equal(uint64(4), content.Length)
suite.Assert().Equal([]byte("body"), writer.Bytes())
suite.Assert().Equal(int64(4), bar.Total)
suite.Assert().Equal(int64(4), bar.Max)
}
45 changes: 45 additions & 0 deletions types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,57 @@ func (s stuff) String() string {
return s.ID
}

type progressWriter struct {
Total int64
Max int64
}

func (w *progressWriter) Write(p []byte) (n int, err error) {
w.Total += int64(len(p))
return len(p), nil
}

func (w *progressWriter) Close() error {
return nil
}

type progressWriter2 struct {
Total int64
Max int64
}

func (w *progressWriter2) SetMax64(max int64) {
w.Max = max
}

func (w *progressWriter2) Write(p []byte) (n int, err error) {
w.Total += int64(len(p))
return len(p), nil
}

type progressWriter3 struct {
Total int64
Max int64
}

func (w *progressWriter3) ChangeMax64(max int64) {
w.Max = max
}

func (w *progressWriter3) Write(p []byte) (n int, err error) {
w.Total += int64(len(p))
return len(p), nil
}

func TestStuffShouldBeStringer(t *testing.T) {
s := stuff{"1234"}
var z interface{} = s
assert.NotNil(t, z.(fmt.Stringer), "Integer type is not a Stringer")
}

// smallPNG returns a small PNG image as a byte array
//
// This is a 1x1 pixel PNG image and is 408 bytes in size
func smallPNG() []byte {
image := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAACklEQVR4nGMAAQAABQABDQottAAAAABJRU5ErkJggg=="
data, err := base64.StdEncoding.DecodeString(image)
Expand Down

0 comments on commit 61440c4

Please sign in to comment.