@@ -7,8 +7,12 @@ package http
7
7
import (
8
8
"bufio"
9
9
"bytes"
10
+ "crypto/rand"
11
+ "fmt"
10
12
"io"
11
13
"io/ioutil"
14
+ "os"
15
+ "reflect"
12
16
"strings"
13
17
"testing"
14
18
)
@@ -90,3 +94,187 @@ func TestDetectInMemoryReaders(t *testing.T) {
90
94
}
91
95
}
92
96
}
97
+
98
+ type mockTransferWriter struct {
99
+ CalledReader io.Reader
100
+ WriteCalled bool
101
+ }
102
+
103
+ var _ io.ReaderFrom = (* mockTransferWriter )(nil )
104
+
105
+ func (w * mockTransferWriter ) ReadFrom (r io.Reader ) (int64 , error ) {
106
+ w .CalledReader = r
107
+ return io .Copy (ioutil .Discard , r )
108
+ }
109
+
110
+ func (w * mockTransferWriter ) Write (p []byte ) (int , error ) {
111
+ w .WriteCalled = true
112
+ return ioutil .Discard .Write (p )
113
+ }
114
+
115
+ func TestTransferWriterWriteBodyReaderTypes (t * testing.T ) {
116
+ fileType := reflect .TypeOf (& os.File {})
117
+ bufferType := reflect .TypeOf (& bytes.Buffer {})
118
+
119
+ nBytes := int64 (1 << 10 )
120
+ newFileFunc := func () (r io.Reader , done func (), err error ) {
121
+ f , err := ioutil .TempFile ("" , "net-http-newfilefunc" )
122
+ if err != nil {
123
+ return nil , nil , err
124
+ }
125
+
126
+ // Write some bytes to the file to enable reading.
127
+ if _ , err := io .CopyN (f , rand .Reader , nBytes ); err != nil {
128
+ return nil , nil , fmt .Errorf ("failed to write data to file: %v" , err )
129
+ }
130
+ if _ , err := f .Seek (0 , 0 ); err != nil {
131
+ return nil , nil , fmt .Errorf ("failed to seek to front: %v" , err )
132
+ }
133
+
134
+ done = func () {
135
+ f .Close ()
136
+ os .Remove (f .Name ())
137
+ }
138
+
139
+ return f , done , nil
140
+ }
141
+
142
+ newBufferFunc := func () (io.Reader , func (), error ) {
143
+ return bytes .NewBuffer (make ([]byte , nBytes )), func () {}, nil
144
+ }
145
+
146
+ cases := []struct {
147
+ name string
148
+ bodyFunc func () (io.Reader , func (), error )
149
+ method string
150
+ contentLength int64
151
+ transferEncoding []string
152
+ limitedReader bool
153
+ expectedReader reflect.Type
154
+ expectedWrite bool
155
+ }{
156
+ {
157
+ name : "file, non-chunked, size set" ,
158
+ bodyFunc : newFileFunc ,
159
+ method : "PUT" ,
160
+ contentLength : nBytes ,
161
+ limitedReader : true ,
162
+ expectedReader : fileType ,
163
+ },
164
+ {
165
+ name : "file, non-chunked, size set, nopCloser wrapped" ,
166
+ method : "PUT" ,
167
+ bodyFunc : func () (io.Reader , func (), error ) {
168
+ r , cleanup , err := newFileFunc ()
169
+ return ioutil .NopCloser (r ), cleanup , err
170
+ },
171
+ contentLength : nBytes ,
172
+ limitedReader : true ,
173
+ expectedReader : fileType ,
174
+ },
175
+ {
176
+ name : "file, non-chunked, negative size" ,
177
+ method : "PUT" ,
178
+ bodyFunc : newFileFunc ,
179
+ contentLength : - 1 ,
180
+ expectedReader : fileType ,
181
+ },
182
+ {
183
+ name : "file, non-chunked, CONNECT, negative size" ,
184
+ method : "CONNECT" ,
185
+ bodyFunc : newFileFunc ,
186
+ contentLength : - 1 ,
187
+ expectedReader : fileType ,
188
+ },
189
+ {
190
+ name : "file, chunked" ,
191
+ method : "PUT" ,
192
+ bodyFunc : newFileFunc ,
193
+ transferEncoding : []string {"chunked" },
194
+ expectedWrite : true ,
195
+ },
196
+ {
197
+ name : "buffer, non-chunked, size set" ,
198
+ bodyFunc : newBufferFunc ,
199
+ method : "PUT" ,
200
+ contentLength : nBytes ,
201
+ limitedReader : true ,
202
+ expectedReader : bufferType ,
203
+ },
204
+ {
205
+ name : "buffer, non-chunked, size set, nopCloser wrapped" ,
206
+ method : "PUT" ,
207
+ bodyFunc : func () (io.Reader , func (), error ) {
208
+ r , cleanup , err := newBufferFunc ()
209
+ return ioutil .NopCloser (r ), cleanup , err
210
+ },
211
+ contentLength : nBytes ,
212
+ limitedReader : true ,
213
+ expectedReader : bufferType ,
214
+ },
215
+ {
216
+ name : "buffer, non-chunked, negative size" ,
217
+ method : "PUT" ,
218
+ bodyFunc : newBufferFunc ,
219
+ contentLength : - 1 ,
220
+ expectedWrite : true ,
221
+ },
222
+ {
223
+ name : "buffer, non-chunked, CONNECT, negative size" ,
224
+ method : "CONNECT" ,
225
+ bodyFunc : newBufferFunc ,
226
+ contentLength : - 1 ,
227
+ expectedWrite : true ,
228
+ },
229
+ {
230
+ name : "buffer, chunked" ,
231
+ method : "PUT" ,
232
+ bodyFunc : newBufferFunc ,
233
+ transferEncoding : []string {"chunked" },
234
+ expectedWrite : true ,
235
+ },
236
+ }
237
+
238
+ for _ , tc := range cases {
239
+ t .Run (tc .name , func (t * testing.T ) {
240
+ body , cleanup , err := tc .bodyFunc ()
241
+ if err != nil {
242
+ t .Fatal (err )
243
+ }
244
+ defer cleanup ()
245
+
246
+ mw := & mockTransferWriter {}
247
+ tw := & transferWriter {
248
+ Body : body ,
249
+ ContentLength : tc .contentLength ,
250
+ TransferEncoding : tc .transferEncoding ,
251
+ }
252
+
253
+ if err := tw .writeBody (mw ); err != nil {
254
+ t .Fatal (err )
255
+ }
256
+
257
+ if tc .expectedReader != nil {
258
+ if mw .CalledReader == nil {
259
+ t .Fatal ("did not call ReadFrom" )
260
+ }
261
+
262
+ var actualReader reflect.Type
263
+ lr , ok := mw .CalledReader .(* io.LimitedReader )
264
+ if ok && tc .limitedReader {
265
+ actualReader = reflect .TypeOf (lr .R )
266
+ } else {
267
+ actualReader = reflect .TypeOf (mw .CalledReader )
268
+ }
269
+
270
+ if tc .expectedReader != actualReader {
271
+ t .Fatalf ("got reader %T want %T" , actualReader , tc .expectedReader )
272
+ }
273
+ }
274
+
275
+ if tc .expectedWrite && ! mw .WriteCalled {
276
+ t .Fatal ("did not invoke Write" )
277
+ }
278
+ })
279
+ }
280
+ }
0 commit comments