-
Notifications
You must be signed in to change notification settings - Fork 6
/
transport.go
98 lines (85 loc) · 2.45 KB
/
transport.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: 2024-Present Defense Unicorns
package helpers
import (
"io"
"net/http"
"time"
"oras.land/oras-go/v2/registry/remote/retry"
)
// Transport is an http.RoundTripper that keeps track of the in-flight
// request and add hooks to report upload progress.
type Transport struct {
Base http.RoundTripper
ProgressBar ProgressWriter
}
// NewTransport returns a custom transport that tracks an http.RoundTripper and a message.ProgressBar.
func NewTransport(base http.RoundTripper, bar ProgressWriter) *Transport {
return &Transport{
Base: base,
ProgressBar: bar,
}
}
// RoundTrip is mirrored from retry, but instead of calling retry's private t.roundTrip(), this uses
// our own which has interactions w/ message.ProgressBar
//
// https://github.com/oras-project/oras-go/blob/main/registry/remote/retry/client.go
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := req.Context()
policy := retry.DefaultPolicy
attempt := 0
for {
resp, respErr := t.roundTrip(req)
duration, err := policy.Retry(attempt, resp, respErr)
if err != nil {
if respErr == nil {
resp.Body.Close()
}
return nil, err
}
if duration < 0 {
return resp, respErr
}
// rewind the body if possible
if req.Body != nil {
if req.GetBody == nil {
// body can't be rewound, so we can't retry
return resp, respErr
}
body, err := req.GetBody()
if err != nil {
// failed to rewind the body, so we can't retry
return resp, respErr
}
req.Body = body
}
// close the response body if needed
if respErr == nil {
resp.Body.Close()
}
timer := time.NewTimer(duration)
select {
case <-ctx.Done():
timer.Stop()
return nil, ctx.Err()
case <-timer.C:
}
attempt++
}
}
// roundTrip calls base roundtrip while keeping track of the current request.
// this is currently only used to track the progress of publishes, not pulls.
func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error) {
if req.Method != http.MethodHead && req.Body != nil && t.ProgressBar != nil {
req.Body = io.NopCloser(io.TeeReader(req.Body, t.ProgressBar))
}
resp, err = t.Base.RoundTrip(req)
if resp != nil && req.Method == http.MethodHead && err == nil && t.ProgressBar != nil {
if resp.ContentLength > 0 {
contentLength := int(resp.ContentLength)
b := make([]byte, contentLength)
t.ProgressBar.Write(b)
}
}
return resp, err
}